Spaces:
Running
Running
# ------------------------------------------------------------------- | |
# This source file is available under the terms of the | |
# Pimcore Open Core License (POCL) | |
# Full copyright and license information is available in | |
# LICENSE.md which is distributed with this source code. | |
# | |
# @copyright Copyright (c) Pimcore GmbH (https://www.pimcore.com) | |
# @license Pimcore Open Core License (POCL) | |
# ------------------------------------------------------------------- | |
import os | |
import torch | |
from fastapi import FastAPI, Path, Depends, HTTPException, UploadFile, Form, File, status, Request | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel | |
from typing import Annotated | |
import json | |
import logging | |
import sys | |
import base64 | |
from transformers import pipeline | |
app = FastAPI( | |
title="Pimcore Local Inference Service", | |
description="This services allows HF inference provider compatible inference to models which are not available at HF inference providers.", | |
version="1.0.0" | |
) | |
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s') | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
class StreamToLogger(object): | |
def __init__(self, logger, log_level): | |
self.logger = logger | |
self.log_level = log_level | |
self.linebuf = '' | |
def write(self, buf): | |
for line in buf.rstrip().splitlines(): | |
self.logger.log(self.log_level, line.rstrip()) | |
def flush(self): | |
pass | |
sys.stdout = StreamToLogger(logger, logging.INFO) | |
sys.stderr = StreamToLogger(logger, logging.ERROR) | |
class ResponseModel(BaseModel): | |
""" Default response model for endpoints. """ | |
message: str | |
success: bool = True | |
async def gpu_check(): | |
""" Check if a GPU is available """ | |
gpu = 'GPU not available' | |
if torch.cuda.is_available(): | |
gpu = 'GPU is available' | |
print("GPU is available") | |
else: | |
print("GPU is not available") | |
return {'success': True, 'gpu': gpu} | |
from typing import Optional | |
# ========================= | |
# Translation Task | |
# ========================= | |
class TranslationRequest(BaseModel): | |
inputs: str | |
parameters: Optional[dict] = None | |
options: Optional[dict] = None | |
async def get_translation_request( | |
request: Request | |
) -> TranslationRequest: | |
content_type = request.headers.get("content-type", "") | |
if content_type.startswith("application/json"): | |
data = await request.json() | |
return TranslationRequest(**data) | |
if content_type.startswith("application/x-www-form-urlencoded"): | |
raw = await request.body() | |
try: | |
data = json.loads(raw) | |
return TranslationRequest(**data) | |
except Exception: | |
try: | |
data = json.loads(raw.decode("utf-8")) | |
return TranslationRequest(**data) | |
except Exception: | |
raise HTTPException(status_code=400, detail="Invalid request body") | |
raise HTTPException(status_code=400, detail="Unsupported content type") | |
async def translate( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the translation model (e.g. Helsinki-NLP/opus-mt-en-de)", | |
example="Helsinki-NLP/opus-mt-en-de" | |
) | |
): | |
""" | |
Execute translation tasks. | |
Returns: | |
list: The translation result(s) as returned by the pipeline. | |
""" | |
translationRequest: TranslationRequest = await get_translation_request(request) | |
try: | |
pipe = pipeline("translation", model=model_name) | |
except Exception as e: | |
logger.error(f"Failed to load model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{model_name}' could not be loaded: {str(e)}" | |
) | |
try: | |
result = pipe(translationRequest.inputs, **(translationRequest.parameters or {})) | |
except Exception as e: | |
logger.error(f"Inference failed for model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Inference failed: {str(e)}" | |
) | |
return result | |
# ========================= | |
# Zero-Shot Image Classification Task | |
# ========================= | |
class ZeroShotImageClassificationRequest(BaseModel): | |
inputs: str | |
parameters: Optional[dict] = None | |
async def get_zero_shot_image_classification_request( | |
request: Request | |
) -> ZeroShotImageClassificationRequest: | |
content_type = request.headers.get("content-type", "") | |
if content_type.startswith("application/json"): | |
data = await request.json() | |
return ZeroShotImageClassificationRequest(**data) | |
if content_type.startswith("application/x-www-form-urlencoded"): | |
raw = await request.body() | |
try: | |
data = json.loads(raw) | |
return ZeroShotImageClassificationRequest(**data) | |
except Exception: | |
try: | |
data = json.loads(raw.decode("utf-8")) | |
return ZeroShotImageClassificationRequest(**data) | |
except Exception: | |
raise HTTPException(status_code=400, detail="Invalid request body") | |
raise HTTPException(status_code=400, detail="Unsupported content type") | |
async def zero_shot_image_classification( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the zero-shot classification model (e.g., openai/clip-vit-large-patch14-336)", | |
example="openai/clip-vit-large-patch14-336" | |
) | |
): | |
""" | |
Execute zero-shot image classification tasks. | |
Returns: | |
list: The classification result(s) as returned by the pipeline. | |
""" | |
zeroShotRequest: ZeroShotImageClassificationRequest = await get_zero_shot_image_classification_request(request) | |
try: | |
pipe = pipeline("zero-shot-image-classification", model=model_name) | |
except Exception as e: | |
logger.error(f"Failed to load model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{model_name}' could not be loaded: {str(e)}" | |
) | |
try: | |
candidate_labels = [] | |
if zeroShotRequest.parameters: | |
candidate_labels = zeroShotRequest.parameters.get('candidate_labels', []) | |
if isinstance(candidate_labels, str): | |
candidate_labels = [label.strip() for label in candidate_labels.split(',')] | |
result = pipe(zeroShotRequest.inputs, candidate_labels=candidate_labels) | |
except Exception as e: | |
logger.error(f"Inference failed for model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Inference failed: {str(e)}" | |
) | |
return result | |
# ========================= | |
# Image to Text Task | |
# ========================= | |
async def get_encoded_image( | |
request: Request | |
) -> str: | |
content_type = request.headers.get("content-type", "") | |
if content_type.startswith("multipart/form-data"): | |
form = await request.form() | |
image = form.get("image") | |
if image: | |
image_bytes = await image.read() | |
return base64.b64encode(image_bytes).decode("utf-8") | |
if content_type.startswith("image/"): | |
image_bytes = await request.body() | |
return base64.b64encode(image_bytes).decode("utf-8") | |
raise HTTPException(status_code=400, detail="Unsupported content type") | |
async def image_to_text( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the image-to-text (e.g., Salesforce/blip-image-captioning-base)", | |
example="Salesforce/blip-image-captioning-base" | |
) | |
): | |
""" | |
Execute image-to-text tasks. | |
Returns: | |
list: The generated text as returned by the pipeline. | |
""" | |
encoded_image = await get_encoded_image(request) | |
try: | |
pipe = pipeline("image-to-text", model=model_name, use_fast=True) | |
except Exception as e: | |
logger.error(f"Failed to load model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{model_name}' could not be loaded: {str(e)}" | |
) | |
try: | |
result = pipe(encoded_image) | |
except Exception as e: | |
logger.error(f"Inference failed for model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Inference failed: {str(e)}" | |
) | |
return result |