Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from huggingface_hub import snapshot_download | |
from safetensors.torch import load_file | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ModelInput(BaseModel): | |
prompt: str = Field(..., description="The input prompt for text generation") | |
max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate") | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define model paths | |
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct" | |
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs" | |
def load_model_and_tokenizer(): | |
"""Load the model, tokenizer, and adapter weights.""" | |
try: | |
logger.info("Loading base model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL_PATH, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
logger.info("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) | |
logger.info("Downloading adapter weights...") | |
adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH) | |
logger.info("Loading adapter weights...") | |
adapter_file = f"{adapter_path_local}/adapter_model.safetensors" | |
state_dict = load_file(adapter_file) | |
logger.info("Applying adapter weights...") | |
model.load_state_dict(state_dict, strict=False) | |
logger.info("Model and adapter loaded successfully!") | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Error during model loading: {e}", exc_info=True) | |
raise | |
# Load model and tokenizer at startup | |
try: | |
model, tokenizer = load_model_and_tokenizer() | |
except Exception as e: | |
logger.error(f"Failed to load model at startup: {e}", exc_info=True) | |
model = None | |
tokenizer = None | |
def generate_response(model, tokenizer, instruction, max_new_tokens=2048): | |
"""Generate a response from the model based on an instruction.""" | |
try: | |
logger.info(f"Generating response for instruction: {instruction[:100]}...") | |
# Encode input with truncation | |
inputs = tokenizer.encode( | |
instruction, | |
return_tensors="pt", | |
truncation=True, | |
max_length=tokenizer.model_max_length | |
).to(model.device) | |
logger.info(f"Input shape: {inputs.shape}") | |
# Create attention mask | |
attention_mask = torch.ones(inputs.shape, device=model.device) | |
# Generate response | |
outputs = model.generate( | |
inputs, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
logger.info(f"Output shape: {outputs.shape}") | |
# Decode and strip input prompt from response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
generated_text = response[len(instruction):].strip() | |
logger.info(f"Generated text length: {len(generated_text)}") | |
return generated_text | |
except Exception as e: | |
logger.error(f"Error generating response: {e}", exc_info=True) | |
raise ValueError(f"Error generating response: {e}") | |
async def generate_text(input: ModelInput, request: Request): | |
"""Generate text based on the input prompt.""" | |
try: | |
if model is None or tokenizer is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
logger.info(f"Received request from {request.client.host}") | |
logger.info(f"Prompt: {input.prompt[:100]}...") | |
response = generate_response( | |
model=model, | |
tokenizer=tokenizer, | |
instruction=input.prompt, | |
max_new_tokens=input.max_new_tokens | |
) | |
if not response: | |
logger.warning("Generated empty response") | |
return {"generated_text": "", "warning": "Empty response generated"} | |
logger.info(f"Generated response length: {len(response)}") | |
return {"generated_text": response} | |
except Exception as e: | |
logger.error(f"Error in generate_text endpoint: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
"""Root endpoint that returns a welcome message.""" | |
return {"message": "Welcome to the Model API!", "status": "running"} | |
async def health_check(): | |
"""Health check endpoint.""" | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None and tokenizer is not None, | |
"model_device": str(next(model.parameters()).device) if model else None | |
} |