auto-diffuser-config / model_memory_calculator.py
chansung's picture
Upload folder using huggingface_hub
80a1334 verified
import requests
from huggingface_hub import HfApi
from typing import Dict, Optional, Tuple
import json
class ModelMemoryCalculator:
def __init__(self):
self.hf_api = HfApi()
self.cache = {} # Cache results to avoid repeated API calls
def get_model_memory_requirements(self, model_id: str) -> Dict:
"""
Calculate memory requirements for a given HuggingFace model.
Args:
model_id: HuggingFace model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
Dict with memory information including:
- total_params: Total parameter count
- memory_fp32: Memory in GB at FP32 precision
- memory_fp16: Memory in GB at FP16 precision
- memory_bf16: Memory in GB at BF16 precision
- safetensors_files: List of safetensor files and their sizes
"""
if model_id in self.cache:
return self.cache[model_id]
try:
print(f"Fetching model info for {model_id}...")
# Get model info
model_info = self.hf_api.model_info(model_id)
print(f"Model info retrieved successfully")
# Get safetensors metadata
print(f"Fetching safetensors metadata...")
safetensors_metadata = self.hf_api.get_safetensors_metadata(model_id)
print(f"Found {len(safetensors_metadata)} safetensor files")
total_params = 0
safetensors_files = []
# Iterate through all safetensor files
for filename, metadata in safetensors_metadata.items():
file_params = 0
file_size_bytes = 0
# Calculate parameters from tensor metadata
if 'metadata' in metadata and metadata['metadata']:
for tensor_name, tensor_info in metadata['metadata'].items():
if 'shape' in tensor_info and 'dtype' in tensor_info:
# Calculate tensor size
shape = tensor_info['shape']
tensor_params = 1
for dim in shape:
tensor_params *= dim
file_params += tensor_params
# Calculate byte size based on dtype
dtype = tensor_info['dtype']
bytes_per_param = self._get_bytes_per_param(dtype)
file_size_bytes += tensor_params * bytes_per_param
total_params += file_params
safetensors_files.append({
'filename': filename,
'parameters': file_params,
'size_bytes': file_size_bytes,
'size_mb': file_size_bytes / (1024 * 1024)
})
# Calculate memory requirements for different precisions
memory_requirements = {
'model_id': model_id,
'total_params': total_params,
'total_params_billions': total_params / 1e9,
'memory_fp32_gb': (total_params * 4) / (1024**3), # 4 bytes per param
'memory_fp16_gb': (total_params * 2) / (1024**3), # 2 bytes per param
'memory_bf16_gb': (total_params * 2) / (1024**3), # 2 bytes per param
'memory_int8_gb': (total_params * 1) / (1024**3), # 1 byte per param
'safetensors_files': safetensors_files,
'estimated_inference_memory_fp16_gb': self._estimate_inference_memory(total_params, 'fp16'),
'estimated_inference_memory_bf16_gb': self._estimate_inference_memory(total_params, 'bf16'),
}
# Cache the result
self.cache[model_id] = memory_requirements
return memory_requirements
except Exception as e:
return {
'error': str(e),
'model_id': model_id,
'total_params': 0,
'memory_fp32_gb': 0,
'memory_fp16_gb': 0,
'memory_bf16_gb': 0,
}
def _get_bytes_per_param(self, dtype: str) -> int:
"""Get bytes per parameter for different data types."""
dtype_map = {
'F32': 4, 'float32': 4,
'F16': 2, 'float16': 2,
'BF16': 2, 'bfloat16': 2,
'I8': 1, 'int8': 1,
'I32': 4, 'int32': 4,
'I64': 8, 'int64': 8,
}
return dtype_map.get(dtype, 4) # Default to 4 bytes (FP32)
def _estimate_inference_memory(self, total_params: int, precision: str) -> float:
"""
Estimate memory requirements during inference.
This includes model weights + activations + intermediate tensors.
"""
bytes_per_param = 2 if precision in ['fp16', 'bf16'] else 4
# Model weights
model_memory = (total_params * bytes_per_param) / (1024**3)
# Estimate activation memory (rough approximation)
# For diffusion models, activations can be 1.5-3x model size during inference
activation_multiplier = 2.0
total_inference_memory = model_memory * (1 + activation_multiplier)
return total_inference_memory
def get_memory_recommendation(self, model_id: str, available_vram_gb: float) -> Dict:
"""
Get memory recommendations based on available VRAM.
Args:
model_id: HuggingFace model identifier
available_vram_gb: Available VRAM in GB
Returns:
Dict with recommendations for precision, offloading, etc.
"""
memory_info = self.get_model_memory_requirements(model_id)
if 'error' in memory_info:
return {'error': memory_info['error']}
recommendations = {
'model_id': model_id,
'available_vram_gb': available_vram_gb,
'model_memory_fp16_gb': memory_info['memory_fp16_gb'],
'estimated_inference_memory_fp16_gb': memory_info['estimated_inference_memory_fp16_gb'],
'recommendations': []
}
inference_memory_fp16 = memory_info['estimated_inference_memory_fp16_gb']
inference_memory_bf16 = memory_info['estimated_inference_memory_bf16_gb']
# Determine recommendations
if available_vram_gb >= inference_memory_bf16:
recommendations['recommendations'].append("βœ… Full model can fit in VRAM with BF16 precision")
recommendations['recommended_precision'] = 'bfloat16'
recommendations['cpu_offload'] = False
recommendations['attention_slicing'] = False
elif available_vram_gb >= inference_memory_fp16:
recommendations['recommendations'].append("βœ… Full model can fit in VRAM with FP16 precision")
recommendations['recommended_precision'] = 'float16'
recommendations['cpu_offload'] = False
recommendations['attention_slicing'] = False
elif available_vram_gb >= memory_info['memory_fp16_gb']:
recommendations['recommendations'].append("⚠️ Model weights fit, but may need memory optimizations")
recommendations['recommended_precision'] = 'float16'
recommendations['cpu_offload'] = False
recommendations['attention_slicing'] = True
recommendations['vae_slicing'] = True
else:
recommendations['recommendations'].append("πŸ”„ Requires CPU offloading and memory optimizations")
recommendations['recommended_precision'] = 'float16'
recommendations['cpu_offload'] = True
recommendations['sequential_offload'] = True
recommendations['attention_slicing'] = True
recommendations['vae_slicing'] = True
return recommendations
def format_memory_info(self, model_id: str) -> str:
"""Format memory information for display."""
info = self.get_model_memory_requirements(model_id)
if 'error' in info:
return f"❌ Error calculating memory for {model_id}: {info['error']}"
output = f"""
πŸ“Š **Memory Requirements for {model_id}**
πŸ”’ **Parameters**: {info['total_params_billions']:.2f}B parameters
πŸ’Ύ **Model Memory**:
β€’ FP32: {info['memory_fp32_gb']:.2f} GB
β€’ FP16/BF16: {info['memory_fp16_gb']:.2f} GB
β€’ INT8: {info['memory_int8_gb']:.2f} GB
πŸš€ **Estimated Inference Memory**:
β€’ FP16: {info['estimated_inference_memory_fp16_gb']:.2f} GB
β€’ BF16: {info['estimated_inference_memory_bf16_gb']:.2f} GB
πŸ“ **SafeTensor Files**: {len(info['safetensors_files'])} files
"""
return output.strip()
# Example usage and testing
if __name__ == "__main__":
calculator = ModelMemoryCalculator()
# Test with FLUX.1-schnell
model_id = "black-forest-labs/FLUX.1-schnell"
print(f"Testing memory calculation for {model_id}...")
memory_info = calculator.get_model_memory_requirements(model_id)
print(json.dumps(memory_info, indent=2))
# Test recommendations
print("\n" + "="*50)
print("MEMORY RECOMMENDATIONS")
print("="*50)
vram_options = [8, 16, 24, 40]
for vram in vram_options:
rec = calculator.get_memory_recommendation(model_id, vram)
print(f"\n🎯 For {vram}GB VRAM:")
if 'recommendations' in rec:
for r in rec['recommendations']:
print(f" {r}")
# Format for display
print("\n" + "="*50)
print("FORMATTED OUTPUT")
print("="*50)
print(calculator.format_memory_info(model_id))