classifieur / server.py
simondh's picture
add tpes
535a3a5
raw
history blame
2.69 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict, Any, Tuple
import json
from classifiers.llm import LLMClassifier
from litellm import completion
import asyncio
from client import get_client, initialize_client
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
app: FastAPI = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize client with API key from environment
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
if api_key:
success: bool
message: str
success, message = initialize_client(api_key)
if not success:
raise RuntimeError(f"Failed to initialize OpenAI client: {message}")
client = get_client()
if not client:
raise RuntimeError("OpenAI client not initialized. Please set OPENAI_API_KEY environment variable.")
# Initialize the LLM classifier
classifier: LLMClassifier = LLMClassifier(client=client, model="gpt-3.5-turbo")
class TextInput(BaseModel):
text: str
categories: Optional[List[str]] = None
class ClassificationResponse(BaseModel):
category: str
confidence: float
explanation: str
class CategorySuggestionResponse(BaseModel):
categories: List[str]
@app.post("/classify", response_model=ClassificationResponse)
async def classify_text(text_input: TextInput) -> ClassificationResponse:
try:
# Use async classification
results: List[Dict[str, Any]] = await classifier.classify_async(
[text_input.text],
text_input.categories
)
result: Dict[str, Any] = results[0] # Get first result since we're classifying one text
return ClassificationResponse(
category=result["category"],
confidence=result["confidence"],
explanation=result["explanation"]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/suggest-categories", response_model=CategorySuggestionResponse)
async def suggest_categories(texts: List[str]) -> CategorySuggestionResponse:
try:
categories: List[str] = await classifier._suggest_categories_async(texts)
return CategorySuggestionResponse(categories=categories)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)