Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
Debug script to monitor memory usage during model loading. | |
Run this to identify exactly where the memory issues occur. | |
""" | |
import gc | |
import logging | |
import os | |
import sys | |
from typing import Optional | |
import psutil | |
import torch | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
def get_memory_info(): | |
"""Get current memory usage information.""" | |
process = psutil.Process(os.getpid()) | |
memory_info = process.memory_info() | |
virtual_memory = psutil.virtual_memory() | |
return { | |
"process_rss_gb": memory_info.rss / (1024**3), # Resident Set Size | |
"process_vms_gb": memory_info.vms / (1024**3), # Virtual Memory Size | |
"system_total_gb": virtual_memory.total / (1024**3), | |
"system_available_gb": virtual_memory.available / (1024**3), | |
"system_used_gb": virtual_memory.used / (1024**3), | |
"system_percent": virtual_memory.percent, | |
} | |
def log_memory_usage(step: str): | |
"""Log current memory usage with a step description.""" | |
mem_info = get_memory_info() | |
logger.info(f"=== {step} ===") | |
logger.info(f"Process RSS: {mem_info['process_rss_gb']:.2f} GB") | |
logger.info(f"Process VMS: {mem_info['process_vms_gb']:.2f} GB") | |
logger.info(f"System Total: {mem_info['system_total_gb']:.2f} GB") | |
logger.info(f"System Available: {mem_info['system_available_gb']:.2f} GB") | |
logger.info( | |
f"System Used: {mem_info['system_used_gb']:.2f} GB ({mem_info['system_percent']:.1f}%)" | |
) | |
if torch.cuda.is_available(): | |
logger.info( | |
f"CUDA Memory Allocated: {torch.cuda.memory_allocated() / (1024**3):.2f} GB" | |
) | |
logger.info( | |
f"CUDA Memory Cached: {torch.cuda.memory_reserved() / (1024**3):.2f} GB" | |
) | |
logger.info("") | |
def force_cleanup(): | |
"""Force garbage collection and memory cleanup.""" | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def debug_model_loading(models_dir: str = "/home/user/app/models"): | |
"""Debug the model loading process step by step.""" | |
ckpt_path = os.path.join(models_dir, "mms_XRI.pt") | |
tokenizer_path = os.path.join(models_dir, "mms_1143_langs_tokenizer_spm.model") | |
logger.info("Starting memory debugging for MMS model loading...") | |
logger.info(f"Checkpoint path: {ckpt_path}") | |
logger.info(f"Tokenizer path: {tokenizer_path}") | |
# Check file sizes | |
if os.path.exists(ckpt_path): | |
ckpt_size_gb = os.path.getsize(ckpt_path) / (1024**3) | |
logger.info(f"Checkpoint file size: {ckpt_size_gb:.2f} GB") | |
else: | |
logger.error(f"Checkpoint file not found: {ckpt_path}") | |
return | |
log_memory_usage("Initial state") | |
try: | |
# Step 1: Check available memory before loading | |
mem_info = get_memory_info() | |
if mem_info["system_available_gb"] < ckpt_size_gb * 1.5: | |
logger.warning( | |
f"Available memory ({mem_info['system_available_gb']:.2f} GB) may be insufficient for checkpoint ({ckpt_size_gb:.2f} GB)" | |
) | |
# Step 2: Try to load checkpoint with memory mapping | |
logger.info("Step 1: Loading checkpoint with memory mapping...") | |
try: | |
# Use mmap=True to avoid loading entire file into memory | |
model_params = torch.load(ckpt_path, map_location="cpu", mmap=True) | |
log_memory_usage("After loading checkpoint (mmap)") | |
except Exception as e: | |
logger.error(f"Memory-mapped loading failed: {e}") | |
logger.info("Falling back to regular loading...") | |
model_params = torch.load(ckpt_path, map_location="cpu") | |
log_memory_usage("After loading checkpoint (regular)") | |
# Step 3: Setup fairseq2 and configs | |
logger.info("Step 2: Setting up fairseq2 and configs...") | |
from fairseq2 import setup_fairseq2 | |
from fairseq2.context import get_runtime_context | |
from fairseq2.models.llama import LLaMAConfig | |
# Import the model classes | |
sys.path.append("/home/user/app/server") | |
from model import ( | |
register_wav2vec2_asr_configs, | |
register_wav2vec2_configs, | |
Wav2Vec2AsrConfig, | |
Wav2Vec2LlamaConfig, | |
Wav2Vec2LlamaFactory, | |
) | |
setup_fairseq2() | |
context = get_runtime_context() | |
register_wav2vec2_configs(context) | |
register_wav2vec2_asr_configs(context) | |
log_memory_usage("After fairseq2 setup") | |
# Step 4: Create configs | |
logger.info("Step 3: Creating model configuration...") | |
w2v2_ctc_registry = context.get_config_registry(Wav2Vec2AsrConfig) | |
wav2vec_ctc_config = w2v2_ctc_registry.get("7b_bib1143") | |
llama_config = LLaMAConfig( | |
model_dim=4096, | |
max_seq_len=8192, | |
vocab_info=wav2vec_ctc_config.vocab_info, | |
num_layers=12, | |
num_attn_heads=8, | |
num_key_value_heads=8, | |
ffn_inner_dim=4096, | |
rope_theta=10_000.0, | |
dropout_p=0.1, | |
) | |
config = Wav2Vec2LlamaConfig() | |
config.wav2vec_ctc_config = wav2vec_ctc_config | |
config.llama_config = llama_config | |
log_memory_usage("After creating configs") | |
# Step 5: Create model architecture (without loading weights) | |
logger.info("Step 4: Creating model architecture...") | |
factory = Wav2Vec2LlamaFactory(config) | |
model = factory.create_model() | |
log_memory_usage("After creating model architecture") | |
# Step 6: Load state dict | |
logger.info("Step 5: Loading model weights...") | |
try: | |
model.load_state_dict(model_params["model"]) | |
log_memory_usage("After loading model weights") | |
except Exception as e: | |
logger.error(f"Failed to load model weights: {e}") | |
return | |
# Step 7: Clean up checkpoint data | |
logger.info("Step 6: Cleaning up checkpoint data...") | |
del model_params | |
force_cleanup() | |
log_memory_usage("After cleanup") | |
# Step 8: Move to device (if specified) | |
device = torch.device("cpu") # Force CPU for debugging | |
logger.info(f"Step 7: Moving model to device {device}...") | |
model = model.to(device).eval() | |
log_memory_usage("After moving to device") | |
logger.info("Model loading completed successfully!") | |
# Step 9: Test a small inference to see memory usage | |
logger.info("Step 8: Testing small inference...") | |
try: | |
# Create a small dummy input | |
dummy_input = torch.randn(1, 16000).to(device) # 1 second of audio | |
with torch.no_grad(): | |
# Just test the encoder part to avoid full inference | |
enc_out = model.encoder_frontend.extract_features(dummy_input, None) | |
log_memory_usage("After small inference test") | |
except Exception as e: | |
logger.error(f"Small inference test failed: {e}") | |
return model | |
except Exception as e: | |
logger.error(f"Error during model loading: {str(e)}") | |
log_memory_usage("After error") | |
raise | |
def check_docker_memory_limits(): | |
"""Check if we're running in Docker and what the memory limits are.""" | |
logger.info("Checking Docker memory configuration...") | |
# Check if we're in a container | |
if os.path.exists("/.dockerenv"): | |
logger.info("Running inside Docker container") | |
# Check cgroup memory limits | |
try: | |
with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f: | |
limit_bytes = int(f.read().strip()) | |
limit_gb = limit_bytes / (1024**3) | |
logger.info(f"Docker memory limit: {limit_gb:.2f} GB") | |
# Check if limit is reasonable (not the default huge value) | |
if limit_gb > 1000: # Probably unlimited | |
logger.warning("Docker memory limit appears to be unlimited") | |
else: | |
logger.info(f"Docker memory limit is set to {limit_gb:.2f} GB") | |
except Exception as e: | |
logger.warning(f"Could not read Docker memory limit: {e}") | |
# Check current memory usage in container | |
try: | |
with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f: | |
usage_bytes = int(f.read().strip()) | |
usage_gb = usage_bytes / (1024**3) | |
logger.info(f"Current Docker memory usage: {usage_gb:.2f} GB") | |
except Exception as e: | |
logger.warning(f"Could not read Docker memory usage: {e}") | |
else: | |
logger.info("Not running in Docker container") | |
if __name__ == "__main__": | |
# Check Docker memory configuration | |
check_docker_memory_limits() | |
# Get models directory from environment or use default | |
models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models") | |
# Run the debugging | |
try: | |
model = debug_model_loading(models_dir) | |
logger.info("Memory debugging completed successfully!") | |
except Exception as e: | |
logger.error(f"Memory debugging failed: {e}") | |
sys.exit(1) | |