import os import logging from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Optional # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 1. Load model & tokenizer once at startup MODEL_ID = "EQuIP-Queries/EQuIP_3B" try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID) except Exception as e: logger.error(f"Failed to load model: {e}") raise # 2. Initialize FastAPI app = FastAPI(title="EQuIP Query Generator", description="Generate Elasticsearch queries using EQuIP model", version="1.0.0") # 3. Define request/response schemas class GenerateRequest(BaseModel): prompt: str = Field(..., description="Input prompt for query generation") max_new_tokens: int = Field(default=50, ge=1, le=512, description="Maximum number of tokens to generate") class GenerateResponse(BaseModel): generated_text: str input_prompt: str token_count: Optional[int] # 4. Health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy", "model": MODEL_ID} # 5. Inference endpoint @app.post("/generate", response_model=GenerateResponse) async def generate(req: GenerateRequest): try: logger.info(f"Processing request with prompt: {req.prompt[:50]}...") inputs = tokenizer(req.prompt, return_tensors="pt") ids = model.generate( **inputs, max_new_tokens=req.max_new_tokens, pad_token_id=tokenizer.eos_token_id, num_return_sequences=1 ) generated_text = tokenizer.decode(ids[0], skip_special_tokens=True) token_count = len(ids[0]) return GenerateResponse( generated_text=generated_text, input_prompt=req.prompt, token_count=token_count ) except Exception as e: logger.error(f"Generation failed: {str(e)}") raise HTTPException( status_code=500, detail=f"Generation failed: {str(e)}" )