Spaces:
Sleeping
Sleeping
File size: 5,304 Bytes
0f1938f 36183d4 0f1938f e5c1bae 0f1938f 720c911 0f1938f 36183d4 0f1938f 36183d4 0f1938f e5c1bae 0f1938f 36183d4 0f1938f 36183d4 0f1938f 36183d4 e5c1bae 36183d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
import random
import json
import asyncio
from typing import List, Dict, Any, Optional
import sys
import os
# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
from .base import BaseClassifier
class LLMClassifier(BaseClassifier):
"""Classifier using a Large Language Model for more accurate but slower classification"""
def __init__(self, client, model="gpt-3.5-turbo"):
super().__init__()
self.client = client
self.model = model
async def _classify_text_async(self, text: str, categories: List[str]) -> Dict[str, Any]:
"""Async version of text classification"""
prompt = TEXT_CLASSIFICATION_PROMPT.format(
categories=", ".join(categories),
text=text
)
try:
# Use the synchronous client method but run it in a thread pool
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
max_tokens=200,
)
)
# Parse JSON response
response_text = response.choices[0].message.content.strip()
result = json.loads(response_text)
# Ensure all required fields are present
if not all(k in result for k in ["category", "confidence", "explanation"]):
raise ValueError("Missing required fields in LLM response")
# Validate category is in the list
if result["category"] not in categories:
result["category"] = categories[0] # Default to first category if invalid
# Validate confidence is a number between 0 and 100
try:
result["confidence"] = float(result["confidence"])
if not 0 <= result["confidence"] <= 100:
result["confidence"] = 50
except:
result["confidence"] = 50
return result
except json.JSONDecodeError:
# Fall back to simple parsing if JSON fails
category = categories[0] # Default
for cat in categories:
if cat.lower() in response_text.lower():
category = cat
break
return {
"category": category,
"confidence": 50,
"explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
}
except Exception as e:
return {
"category": categories[0],
"confidence": 50,
"explanation": f"Error during classification: {str(e)}",
}
async def _suggest_categories_async(self, texts: List[str], sample_size: int = 20) -> List[str]:
"""Async version of category suggestion"""
# Take a sample of texts to avoid token limitations
if len(texts) > sample_size:
sample_texts = random.sample(texts, sample_size)
else:
sample_texts = texts
prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
try:
# Use the synchronous client method but run it in a thread pool
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
max_tokens=100,
)
)
# Parse response to get categories
categories_text = response.choices[0].message.content.strip()
categories = [cat.strip() for cat in categories_text.split(",")]
return categories
except Exception as e:
# Fallback to default categories on error
print(f"Error suggesting categories: {str(e)}")
return self._generate_default_categories(texts)
async def classify_async(
self, texts: List[str], categories: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Async method to classify texts"""
if not categories:
categories = await self._suggest_categories_async(texts)
# Create tasks for all texts
tasks = [self._classify_text_async(text, categories) for text in texts]
# Gather all results
results = await asyncio.gather(*tasks)
return results
def classify(
self, texts: List[str], categories: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Synchronous wrapper for backwards compatibility"""
return asyncio.run(self.classify_async(texts, categories))
|