api-smollm135m / app.py
Reality123b's picture
Update app.py
fbf5fda verified
raw
history blame
5.28 kB
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelInput(BaseModel):
prompt: str = Field(..., description="The input prompt for text generation")
max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define model paths
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"
def load_model_and_tokenizer():
"""Load the model, tokenizer, and adapter weights."""
try:
logger.info("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto"
)
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
logger.info("Downloading adapter weights...")
adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)
logger.info("Loading adapter weights...")
adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
state_dict = load_file(adapter_file)
logger.info("Applying adapter weights...")
model.load_state_dict(state_dict, strict=False)
logger.info("Model and adapter loaded successfully!")
return model, tokenizer
except Exception as e:
logger.error(f"Error during model loading: {e}", exc_info=True)
raise
# Load model and tokenizer at startup
try:
model, tokenizer = load_model_and_tokenizer()
except Exception as e:
logger.error(f"Failed to load model at startup: {e}", exc_info=True)
model = None
tokenizer = None
def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
"""Generate a response from the model based on an instruction."""
try:
logger.info(f"Generating response for instruction: {instruction[:100]}...")
# Encode input with truncation
inputs = tokenizer.encode(
instruction,
return_tensors="pt",
truncation=True,
max_length=tokenizer.model_max_length
).to(model.device)
logger.info(f"Input shape: {inputs.shape}")
# Create attention mask
attention_mask = torch.ones(inputs.shape, device=model.device)
# Generate response
outputs = model.generate(
inputs,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
logger.info(f"Output shape: {outputs.shape}")
# Decode and strip input prompt from response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text = response[len(instruction):].strip()
logger.info(f"Generated text length: {len(generated_text)}")
return generated_text
except Exception as e:
logger.error(f"Error generating response: {e}", exc_info=True)
raise ValueError(f"Error generating response: {e}")
@app.post("/generate")
async def generate_text(input: ModelInput, request: Request):
"""Generate text based on the input prompt."""
try:
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
logger.info(f"Received request from {request.client.host}")
logger.info(f"Prompt: {input.prompt[:100]}...")
response = generate_response(
model=model,
tokenizer=tokenizer,
instruction=input.prompt,
max_new_tokens=input.max_new_tokens
)
if not response:
logger.warning("Generated empty response")
return {"generated_text": "", "warning": "Empty response generated"}
logger.info(f"Generated response length: {len(response)}")
return {"generated_text": response}
except Exception as e:
logger.error(f"Error in generate_text endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
"""Root endpoint that returns a welcome message."""
return {"message": "Welcome to the Model API!", "status": "running"}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": model is not None and tokenizer is not None,
"model_device": str(next(model.parameters()).device) if model else None
}