File size: 1,941 Bytes
522275f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import json
from classifiers.llm import LLMClassifier
from litellm import completion
import asyncio

app = FastAPI()

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, replace with specific origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize the LLM classifier
classifier = LLMClassifier(client=completion, model="gpt-3.5-turbo")

class TextInput(BaseModel):
    text: str
    categories: Optional[List[str]] = None

class ClassificationResponse(BaseModel):
    category: str
    confidence: float
    explanation: str

class CategorySuggestionResponse(BaseModel):
    categories: List[str]

@app.post("/classify", response_model=ClassificationResponse)
async def classify_text(text_input: TextInput):
    try:
        # Use async classification
        results = await classifier.classify_async(
            [text_input.text],
            text_input.categories
        )
        result = results[0]  # Get first result since we're classifying one text
        
        return ClassificationResponse(
            category=result["category"],
            confidence=result["confidence"],
            explanation=result["explanation"]
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/suggest-categories", response_model=CategorySuggestionResponse)
async def suggest_categories(texts: List[str]):
    try:
        categories = await classifier._suggest_categories_async(texts)
        return CategorySuggestionResponse(categories=categories)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)