|
from transformers import pipeline |
|
from pydantic import BaseModel |
|
import logging |
|
from fastapi import Request, HTTPException |
|
import json |
|
from typing import Optional |
|
|
|
|
|
class ImageClassificationRequest(BaseModel): |
|
inputs: str |
|
parameters: Optional[dict] = None |
|
|
|
class ImageClassificationTaskService: |
|
|
|
__logger: logging.Logger |
|
__task_name: str |
|
|
|
def __init__(self, logger: logging.Logger, task_name: str = "image-classification"): |
|
self.__logger = logger |
|
self.__task_name = task_name |
|
|
|
async def get_image_classification_request( |
|
self, |
|
request: Request |
|
) -> ImageClassificationRequest: |
|
content_type = request.headers.get("content-type", "") |
|
if content_type.startswith("application/json"): |
|
data = await request.json() |
|
return ImageClassificationRequest(**data) |
|
if content_type.startswith("application/x-www-form-urlencoded"): |
|
raw = await request.body() |
|
try: |
|
data = json.loads(raw) |
|
return ImageClassificationRequest(**data) |
|
except Exception: |
|
try: |
|
data = json.loads(raw.decode("utf-8")) |
|
return ImageClassificationRequest(**data) |
|
except Exception: |
|
raise HTTPException(status_code=400, detail="Invalid request body") |
|
raise HTTPException(status_code=400, detail="Unsupported content type") |
|
|
|
|
|
async def classify( |
|
self, |
|
request: Request, |
|
model_name: str |
|
): |
|
|
|
imageRequest: ImageClassificationRequest = await self.get_image_classification_request(request) |
|
|
|
try: |
|
pipe = pipeline(self.__task_name, model=model_name) |
|
except Exception as e: |
|
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}") |
|
raise HTTPException( |
|
status_code=404, |
|
detail=f"Model '{model_name}' could not be loaded: {str(e)}" |
|
) |
|
|
|
try: |
|
|
|
if self.__task_name == "zero-shot-image-classification": |
|
candidate_labels = [] |
|
|
|
if imageRequest.parameters: |
|
candidate_labels = imageRequest.parameters.get('candidate_labels', []) |
|
if isinstance(candidate_labels, str): |
|
candidate_labels = [label.strip() for label in candidate_labels.split(',')] |
|
result = pipe(imageRequest.inputs, candidate_labels=candidate_labels) |
|
|
|
else: |
|
result = pipe(imageRequest.inputs) |
|
|
|
except Exception as e: |
|
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Inference failed: {str(e)}" |
|
) |
|
|
|
return result |