|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
import os |
|
import httpx |
|
|
|
|
|
|
|
""" |
|
FastAPI Backend AI Service using Gemma-3n-E4B-it |
|
Provides OpenAI-compatible chat completion endpoints powered by google/gemma-3n-E4B-it |
|
""" |
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") |
|
warnings.filterwarnings("ignore", message=".*slow image processor.*") |
|
warnings.filterwarnings("ignore", message=".*rope_scaling.*") |
|
|
|
|
|
os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface") |
|
|
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" |
|
hf_token = os.environ.get("HF_TOKEN") |
|
import asyncio |
|
import logging |
|
import time |
|
from contextlib import asynccontextmanager |
|
from typing import List, Dict, Any, Optional, Union |
|
|
|
from fastapi import FastAPI, HTTPException, Depends, Request |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, Field, field_validator |
|
|
|
import uvicorn |
|
import requests |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
from transformers import BitsAndBytesConfig |
|
|
|
from transformers import Gemma3nForConditionalGeneration, AutoProcessor |
|
import torch |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
class TextContent(BaseModel): |
|
type: str = Field(default="text", description="Content type") |
|
text: str = Field(..., description="Text content") |
|
|
|
@field_validator('type') |
|
@classmethod |
|
def validate_type(cls, v: str) -> str: |
|
if v != "text": |
|
raise ValueError("Type must be 'text'") |
|
return v |
|
|
|
class ImageContent(BaseModel): |
|
type: str = Field(default="image", description="Content type") |
|
url: str = Field(..., description="Image URL") |
|
|
|
@field_validator('type') |
|
@classmethod |
|
def validate_type(cls, v: str) -> str: |
|
if v != "image": |
|
raise ValueError("Type must be 'image'") |
|
return v |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
role: str = Field(..., description="The role of the message author") |
|
content: Union[str, List[Union[TextContent, ImageContent]]] = Field(..., description="The content of the message - either string or list of content items") |
|
|
|
@field_validator('role') |
|
@classmethod |
|
def validate_role(cls, v: str) -> str: |
|
if v not in ["system", "user", "assistant"]: |
|
raise ValueError("Role must be one of: system, user, assistant") |
|
return v |
|
|
|
class ChatCompletionRequest(BaseModel): |
|
model: str = Field(default_factory=lambda: "google/gemma-3n-E4B-it", description="The model to use for completion") |
|
messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") |
|
max_tokens: Optional[int] = Field(default=512, ge=1, le=2048, description="Maximum tokens to generate") |
|
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") |
|
stream: Optional[bool] = Field(default=False, description="Whether to stream responses") |
|
top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") |
|
|
|
class ChatCompletionChoice(BaseModel): |
|
index: int |
|
message: ChatMessage |
|
finish_reason: str |
|
|
|
class ChatCompletionResponse(BaseModel): |
|
id: str |
|
object: str = "chat.completion" |
|
created: int |
|
model: str |
|
choices: List[ChatCompletionChoice] |
|
|
|
class ChatCompletionChunk(BaseModel): |
|
id: str |
|
object: str = "chat.completion.chunk" |
|
created: int |
|
model: str |
|
choices: List[Dict[str, Any]] |
|
|
|
class HealthResponse(BaseModel): |
|
status: str |
|
model: str |
|
version: str |
|
|
|
class ModelInfo(BaseModel): |
|
id: str |
|
object: str = "model" |
|
created: int |
|
owned_by: str = "huggingface" |
|
|
|
class ModelsResponse(BaseModel): |
|
object: str = "list" |
|
data: List[ModelInfo] |
|
|
|
class CompletionRequest(BaseModel): |
|
prompt: str = Field(..., description="The prompt to complete") |
|
max_tokens: Optional[int] = Field(default=512, ge=1, le=2048) |
|
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) |
|
|
|
|
|
|
|
|
|
|
|
ai_model_env = os.environ.get("AI_MODEL", "google/gemma-3n-E4B-it") |
|
|
|
if "GGUF" in ai_model_env: |
|
current_model = "google/gemma-3n-E4B-it" |
|
print(f"🔄 Overriding GGUF model {ai_model_env} with transformers-compatible model: {current_model}") |
|
else: |
|
current_model = ai_model_env |
|
vision_model = os.environ.get("VISION_MODEL", "Salesforce/blip-image-captioning-base") |
|
|
|
|
|
processor = None |
|
model = None |
|
image_text_pipeline = None |
|
|
|
|
|
|
|
|
|
async def download_image(url: str) -> Image.Image: |
|
"""Download and process image from URL""" |
|
try: |
|
response = requests.get(url, timeout=10) |
|
response.raise_for_status() |
|
image = Image.open(requests.compat.BytesIO(response.content)) |
|
return image |
|
except Exception as e: |
|
logger.error(f"Failed to download image from {url}: {e}") |
|
raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}") |
|
|
|
def extract_text_and_images(content: Union[str, List[Any]]) -> tuple[str, List[str]]: |
|
"""Extract text and image URLs from message content""" |
|
if isinstance(content, str): |
|
return content, [] |
|
|
|
text_parts: List[str] = [] |
|
image_urls: List[str] = [] |
|
|
|
for item in content: |
|
if hasattr(item, 'type'): |
|
if item.type == "text" and hasattr(item, 'text'): |
|
text_parts.append(str(item.text)) |
|
elif item.type == "image" and hasattr(item, 'url'): |
|
image_urls.append(str(item.url)) |
|
|
|
return " ".join(text_parts), image_urls |
|
|
|
def has_images(messages: List[ChatMessage]) -> bool: |
|
"""Check if any messages contain images""" |
|
for message in messages: |
|
if isinstance(message.content, list): |
|
for item in message.content: |
|
if hasattr(item, 'type') and item.type == "image": |
|
return True |
|
return False |
|
|
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""Application lifespan manager for startup and shutdown events""" |
|
global processor, model, image_text_pipeline, current_model |
|
logger.info("🚀 Starting AI Backend Service (Hugging Face Spaces mode)...") |
|
logger.info(f"🔧 Using model: {current_model}") |
|
try: |
|
logger.info(f"📥 Loading model with transformers: {current_model}") |
|
|
|
|
|
if "gemma-3n" in current_model.lower(): |
|
logger.info("🔍 Detected Gemma 3n model - using specialized classes") |
|
processor = AutoProcessor.from_pretrained(current_model) |
|
model = Gemma3nForConditionalGeneration.from_pretrained( |
|
current_model, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
).eval() |
|
else: |
|
|
|
logger.info("🔍 Using standard transformers classes") |
|
processor = AutoTokenizer.from_pretrained(current_model) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
current_model, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
) |
|
|
|
logger.info(f"✅ Successfully loaded model and processor: {current_model}") |
|
|
|
|
|
if "gemma-3n" not in current_model.lower(): |
|
|
|
try: |
|
logger.info(f"🖼️ Initializing image captioning pipeline with model: {vision_model}") |
|
image_text_pipeline = pipeline("image-to-text", model=vision_model) |
|
logger.info("✅ Image captioning pipeline loaded successfully") |
|
except Exception as e: |
|
logger.warning(f"⚠️ Could not load image captioning pipeline: {e}") |
|
image_text_pipeline = None |
|
else: |
|
logger.info("✅ Gemma 3n has built-in multimodal support") |
|
image_text_pipeline = None |
|
|
|
except Exception as e: |
|
logger.error(f"❌ Failed to initialize model: {e}") |
|
raise RuntimeError(f"Service initialization failed: {e}") |
|
yield |
|
logger.info("🔄 Shutting down AI Backend Service...") |
|
processor = None |
|
model = None |
|
image_text_pipeline = None |
|
|
|
|
|
app = FastAPI( |
|
title="AI Backend Service - Gemma 3n", |
|
description="OpenAI-compatible chat completion API powered by google/gemma-3n-E4B-it", |
|
version="1.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
def ensure_model_ready(): |
|
"""Check if transformers model is loaded and ready""" |
|
if processor is None or model is None: |
|
raise HTTPException(status_code=503, detail="Service not ready - no model initialized (transformers)") |
|
|
|
def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: |
|
"""Convert OpenAI messages format to a single prompt string""" |
|
prompt_parts: List[str] = [] |
|
|
|
for message in messages: |
|
role = message.role |
|
|
|
|
|
if isinstance(message.content, str): |
|
content = message.content |
|
else: |
|
content, _ = extract_text_and_images(message.content) |
|
|
|
if role == "system": |
|
prompt_parts.append(f"System: {content}") |
|
elif role == "user": |
|
prompt_parts.append(f"Human: {content}") |
|
elif role == "assistant": |
|
prompt_parts.append(f"Assistant: {content}") |
|
|
|
|
|
prompt_parts.append("Assistant:") |
|
|
|
return "\n".join(prompt_parts) |
|
|
|
async def generate_multimodal_response( |
|
messages: List[ChatMessage], |
|
request: ChatCompletionRequest |
|
) -> str: |
|
"""Generate response using image-text-to-text pipeline for multimodal content""" |
|
if not image_text_pipeline: |
|
raise HTTPException(status_code=503, detail="Image processing not available - pipeline not initialized") |
|
|
|
try: |
|
|
|
last_user_message = None |
|
for message in reversed(messages): |
|
if message.role == "user" and isinstance(message.content, list): |
|
last_user_message = message |
|
break |
|
|
|
if not last_user_message: |
|
raise HTTPException(status_code=400, detail="No user message with images found") |
|
|
|
|
|
text_content, image_urls = extract_text_and_images(last_user_message.content) |
|
|
|
if not image_urls: |
|
raise HTTPException(status_code=400, detail="No images found in the message") |
|
|
|
|
|
image_url = image_urls[0] |
|
|
|
|
|
logger.info(f"🖼️ Processing image: {image_url}") |
|
try: |
|
|
|
result = await asyncio.to_thread(lambda: image_text_pipeline(image_url)) |
|
|
|
|
|
if result and hasattr(result, '__len__') and len(result) > 0: |
|
first_result = result[0] |
|
if hasattr(first_result, 'get'): |
|
generated_text = first_result.get('generated_text', f'I can see an image at {image_url}.') |
|
else: |
|
generated_text = str(first_result) |
|
|
|
|
|
if text_content: |
|
response = f"Looking at this image, I can see: {generated_text}. " |
|
if "what" in text_content.lower() or "?" in text_content: |
|
response += f"Regarding your question '{text_content}': Based on what I can see, this appears to be {generated_text.lower()}." |
|
else: |
|
response += f"You mentioned: {text_content}" |
|
return response |
|
else: |
|
return f"I can see: {generated_text}" |
|
else: |
|
return f"I can see there's an image at {image_url}, but cannot process it right now." |
|
|
|
except Exception as pipeline_error: |
|
logger.warning(f"Pipeline error: {pipeline_error}") |
|
return f"I can see there's an image at {image_url}. The image appears to contain visual content that I'm having trouble processing right now." |
|
|
|
except Exception as e: |
|
logger.error(f"Error in multimodal generation: {e}") |
|
return f"I'm having trouble processing the image. Error: {str(e)}" |
|
|
|
|
|
def generate_response_local(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str: |
|
"""Generate response using local transformers model with chat template.""" |
|
ensure_model_ready() |
|
try: |
|
logger.info(" Generating response using transformers model") |
|
return generate_response_transformers(messages, max_tokens, temperature, top_p) |
|
except Exception as e: |
|
logger.error(f"Local generation failed: {e}") |
|
return "I apologize, but I'm having trouble generating a response right now. Please try again." |
|
|
|
|
|
|
|
def convert_messages_to_gemma_prompt(messages: List[ChatMessage]) -> str: |
|
"""Convert OpenAI messages format to Gemma 3n chat format.""" |
|
|
|
prompt_parts = ["<bos>"] |
|
|
|
for message in messages: |
|
role = message.role |
|
content = message.content |
|
|
|
if role == "system": |
|
prompt_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") |
|
elif role == "user": |
|
prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") |
|
elif role == "assistant": |
|
prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") |
|
|
|
|
|
prompt_parts.append("<start_of_turn>model\n") |
|
|
|
return "\n".join(prompt_parts) |
|
|
|
def generate_response_transformers(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str: |
|
"""Generate response using transformers model with chat template.""" |
|
try: |
|
|
|
if "gemma-3n" in current_model.lower(): |
|
|
|
|
|
chat_messages = [] |
|
for m in messages: |
|
|
|
if isinstance(m.content, str): |
|
content = [{"type": "text", "text": m.content}] |
|
else: |
|
|
|
text_content, _ = extract_text_and_images(m.content) |
|
content = [{"type": "text", "text": text_content}] |
|
|
|
chat_messages.append({"role": m.role, "content": content}) |
|
|
|
|
|
inputs = processor.apply_chat_template( |
|
chat_messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
with torch.inference_mode(): |
|
generation = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=temperature > 0, |
|
) |
|
generation = generation[0][input_len:] |
|
|
|
|
|
generated_text = processor.decode(generation, skip_special_tokens=True) |
|
return generated_text.strip() |
|
|
|
else: |
|
|
|
|
|
chat_messages = [] |
|
for m in messages: |
|
content_str = m.content if isinstance(m.content, str) else extract_text_and_images(m.content)[0] |
|
chat_messages.append({"role": m.role, "content": content_str}) |
|
|
|
|
|
inputs = processor.apply_chat_template( |
|
chat_messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt", |
|
) |
|
|
|
outputs = model.generate( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs.get("attention_mask"), |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=temperature > 0, |
|
) |
|
|
|
generated_text = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) |
|
return generated_text.strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Transformers generation failed: {e}") |
|
return "I apologize, but I'm having trouble generating a response right now. Please try again." |
|
|
|
|
|
@app.get("/", response_class=JSONResponse) |
|
async def root() -> Dict[str, Any]: |
|
"""Root endpoint with service information""" |
|
return { |
|
"message": "AI Backend Service is running with Mistral Nemo!", |
|
"model": current_model, |
|
"version": "1.0.0", |
|
"endpoints": { |
|
"health": "/health", |
|
"models": "/v1/models", |
|
"chat_completions": "/v1/chat/completions" |
|
} |
|
} |
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
global current_model, tokenizer, model |
|
return HealthResponse( |
|
status="healthy" if (tokenizer is not None and model is not None) else "unhealthy", |
|
model=current_model, |
|
version="1.0.0" |
|
) |
|
|
|
@app.get("/v1/models", response_model=ModelsResponse) |
|
async def list_models(): |
|
"""List available models (OpenAI-compatible)""" |
|
|
|
models = [ |
|
ModelInfo( |
|
id=current_model, |
|
created=int(time.time()), |
|
owned_by="huggingface" |
|
) |
|
] |
|
|
|
|
|
if image_text_pipeline: |
|
models.append( |
|
ModelInfo( |
|
id=vision_model, |
|
created=int(time.time()), |
|
owned_by="huggingface" |
|
) |
|
) |
|
|
|
return ModelsResponse(data=models) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
|
async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse: |
|
"""Create a chat completion (OpenAI-compatible) with multimodal support. Hugging Face Spaces: Only transformers backend supported.""" |
|
try: |
|
if not request.messages: |
|
raise HTTPException(status_code=400, detail="Messages cannot be empty") |
|
is_multimodal = has_images(request.messages) |
|
if is_multimodal: |
|
if not image_text_pipeline: |
|
raise HTTPException(status_code=503, detail="Image processing not available") |
|
response_text = await generate_multimodal_response(request.messages, request) |
|
else: |
|
logger.info(f"Generating local response for messages: {request.messages}") |
|
response_text = await asyncio.to_thread( |
|
generate_response_local, |
|
request.messages, |
|
request.max_tokens or 512, |
|
request.temperature or 0.7, |
|
request.top_p or 0.95 |
|
) |
|
response_text = response_text.strip() if response_text else "No response generated." |
|
return ChatCompletionResponse( |
|
id=f"chatcmpl-{int(time.time())}", |
|
created=int(time.time()), |
|
model=request.model, |
|
choices=[ChatCompletionChoice( |
|
index=0, |
|
message=ChatMessage(role="assistant", content=response_text), |
|
finish_reason="stop" |
|
)] |
|
) |
|
except Exception as e: |
|
logger.error(f"Error in chat completion: {e}") |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
|
|
|
|
@app.post("/v1/completions") |
|
async def create_completion( |
|
request: CompletionRequest |
|
) -> Dict[str, Any]: |
|
"""Create a text completion (OpenAI-compatible)""" |
|
try: |
|
if not request.prompt: |
|
raise HTTPException(status_code=400, detail="Prompt cannot be empty") |
|
ensure_model_ready() |
|
|
|
messages = [ChatMessage(role="user", content=request.prompt)] |
|
response_text = await asyncio.to_thread( |
|
generate_response_local, |
|
messages, |
|
request.max_tokens or 512, |
|
request.temperature or 0.7, |
|
0.95 |
|
) |
|
return { |
|
"id": f"cmpl-{int(time.time())}", |
|
"object": "text_completion", |
|
"created": int(time.time()), |
|
"model": current_model, |
|
"choices": [{ |
|
"text": response_text, |
|
"index": 0, |
|
"finish_reason": "stop" |
|
}] |
|
} |
|
except Exception as e: |
|
logger.error(f"Error in completion: {e}") |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
@app.post("/api/response") |
|
async def api_response(request: Request) -> JSONResponse: |
|
"""Endpoint to receive and send responses via API.""" |
|
try: |
|
data = await request.json() |
|
message = data.get("message", "No message provided") |
|
return JSONResponse(content={ |
|
"status": "success", |
|
"received_message": message, |
|
"response_message": f"You sent: {message}" |
|
}) |
|
except Exception as e: |
|
logger.error(f"Error processing API response: {e}") |
|
raise HTTPException(status_code=500, detail="Internal server error") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run("backend_service:app", host="0.0.0.0", port=8000, reload=True) |
|
|