Spaces:
Running
Running
import requests | |
from huggingface_hub import HfApi | |
from typing import Dict, Optional, Tuple | |
import json | |
class ModelMemoryCalculator: | |
def __init__(self): | |
self.hf_api = HfApi() | |
self.cache = {} # Cache results to avoid repeated API calls | |
def get_model_memory_requirements(self, model_id: str) -> Dict: | |
""" | |
Calculate memory requirements for a given HuggingFace model. | |
Args: | |
model_id: HuggingFace model identifier (e.g., "black-forest-labs/FLUX.1-schnell") | |
Returns: | |
Dict with memory information including: | |
- total_params: Total parameter count | |
- memory_fp32: Memory in GB at FP32 precision | |
- memory_fp16: Memory in GB at FP16 precision | |
- memory_bf16: Memory in GB at BF16 precision | |
- safetensors_files: List of safetensor files and their sizes | |
""" | |
if model_id in self.cache: | |
return self.cache[model_id] | |
try: | |
print(f"Fetching model info for {model_id}...") | |
# Get model info | |
model_info = self.hf_api.model_info(model_id) | |
print(f"Model info retrieved successfully") | |
# Get safetensors metadata | |
print(f"Fetching safetensors metadata...") | |
safetensors_metadata = self.hf_api.get_safetensors_metadata(model_id) | |
print(f"Found {len(safetensors_metadata)} safetensor files") | |
total_params = 0 | |
safetensors_files = [] | |
# Iterate through all safetensor files | |
for filename, metadata in safetensors_metadata.items(): | |
file_params = 0 | |
file_size_bytes = 0 | |
# Calculate parameters from tensor metadata | |
if 'metadata' in metadata and metadata['metadata']: | |
for tensor_name, tensor_info in metadata['metadata'].items(): | |
if 'shape' in tensor_info and 'dtype' in tensor_info: | |
# Calculate tensor size | |
shape = tensor_info['shape'] | |
tensor_params = 1 | |
for dim in shape: | |
tensor_params *= dim | |
file_params += tensor_params | |
# Calculate byte size based on dtype | |
dtype = tensor_info['dtype'] | |
bytes_per_param = self._get_bytes_per_param(dtype) | |
file_size_bytes += tensor_params * bytes_per_param | |
total_params += file_params | |
safetensors_files.append({ | |
'filename': filename, | |
'parameters': file_params, | |
'size_bytes': file_size_bytes, | |
'size_mb': file_size_bytes / (1024 * 1024) | |
}) | |
# Calculate memory requirements for different precisions | |
memory_requirements = { | |
'model_id': model_id, | |
'total_params': total_params, | |
'total_params_billions': total_params / 1e9, | |
'memory_fp32_gb': (total_params * 4) / (1024**3), # 4 bytes per param | |
'memory_fp16_gb': (total_params * 2) / (1024**3), # 2 bytes per param | |
'memory_bf16_gb': (total_params * 2) / (1024**3), # 2 bytes per param | |
'memory_int8_gb': (total_params * 1) / (1024**3), # 1 byte per param | |
'safetensors_files': safetensors_files, | |
'estimated_inference_memory_fp16_gb': self._estimate_inference_memory(total_params, 'fp16'), | |
'estimated_inference_memory_bf16_gb': self._estimate_inference_memory(total_params, 'bf16'), | |
} | |
# Cache the result | |
self.cache[model_id] = memory_requirements | |
return memory_requirements | |
except Exception as e: | |
return { | |
'error': str(e), | |
'model_id': model_id, | |
'total_params': 0, | |
'memory_fp32_gb': 0, | |
'memory_fp16_gb': 0, | |
'memory_bf16_gb': 0, | |
} | |
def _get_bytes_per_param(self, dtype: str) -> int: | |
"""Get bytes per parameter for different data types.""" | |
dtype_map = { | |
'F32': 4, 'float32': 4, | |
'F16': 2, 'float16': 2, | |
'BF16': 2, 'bfloat16': 2, | |
'I8': 1, 'int8': 1, | |
'I32': 4, 'int32': 4, | |
'I64': 8, 'int64': 8, | |
} | |
return dtype_map.get(dtype, 4) # Default to 4 bytes (FP32) | |
def _estimate_inference_memory(self, total_params: int, precision: str) -> float: | |
""" | |
Estimate memory requirements during inference. | |
This includes model weights + activations + intermediate tensors. | |
""" | |
bytes_per_param = 2 if precision in ['fp16', 'bf16'] else 4 | |
# Model weights | |
model_memory = (total_params * bytes_per_param) / (1024**3) | |
# Estimate activation memory (rough approximation) | |
# For diffusion models, activations can be 1.5-3x model size during inference | |
activation_multiplier = 2.0 | |
total_inference_memory = model_memory * (1 + activation_multiplier) | |
return total_inference_memory | |
def get_memory_recommendation(self, model_id: str, available_vram_gb: float) -> Dict: | |
""" | |
Get memory recommendations based on available VRAM. | |
Args: | |
model_id: HuggingFace model identifier | |
available_vram_gb: Available VRAM in GB | |
Returns: | |
Dict with recommendations for precision, offloading, etc. | |
""" | |
memory_info = self.get_model_memory_requirements(model_id) | |
if 'error' in memory_info: | |
return {'error': memory_info['error']} | |
recommendations = { | |
'model_id': model_id, | |
'available_vram_gb': available_vram_gb, | |
'model_memory_fp16_gb': memory_info['memory_fp16_gb'], | |
'estimated_inference_memory_fp16_gb': memory_info['estimated_inference_memory_fp16_gb'], | |
'recommendations': [] | |
} | |
inference_memory_fp16 = memory_info['estimated_inference_memory_fp16_gb'] | |
inference_memory_bf16 = memory_info['estimated_inference_memory_bf16_gb'] | |
# Determine recommendations | |
if available_vram_gb >= inference_memory_bf16: | |
recommendations['recommendations'].append("β Full model can fit in VRAM with BF16 precision") | |
recommendations['recommended_precision'] = 'bfloat16' | |
recommendations['cpu_offload'] = False | |
recommendations['attention_slicing'] = False | |
elif available_vram_gb >= inference_memory_fp16: | |
recommendations['recommendations'].append("β Full model can fit in VRAM with FP16 precision") | |
recommendations['recommended_precision'] = 'float16' | |
recommendations['cpu_offload'] = False | |
recommendations['attention_slicing'] = False | |
elif available_vram_gb >= memory_info['memory_fp16_gb']: | |
recommendations['recommendations'].append("β οΈ Model weights fit, but may need memory optimizations") | |
recommendations['recommended_precision'] = 'float16' | |
recommendations['cpu_offload'] = False | |
recommendations['attention_slicing'] = True | |
recommendations['vae_slicing'] = True | |
else: | |
recommendations['recommendations'].append("π Requires CPU offloading and memory optimizations") | |
recommendations['recommended_precision'] = 'float16' | |
recommendations['cpu_offload'] = True | |
recommendations['sequential_offload'] = True | |
recommendations['attention_slicing'] = True | |
recommendations['vae_slicing'] = True | |
return recommendations | |
def format_memory_info(self, model_id: str) -> str: | |
"""Format memory information for display.""" | |
info = self.get_model_memory_requirements(model_id) | |
if 'error' in info: | |
return f"β Error calculating memory for {model_id}: {info['error']}" | |
output = f""" | |
π **Memory Requirements for {model_id}** | |
π’ **Parameters**: {info['total_params_billions']:.2f}B parameters | |
πΎ **Model Memory**: | |
β’ FP32: {info['memory_fp32_gb']:.2f} GB | |
β’ FP16/BF16: {info['memory_fp16_gb']:.2f} GB | |
β’ INT8: {info['memory_int8_gb']:.2f} GB | |
π **Estimated Inference Memory**: | |
β’ FP16: {info['estimated_inference_memory_fp16_gb']:.2f} GB | |
β’ BF16: {info['estimated_inference_memory_bf16_gb']:.2f} GB | |
π **SafeTensor Files**: {len(info['safetensors_files'])} files | |
""" | |
return output.strip() | |
# Example usage and testing | |
if __name__ == "__main__": | |
calculator = ModelMemoryCalculator() | |
# Test with FLUX.1-schnell | |
model_id = "black-forest-labs/FLUX.1-schnell" | |
print(f"Testing memory calculation for {model_id}...") | |
memory_info = calculator.get_model_memory_requirements(model_id) | |
print(json.dumps(memory_info, indent=2)) | |
# Test recommendations | |
print("\n" + "="*50) | |
print("MEMORY RECOMMENDATIONS") | |
print("="*50) | |
vram_options = [8, 16, 24, 40] | |
for vram in vram_options: | |
rec = calculator.get_memory_recommendation(model_id, vram) | |
print(f"\nπ― For {vram}GB VRAM:") | |
if 'recommendations' in rec: | |
for r in rec['recommendations']: | |
print(f" {r}") | |
# Format for display | |
print("\n" + "="*50) | |
print("FORMATTED OUTPUT") | |
print("="*50) | |
print(calculator.format_memory_info(model_id)) |