Spaces:
Sleeping
Sleeping
# server.py | |
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from huggingface_hub import snapshot_download | |
from safetensors.torch import load_file | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ModelInput(BaseModel): | |
prompt: str = Field(..., description="The input prompt for text generation") | |
max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate") | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define model paths | |
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct" | |
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs" | |
def format_prompt(instruction): | |
"""Format the prompt according to the model's expected format.""" | |
return f"""### Instruction: | |
{instruction} | |
### Response: | |
""" | |
def load_model_and_tokenizer(): | |
"""Load the model, tokenizer, and adapter weights.""" | |
try: | |
logger.info("Loading base model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL_PATH, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
device_map="auto", | |
use_cache=True | |
) | |
logger.info("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
BASE_MODEL_PATH, | |
padding_side="left", | |
truncation_side="left" | |
) | |
# Ensure the tokenizer has the necessary special tokens | |
special_tokens = { | |
"pad_token": "<|padding|>", | |
"eos_token": "</s>", | |
"bos_token": "<s>", | |
"unk_token": "<|unknown|>" | |
} | |
tokenizer.add_special_tokens(special_tokens) | |
# Resize the model embeddings to match the new tokenizer size | |
model.resize_token_embeddings(len(tokenizer)) | |
logger.info("Downloading adapter weights...") | |
adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH) | |
logger.info("Loading adapter weights...") | |
adapter_file = f"{adapter_path_local}/adapter_model.safetensors" | |
state_dict = load_file(adapter_file) | |
logger.info("Applying adapter weights...") | |
model.load_state_dict(state_dict, strict=False) | |
logger.info("Model and adapter loaded successfully!") | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Error during model loading: {e}", exc_info=True) | |
raise | |
# Load model and tokenizer at startup | |
try: | |
model, tokenizer = load_model_and_tokenizer() | |
except Exception as e: | |
logger.error(f"Failed to load model at startup: {e}", exc_info=True) | |
model = None | |
tokenizer = None | |
def generate_response(model, tokenizer, instruction, max_new_tokens=2048): | |
"""Generate a response from the model based on an instruction.""" | |
try: | |
# Format the prompt | |
formatted_prompt = format_prompt(instruction) | |
logger.info(f"Formatted prompt: {formatted_prompt}") | |
# Encode input with truncation | |
inputs = tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=tokenizer.model_max_length, | |
padding=True, | |
add_special_tokens=True | |
).to(model.device) | |
logger.info(f"Input shape: {inputs.input_ids.shape}") | |
# Generate response | |
with torch.inference_mode(): | |
outputs = model.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=max_new_tokens, | |
temperature=0.7, | |
top_p=0.9, | |
top_k=50, | |
do_sample=True, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.1, | |
length_penalty=1.0, | |
no_repeat_ngram_size=3 | |
) | |
logger.info(f"Output shape: {outputs.shape}") | |
# Decode the response | |
response = tokenizer.decode( | |
outputs[0, inputs.input_ids.shape[1]:], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
response = response.strip() | |
logger.info(f"Generated text length: {len(response)}") | |
logger.info(f"Generated text preview: {response[:100]}...") | |
if not response: | |
logger.warning("Empty response generated") | |
raise ValueError("Model generated an empty response") | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {e}", exc_info=True) | |
raise ValueError(f"Error generating response: {e}") | |
async def generate_text(input: ModelInput, request: Request): | |
"""Generate text based on the input prompt.""" | |
try: | |
if model is None or tokenizer is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
logger.info(f"Received request from {request.client.host}") | |
logger.info(f"Prompt: {input.prompt[:100]}...") | |
response = generate_response( | |
model=model, | |
tokenizer=tokenizer, | |
instruction=input.prompt, | |
max_new_tokens=input.max_new_tokens | |
) | |
return {"generated_text": response} | |
except Exception as e: | |
logger.error(f"Error in generate_text endpoint: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
"""Root endpoint that returns a welcome message.""" | |
return {"message": "Welcome to the Model API!", "status": "running"} | |
async def health_check(): | |
"""Health check endpoint.""" | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None and tokenizer is not None, | |
"model_device": str(next(model.parameters()).device) if model else None, | |
"tokenizer_vocab_size": len(tokenizer) if tokenizer else None | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |