Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import List, Optional, Dict, Any, Tuple | |
import json | |
from classifiers.llm import LLMClassifier | |
from litellm import completion | |
import asyncio | |
from client import get_client, initialize_client | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
app: FastAPI = FastAPI() | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, replace with specific origins | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize client with API key from environment | |
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") | |
if api_key: | |
success: bool | |
message: str | |
success, message = initialize_client(api_key) | |
if not success: | |
raise RuntimeError(f"Failed to initialize OpenAI client: {message}") | |
client = get_client() | |
if not client: | |
raise RuntimeError("OpenAI client not initialized. Please set OPENAI_API_KEY environment variable.") | |
# Initialize the LLM classifier | |
classifier: LLMClassifier = LLMClassifier(client=client, model="gpt-3.5-turbo") | |
class TextInput(BaseModel): | |
text: str | |
categories: Optional[List[str]] = None | |
class ClassificationResponse(BaseModel): | |
category: str | |
confidence: float | |
explanation: str | |
class CategorySuggestionResponse(BaseModel): | |
categories: List[str] | |
async def classify_text(text_input: TextInput) -> ClassificationResponse: | |
try: | |
# Use async classification | |
results: List[Dict[str, Any]] = await classifier.classify_async( | |
[text_input.text], | |
text_input.categories | |
) | |
result: Dict[str, Any] = results[0] # Get first result since we're classifying one text | |
return ClassificationResponse( | |
category=result["category"], | |
confidence=result["confidence"], | |
explanation=result["explanation"] | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def suggest_categories(texts: List[str]) -> CategorySuggestionResponse: | |
try: | |
categories: List[str] = await classifier._suggest_categories_async(texts) | |
return CategorySuggestionResponse(categories=categories) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True) |