Lyon28's picture
Rename app.py to main.py
90dbd26 verified
raw
history blame
3.12 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
import torch
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="Model Inference API")
# Allow CORS for external frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_MAP = {
"tinny-llama": "Lyon28/Tinny-Llama",
"pythia": "Lyon28/Pythia",
"bert-tinny": "Lyon28/Bert-Tinny",
"albert-base-v2": "Lyon28/Albert-Base-V2",
"t5-small": "Lyon28/T5-Small",
"gpt-2": "Lyon28/GPT-2",
"gpt-neo": "Lyon28/GPT-Neo",
"distilbert-base-uncased": "Lyon28/Distilbert-Base-Uncased",
"distil-gpt-2": "Lyon28/Distil_GPT-2",
"gpt-2-tinny": "Lyon28/GPT-2-Tinny",
"electra-small": "Lyon28/Electra-Small"
}
TASK_MAP = {
"text-generation": ["gpt-2", "gpt-neo", "distil-gpt-2", "gpt-2-tinny", "tinny-llama", "pythia"],
"text-classification": ["bert-tinny", "albert-base-v2", "distilbert-base-uncased", "electra-small"],
"text2text-generation": ["t5-small"]
}
class InferenceRequest(BaseModel):
text: str
max_length: int = 100
temperature: float = 0.9
def get_task(model_id: str):
for task, models in TASK_MAP.items():
if model_id in models:
return task
return "text-generation"
@app.on_event("startup")
async def load_models():
# Initialize models (optional: pre-load critical models)
app.state.pipelines = {}
print("Models initialized in memory")
@app.post("/inference/{model_id}")
async def model_inference(model_id: str, request: InferenceRequest):
try:
if model_id not in MODEL_MAP:
raise HTTPException(status_code=404, detail="Model not found")
task = get_task(model_id)
# Load pipeline with caching
if model_id not in app.state.pipelines:
app.state.pipelines[model_id] = pipeline(
task=task,
model=MODEL_MAP[model_id],
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = app.state.pipelines[model_id]
# Process based on task
if task == "text-generation":
result = pipe(
request.text,
max_length=request.max_length,
temperature=request.temperature
)[0]['generated_text']
elif task == "text-classification":
output = pipe(request.text)[0]
result = {
"label": output['label'],
"confidence": round(output['score'], 4)
}
elif task == "text2text-generation":
result = pipe(request.text)[0]['generated_text']
return {"result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def list_models():
return {"available_models": list(MODEL_MAP.keys())}
@app.get("/health")
async def health_check():
return {"status": "healthy"}