|
from transformers import pipeline |
|
from pydantic import BaseModel |
|
import logging |
|
from fastapi import Request, HTTPException |
|
import json |
|
from typing import Optional |
|
|
|
class TranslationRequest(BaseModel): |
|
inputs: str |
|
parameters: Optional[dict] = None |
|
options: Optional[dict] = None |
|
|
|
class TranslationTaskService: |
|
|
|
__logger: logging.Logger |
|
|
|
def __init__(self, logger: logging.Logger): |
|
self.__logger = logger |
|
|
|
async def get_translation_request( |
|
self, |
|
request: Request |
|
) -> TranslationRequest: |
|
content_type = request.headers.get("content-type", "") |
|
if content_type.startswith("application/json"): |
|
data = await request.json() |
|
return TranslationRequest(**data) |
|
if content_type.startswith("application/x-www-form-urlencoded"): |
|
raw = await request.body() |
|
try: |
|
data = json.loads(raw) |
|
return TranslationRequest(**data) |
|
except Exception: |
|
try: |
|
data = json.loads(raw.decode("utf-8")) |
|
return TranslationRequest(**data) |
|
except Exception: |
|
raise HTTPException(status_code=400, detail="Invalid request body") |
|
raise HTTPException(status_code=400, detail="Unsupported content type") |
|
|
|
|
|
async def translate( |
|
self, |
|
request: Request, |
|
model_name: str |
|
): |
|
|
|
translationRequest: TranslationRequest = await self.get_translation_request(request) |
|
|
|
try: |
|
pipe = pipeline("translation", model=model_name) |
|
except Exception as e: |
|
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}") |
|
raise HTTPException( |
|
status_code=404, |
|
detail=f"Model '{model_name}' could not be loaded: {str(e)}" |
|
) |
|
|
|
try: |
|
result = pipe(translationRequest.inputs, **(translationRequest.parameters or {})) |
|
return result |
|
except Exception as e: |
|
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Inference failed: {str(e)}" |
|
) |