File size: 2,489 Bytes
522275f
 
 
 
 
 
 
 
49520a1
 
 
 
 
 
 
522275f
 
 
 
 
 
 
 
 
 
 
49520a1
 
 
 
 
 
 
3c4ab41
49520a1
 
3c4ab41
522275f
3c4ab41
522275f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c4ab41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import json
from classifiers.llm import LLMClassifier
from litellm import completion
import asyncio
from client import get_client, initialize_client
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

app = FastAPI()

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, replace with specific origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize client with API key from environment
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
    success, message = initialize_client(api_key)
    if not success:
        raise RuntimeError(f"Failed to initialize OpenAI client: {message}")

client = get_client()
if not client:
    raise RuntimeError("OpenAI client not initialized. Please set OPENAI_API_KEY environment variable.")

# Initialize the LLM classifier
classifier = LLMClassifier(client=client, model="gpt-3.5-turbo")

class TextInput(BaseModel):
    text: str
    categories: Optional[List[str]] = None

class ClassificationResponse(BaseModel):
    category: str
    confidence: float
    explanation: str

class CategorySuggestionResponse(BaseModel):
    categories: List[str]

@app.post("/classify", response_model=ClassificationResponse)
async def classify_text(text_input: TextInput):
    try:
        # Use async classification
        results = await classifier.classify_async(
            [text_input.text],
            text_input.categories
        )
        result = results[0]  # Get first result since we're classifying one text
        
        return ClassificationResponse(
            category=result["category"],
            confidence=result["confidence"],
            explanation=result["explanation"]
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/suggest-categories", response_model=CategorySuggestionResponse)
async def suggest_categories(texts: List[str]):
    try:
        categories = await classifier._suggest_categories_async(texts)
        return CategorySuggestionResponse(categories=categories)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)