api-smollm135m / app.py
Reality123b's picture
Update app.py
a04b12b verified
raw
history blame
6.55 kB
# server.py
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelInput(BaseModel):
prompt: str = Field(..., description="The input prompt for text generation")
max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define model paths
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"
def format_prompt(instruction):
"""Format the prompt according to the model's expected format."""
return f"""### Instruction:
{instruction}
### Response:
"""
def load_model_and_tokenizer():
"""Load the model, tokenizer, and adapter weights."""
try:
logger.info("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
use_cache=True
)
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_PATH,
padding_side="left",
truncation_side="left"
)
# Ensure the tokenizer has the necessary special tokens
special_tokens = {
"pad_token": "<|padding|>",
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<|unknown|>"
}
tokenizer.add_special_tokens(special_tokens)
# Resize the model embeddings to match the new tokenizer size
model.resize_token_embeddings(len(tokenizer))
logger.info("Downloading adapter weights...")
adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)
logger.info("Loading adapter weights...")
adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
state_dict = load_file(adapter_file)
logger.info("Applying adapter weights...")
model.load_state_dict(state_dict, strict=False)
logger.info("Model and adapter loaded successfully!")
return model, tokenizer
except Exception as e:
logger.error(f"Error during model loading: {e}", exc_info=True)
raise
# Load model and tokenizer at startup
try:
model, tokenizer = load_model_and_tokenizer()
except Exception as e:
logger.error(f"Failed to load model at startup: {e}", exc_info=True)
model = None
tokenizer = None
def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
"""Generate a response from the model based on an instruction."""
try:
# Format the prompt
formatted_prompt = format_prompt(instruction)
logger.info(f"Formatted prompt: {formatted_prompt}")
# Encode input with truncation
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=tokenizer.model_max_length,
padding=True,
add_special_tokens=True
).to(model.device)
logger.info(f"Input shape: {inputs.input_ids.shape}")
# Generate response
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
top_k=50,
do_sample=True,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
length_penalty=1.0,
no_repeat_ngram_size=3
)
logger.info(f"Output shape: {outputs.shape}")
# Decode the response
response = tokenizer.decode(
outputs[0, inputs.input_ids.shape[1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
response = response.strip()
logger.info(f"Generated text length: {len(response)}")
logger.info(f"Generated text preview: {response[:100]}...")
if not response:
logger.warning("Empty response generated")
raise ValueError("Model generated an empty response")
return response
except Exception as e:
logger.error(f"Error generating response: {e}", exc_info=True)
raise ValueError(f"Error generating response: {e}")
@app.post("/generate")
async def generate_text(input: ModelInput, request: Request):
"""Generate text based on the input prompt."""
try:
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
logger.info(f"Received request from {request.client.host}")
logger.info(f"Prompt: {input.prompt[:100]}...")
response = generate_response(
model=model,
tokenizer=tokenizer,
instruction=input.prompt,
max_new_tokens=input.max_new_tokens
)
return {"generated_text": response}
except Exception as e:
logger.error(f"Error in generate_text endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
"""Root endpoint that returns a welcome message."""
return {"message": "Welcome to the Model API!", "status": "running"}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": model is not None and tokenizer is not None,
"model_device": str(next(model.parameters()).device) if model else None,
"tokenizer_vocab_size": len(tokenizer) if tokenizer else None
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")