MEME / model_setup.py
Chanlefe's picture
Create model_setup.py
aa1262d verified
#!/usr/bin/env python3
"""
Helper script to prepare models for deployment
"""
import os
import zipfile
import shutil
from pathlib import Path
def setup_bert_model():
"""Extract and setup the fine-tuned BERT model"""
zip_path = "fine_tuned_bert_sentiment.zip"
extract_path = "./fine_tuned_bert_sentiment"
if not os.path.exists(zip_path):
print(f"❌ {zip_path} not found. Please upload your fine-tuned BERT model.")
return False
print(f"πŸ“¦ Extracting {zip_path}...")
# Create extraction directory
os.makedirs(extract_path, exist_ok=True)
# Extract zip file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_path)
# Verify required files exist
required_files = [
"config.json",
"pytorch_model.bin",
"tokenizer_config.json",
"vocab.txt"
]
missing_files = []
for file in required_files:
if not os.path.exists(os.path.join(extract_path, file)):
missing_files.append(file)
if missing_files:
print(f"⚠️ Missing required files: {missing_files}")
return False
print("βœ… BERT model setup complete!")
return True
def download_fallback_models():
"""Download fallback models if needed"""
from transformers import AutoTokenizer, AutoModel
print("πŸ“₯ Downloading fallback models...")
# Download SigLIP model
try:
AutoTokenizer.from_pretrained("google/siglip-large-patch16-384")
AutoModel.from_pretrained("google/siglip-large-patch16-384")
print("βœ… SigLIP-Large downloaded")
except Exception as e:
print(f"⚠️ SigLIP-Large download failed: {e}")
print("πŸ“₯ Downloading SigLIP-Base as fallback...")
AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
AutoModel.from_pretrained("google/siglip-base-patch16-224")
# Download sentiment model
AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
AutoModel.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
print("βœ… Sentiment model downloaded")
if __name__ == "__main__":
print("πŸš€ Setting up Enhanced Ensemble Model...")
# Setup BERT model
bert_success = setup_bert_model()
# Download other models
download_fallback_models()
if bert_success:
print("πŸŽ‰ All models ready for deployment!")
else:
print("⚠️ Deployment ready with fallback models. Upload your BERT model for best performance.")