Spaces:
Running
Running
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import logging | |
logger = logging.getLogger(__name__) | |
class ModelConfig: | |
"""Configuration for different LLM models optimized for Hugging Face Spaces""" | |
MODELS = { | |
"dialogpt-medium": { | |
"name": "microsoft/DialoGPT-medium", | |
"description": "Conversational AI model, good for chat", | |
"max_length": 512, | |
"memory_usage": "medium", | |
"recommended_for": "chat, conversation" | |
}, | |
"dialogpt-small": { | |
"name": "microsoft/DialoGPT-small", | |
"description": "Smaller conversational model, faster inference", | |
"max_length": 256, | |
"memory_usage": "low", | |
"recommended_for": "quick responses, limited resources" | |
}, | |
"gpt2": { | |
"name": "gpt2", | |
"description": "General purpose text generation", | |
"max_length": 1024, | |
"memory_usage": "medium", | |
"recommended_for": "text generation, creative writing" | |
}, | |
"distilgpt2": { | |
"name": "distilgpt2", | |
"description": "Distilled GPT-2, faster and smaller", | |
"max_length": 512, | |
"memory_usage": "low", | |
"recommended_for": "fast inference, resource constrained" | |
}, | |
"flan-t5-small": { | |
"name": "google/flan-t5-small", | |
"description": "Instruction-tuned T5 model", | |
"max_length": 512, | |
"memory_usage": "low", | |
"recommended_for": "instruction following, Q&A" | |
} | |
} | |
def get_model_info(cls, model_key: str = None): | |
"""Get information about available models""" | |
if model_key: | |
return cls.MODELS.get(model_key) | |
return cls.MODELS | |
def get_recommended_model(cls, use_case: str = "general"): | |
"""Get recommended model based on use case""" | |
recommendations = { | |
"chat": "dialogpt-medium", | |
"fast": "distilgpt2", | |
"general": "gpt2", | |
"qa": "flan-t5-small", | |
"low_memory": "dialogpt-small" | |
} | |
return recommendations.get(use_case, "dialogpt-medium") | |
class ModelManager: | |
"""Manages model loading and inference""" | |
def __init__(self, model_name: str = None): | |
self.model_name = model_name or os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") | |
self.model = None | |
self.tokenizer = None | |
self.pipeline = None | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.loaded = False | |
def load_model(self): | |
"""Load the specified model""" | |
try: | |
logger.info(f"Loading model: {self.model_name}") | |
logger.info(f"Using device: {self.device}") | |
# Load tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
padding_side="left" | |
) | |
# Add padding token if it doesn't exist | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Load model with optimizations | |
model_kwargs = { | |
"low_cpu_mem_usage": True, | |
"torch_dtype": torch.float16 if self.device == "cuda" else torch.float32, | |
} | |
if self.device == "cuda": | |
model_kwargs["device_map"] = "auto" | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
**model_kwargs | |
) | |
# Move to device if not using device_map | |
if self.device == "cpu": | |
self.model = self.model.to(self.device) | |
# Create pipeline | |
self.pipeline = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
device=0 if self.device == "cuda" else -1, | |
return_full_text=False | |
) | |
self.loaded = True | |
logger.info("Model loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise e | |
def generate_response(self, | |
prompt: str, | |
max_length: int = 100, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
do_sample: bool = True) -> str: | |
"""Generate response using the loaded model""" | |
if not self.loaded: | |
raise RuntimeError("Model not loaded. Call load_model() first.") | |
try: | |
# Generate response | |
outputs = self.pipeline( | |
prompt, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=do_sample, | |
pad_token_id=self.tokenizer.eos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
truncation=True | |
) | |
# Extract generated text | |
if outputs and len(outputs) > 0: | |
generated_text = outputs[0]['generated_text'] | |
return generated_text.strip() | |
else: | |
return "Sorry, I couldn't generate a response." | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise e | |
def get_model_info(self): | |
"""Get information about the loaded model""" | |
return { | |
"model_name": self.model_name, | |
"device": self.device, | |
"loaded": self.loaded, | |
"tokenizer_vocab_size": len(self.tokenizer) if self.tokenizer else None, | |
"model_parameters": sum(p.numel() for p in self.model.parameters()) if self.model else None | |
} | |
def unload_model(self): | |
"""Unload the model to free memory""" | |
if self.model: | |
del self.model | |
self.model = None | |
if self.tokenizer: | |
del self.tokenizer | |
self.tokenizer = None | |
if self.pipeline: | |
del self.pipeline | |
self.pipeline = None | |
# Clear CUDA cache if using GPU | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
self.loaded = False | |
logger.info("Model unloaded successfully") | |
# Global model manager instance | |
model_manager = None | |
def get_model_manager(model_name: str = None) -> ModelManager: | |
"""Get or create the global model manager instance""" | |
global model_manager | |
if model_manager is None: | |
model_manager = ModelManager(model_name) | |
return model_manager | |
def initialize_model(model_name: str = None): | |
"""Initialize and load the model""" | |
manager = get_model_manager(model_name) | |
if not manager.loaded: | |
manager.load_model() | |
return manager | |