|
from transformers import pipeline |
|
import logging |
|
from fastapi import Request, HTTPException |
|
import base64 |
|
|
|
|
|
class TextToImageTaskService: |
|
|
|
__logger: logging.Logger |
|
|
|
def __init__(self, logger: logging.Logger): |
|
self.__logger = logger |
|
|
|
async def get_encoded_image( |
|
self, |
|
request: Request |
|
) -> str: |
|
content_type = request.headers.get("content-type", "") |
|
if content_type.startswith("multipart/form-data"): |
|
form = await request.form() |
|
image = form.get("image") |
|
if image: |
|
image_bytes = await image.read() |
|
return base64.b64encode(image_bytes).decode("utf-8") |
|
if content_type.startswith("image/"): |
|
image_bytes = await request.body() |
|
return base64.b64encode(image_bytes).decode("utf-8") |
|
|
|
raise HTTPException(status_code=400, detail="Unsupported content type") |
|
|
|
async def extract( |
|
self, |
|
request: Request, |
|
model_name: str |
|
): |
|
encoded_image = await self.get_encoded_image(request) |
|
|
|
try: |
|
pipe = pipeline("image-to-text", model=model_name, use_fast=True) |
|
except Exception as e: |
|
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}") |
|
raise HTTPException( |
|
status_code=404, |
|
detail=f"Model '{model_name}' could not be loaded: {str(e)}" |
|
) |
|
|
|
try: |
|
result = pipe(encoded_image) |
|
except Exception as e: |
|
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Inference failed: {str(e)}" |
|
) |
|
|
|
return result |