#!/usr/bin/env python3 """ Lightweight Backend Service - Memory-Optimized for HF Spaces Uses CPU-only transformers with quantization instead of GGUF """ import os import logging import time from contextlib import asynccontextmanager from typing import List, Dict, Any, Optional import uuid from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, field_validator # Import transformers with optimizations import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline ) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Pydantic models for OpenAI-compatible API class ChatMessage(BaseModel): role: str = Field(..., description="The role of the message author") content: str = Field(..., description="The content of the message") @field_validator('role') @classmethod def validate_role(cls, v: str) -> str: if v not in ["system", "user", "assistant"]: raise ValueError("Role must be one of: system, user, assistant") return v class ChatCompletionRequest(BaseModel): model: str = Field(default="gemma-2-2b-it", description="The model to use for completion") messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") max_tokens: Optional[int] = Field(default=256, ge=1, le=1024, description="Maximum tokens to generate (memory-optimized)") temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0, description="Sampling temperature") top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") stream: Optional[bool] = Field(default=False, description="Whether to stream responses") class ChatCompletionChoice(BaseModel): index: int message: ChatMessage finish_reason: str class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionChoice] class HealthResponse(BaseModel): status: str model: str version: str backend: str memory_optimization: str # Global variables for model management # Use smaller Gemma 2B model for better memory efficiency current_model = os.environ.get("AI_MODEL", "google/gemma-2-2b-it") tokenizer = None model = None text_pipeline = None @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager with memory-optimized model loading""" global tokenizer, model, text_pipeline logger.info("๐Ÿš€ Starting Lightweight Backend Service...") if os.environ.get("DEMO_MODE", "").strip() not in ("", "0", "false", "False"): logger.info("๐Ÿงช DEMO_MODE enabled: skipping model load") yield logger.info("๐Ÿ”„ Shutting down Lightweight Backend Service (demo mode)...") return try: logger.info(f"๐Ÿ“ฅ Loading lightweight model: {current_model}") # Let accelerate handle device and thread management automatically logger.info("โš™๏ธ Configuring accelerate-optimized model loading...") # Load tokenizer first tokenizer = AutoTokenizer.from_pretrained( current_model, trust_remote_code=True, use_fast=True ) # Ensure pad token exists if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model with memory optimizations model = AutoModelForCausalLM.from_pretrained( current_model, torch_dtype=torch.float32, # Use float32 for CPU (more compatible) device_map="auto", # Let accelerate handle device placement automatically low_cpu_mem_usage=True, # Enable memory-efficient loading trust_remote_code=True, # Additional memory optimizations attn_implementation="eager", # Use basic attention (less memory) ) # Create pipeline for efficient generation (let accelerate handle device) text_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, # Remove device=-1 to avoid conflict with accelerate max_new_tokens=256, # Default limit do_sample=True, temperature=1.0, top_p=0.95, pad_token_id=tokenizer.eos_token_id, ) logger.info("โœ… Successfully loaded lightweight model with accelerate optimizations") logger.info(f"๐Ÿ“Š Model: {current_model}") logger.info(f"๐Ÿ”ง Device: auto (managed by accelerate)") logger.info(f"๐Ÿง  Memory Mode: CPU-optimized with float32") except Exception as e: logger.error(f"โŒ Failed to initialize model: {e}") logger.info("๐Ÿ”„ Starting service in demo mode") model = None tokenizer = None text_pipeline = None yield logger.info("๐Ÿ”„ Shutting down Lightweight Backend Service...") # Clean up model resources if model: del model if tokenizer: del tokenizer if text_pipeline: del text_pipeline # Initialize FastAPI app app = FastAPI( title="Lightweight Gemma Backend Service", description="Memory-optimized OpenAI-compatible chat completion API", version="1.0.0", lifespan=lifespan ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: """Convert OpenAI messages format to Gemma chat format.""" prompt_parts = [] for message in messages: role = message.role content = message.content.strip() if role == "system": prompt_parts.append(f"system\n{content}") elif role == "user": prompt_parts.append(f"user\n{content}") elif role == "assistant": prompt_parts.append(f"model\n{content}") # Add the start for model response prompt_parts.append("model\n") return "\n".join(prompt_parts) def generate_response(messages: List[ChatMessage], max_tokens: int = 256, temperature: float = 1.0, top_p: float = 0.95) -> str: """Generate response using lightweight transformers pipeline.""" if text_pipeline is None: return "๐Ÿค– Demo mode: Model not loaded. This would be a real response from the Gemma model." try: # Convert messages to prompt prompt = convert_messages_to_prompt(messages) # Limit max_tokens for memory efficiency max_tokens = min(max_tokens, 512) # Generate response result = text_pipeline( prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, return_full_text=False, # Only return the new tokens pad_token_id=tokenizer.eos_token_id, ) # Extract generated text if result and len(result) > 0: response_text = result[0]['generated_text'].strip() # Clean up any unwanted tokens if "" in response_text: response_text = response_text.split("")[0].strip() return response_text else: return "I apologize, but I'm having trouble generating a response right now." except Exception as e: logger.error(f"Generation failed: {e}") return "I apologize, but I'm having trouble generating a response right now. Please try again." @app.get("/", response_class=JSONResponse) async def root() -> Dict[str, Any]: """Root endpoint with service information""" return { "service": "Lightweight Gemma Backend", "version": "1.0.0", "model": current_model, "backend": "transformers-cpu", "optimization": "memory-efficient", "endpoints": { "health": "/health", "chat": "/v1/chat/completions", "docs": "/docs" } } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" status = "healthy" if text_pipeline is not None else "demo_mode" return HealthResponse( status=status, model=current_model, version="1.0.0", backend="transformers-cpu", memory_optimization="float32-cpu-lowmem" ) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse: """Create a chat completion (OpenAI-compatible) using lightweight model""" try: # Generate response response_text = generate_response( messages=request.messages, max_tokens=request.max_tokens or 256, temperature=request.temperature or 1.0, top_p=request.top_p or 0.95 ) # Create response message response_message = ChatMessage(role="assistant", content=response_text) # Create choice choice = ChatCompletionChoice( index=0, message=response_message, finish_reason="stop" ) # Create completion response completion = ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex[:8]}", object="chat.completion", created=int(time.time()), model=request.model, choices=[choice] ) return completion except Exception as e: logger.error(f"Chat completion failed: {e}") raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)