# server.py from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer import torch from huggingface_hub import snapshot_download from safetensors.torch import load_file import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelInput(BaseModel): prompt: str = Field(..., description="The input prompt for text generation") max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate") app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define model paths BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct" ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs" def format_prompt(instruction): """Format the prompt according to the model's expected format.""" return f"""### Instruction: {instruction} ### Response: """ def load_model_and_tokenizer(): """Load the model, tokenizer, and adapter weights.""" try: logger.info("Loading base model...") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_PATH, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto", use_cache=True ) logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( BASE_MODEL_PATH, padding_side="left", truncation_side="left" ) # Ensure the tokenizer has the necessary special tokens special_tokens = { "pad_token": "<|padding|>", "eos_token": "", "bos_token": "", "unk_token": "<|unknown|>" } tokenizer.add_special_tokens(special_tokens) # Resize the model embeddings to match the new tokenizer size model.resize_token_embeddings(len(tokenizer)) logger.info("Downloading adapter weights...") adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH) logger.info("Loading adapter weights...") adapter_file = f"{adapter_path_local}/adapter_model.safetensors" state_dict = load_file(adapter_file) logger.info("Applying adapter weights...") model.load_state_dict(state_dict, strict=False) logger.info("Model and adapter loaded successfully!") return model, tokenizer except Exception as e: logger.error(f"Error during model loading: {e}", exc_info=True) raise # Load model and tokenizer at startup try: model, tokenizer = load_model_and_tokenizer() except Exception as e: logger.error(f"Failed to load model at startup: {e}", exc_info=True) model = None tokenizer = None def generate_response(model, tokenizer, instruction, max_new_tokens=2048): """Generate a response from the model based on an instruction.""" try: # Format the prompt formatted_prompt = format_prompt(instruction) logger.info(f"Formatted prompt: {formatted_prompt}") # Encode input with truncation inputs = tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length, padding=True, add_special_tokens=True ).to(model.device) logger.info(f"Input shape: {inputs.input_ids.shape}") # Generate response with torch.inference_mode(): outputs = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.9, top_k=50, do_sample=True, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, length_penalty=1.0, no_repeat_ngram_size=3 ) logger.info(f"Output shape: {outputs.shape}") # Decode the response response = tokenizer.decode( outputs[0, inputs.input_ids.shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=True ) response = response.strip() logger.info(f"Generated text length: {len(response)}") logger.info(f"Generated text preview: {response[:100]}...") if not response: logger.warning("Empty response generated") raise ValueError("Model generated an empty response") return response except Exception as e: logger.error(f"Error generating response: {e}", exc_info=True) raise ValueError(f"Error generating response: {e}") @app.post("/generate") async def generate_text(input: ModelInput, request: Request): """Generate text based on the input prompt.""" try: if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") logger.info(f"Received request from {request.client.host}") logger.info(f"Prompt: {input.prompt[:100]}...") response = generate_response( model=model, tokenizer=tokenizer, instruction=input.prompt, max_new_tokens=input.max_new_tokens ) return {"generated_text": response} except Exception as e: logger.error(f"Error in generate_text endpoint: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): """Root endpoint that returns a welcome message.""" return {"message": "Welcome to the Model API!", "status": "running"} @app.get("/health") async def health_check(): """Health check endpoint.""" return { "status": "healthy", "model_loaded": model is not None and tokenizer is not None, "model_device": str(next(model.parameters()).device) if model else None, "tokenizer_vocab_size": len(tokenizer) if tokenizer else None } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")