|
from transformers import pipeline |
|
from pydantic import BaseModel |
|
import logging |
|
from fastapi import Request, HTTPException |
|
import json |
|
from typing import Optional |
|
|
|
|
|
class ClassificationRequest(BaseModel): |
|
inputs: str |
|
parameters: Optional[dict] = None |
|
|
|
class ClassificationTaskService: |
|
|
|
__logger: logging.Logger |
|
__task_name: str |
|
|
|
def __init__(self, logger: logging.Logger, task_name: str): |
|
self.__logger = logger |
|
self.__task_name = task_name |
|
|
|
async def get_classification_request( |
|
self, |
|
request: Request |
|
) -> ClassificationRequest: |
|
content_type = request.headers.get("content-type", "") |
|
if content_type.startswith("application/json"): |
|
data = await request.json() |
|
return ClassificationRequest(**data) |
|
if content_type.startswith("application/x-www-form-urlencoded"): |
|
raw = await request.body() |
|
try: |
|
data = json.loads(raw) |
|
return ClassificationRequest(**data) |
|
except Exception: |
|
try: |
|
data = json.loads(raw.decode("utf-8")) |
|
return ClassificationRequest(**data) |
|
except Exception: |
|
raise HTTPException(status_code=400, detail="Invalid request body") |
|
raise HTTPException(status_code=400, detail="Unsupported content type") |
|
|
|
|
|
async def classify( |
|
self, |
|
request: Request, |
|
model_name: str |
|
): |
|
|
|
classificationRequest: ClassificationRequest = await self.get_classification_request(request) |
|
|
|
try: |
|
pipe = pipeline(self.__task_name, model=model_name) |
|
except Exception as e: |
|
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}") |
|
raise HTTPException( |
|
status_code=404, |
|
detail=f"Model '{model_name}' could not be loaded: {str(e)}" |
|
) |
|
|
|
try: |
|
|
|
if self.__task_name == "zero-shot-image-classification" or self.__task_name == "zero-shot-classification": |
|
candidate_labels = [] |
|
|
|
if classificationRequest.parameters: |
|
candidate_labels = classificationRequest.parameters.get('candidate_labels', []) |
|
if isinstance(candidate_labels, str): |
|
candidate_labels = [label.strip() for label in candidate_labels.split(',')] |
|
result = pipe(classificationRequest.inputs, candidate_labels=candidate_labels) |
|
|
|
else: |
|
result = pipe(classificationRequest.inputs) |
|
|
|
except Exception as e: |
|
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Inference failed: {str(e)}" |
|
) |
|
|
|
return result |