|
import os |
|
import logging |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel, Field |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from typing import Optional |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_ID = "EQuIP-Queries/EQuIP_3B" |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID) |
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
raise |
|
|
|
|
|
app = FastAPI(title="EQuIP Query Generator", |
|
description="Generate Elasticsearch queries using EQuIP model", |
|
version="1.0.0") |
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
prompt: str = Field(..., description="Input prompt for query generation") |
|
max_new_tokens: int = Field(default=50, ge=1, le=512, description="Maximum number of tokens to generate") |
|
|
|
class GenerateResponse(BaseModel): |
|
generated_text: str |
|
input_prompt: str |
|
token_count: Optional[int] |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy", "model": MODEL_ID} |
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse) |
|
async def generate(req: GenerateRequest): |
|
try: |
|
logger.info(f"Processing request with prompt: {req.prompt[:50]}...") |
|
inputs = tokenizer(req.prompt, return_tensors="pt") |
|
|
|
ids = model.generate( |
|
**inputs, |
|
max_new_tokens=req.max_new_tokens, |
|
pad_token_id=tokenizer.eos_token_id, |
|
num_return_sequences=1 |
|
) |
|
|
|
generated_text = tokenizer.decode(ids[0], skip_special_tokens=True) |
|
token_count = len(ids[0]) |
|
|
|
return GenerateResponse( |
|
generated_text=generated_text, |
|
input_prompt=req.prompt, |
|
token_count=token_count |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Generation failed: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Generation failed: {str(e)}" |
|
) |