ndc8
Refactor model loading to utilize accelerate for device management; add test script to verify loading fix and prevent device conflicts
8a3c5dd
#!/usr/bin/env python3 | |
""" | |
Lightweight Backend Service - Memory-Optimized for HF Spaces | |
Uses CPU-only transformers with quantization instead of GGUF | |
""" | |
import os | |
import logging | |
import time | |
from contextlib import asynccontextmanager | |
from typing import List, Dict, Any, Optional | |
import uuid | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field, field_validator | |
# Import transformers with optimizations | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
BitsAndBytesConfig, | |
pipeline | |
) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Pydantic models for OpenAI-compatible API | |
class ChatMessage(BaseModel): | |
role: str = Field(..., description="The role of the message author") | |
content: str = Field(..., description="The content of the message") | |
def validate_role(cls, v: str) -> str: | |
if v not in ["system", "user", "assistant"]: | |
raise ValueError("Role must be one of: system, user, assistant") | |
return v | |
class ChatCompletionRequest(BaseModel): | |
model: str = Field(default="gemma-2-2b-it", description="The model to use for completion") | |
messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") | |
max_tokens: Optional[int] = Field(default=256, ge=1, le=1024, description="Maximum tokens to generate (memory-optimized)") | |
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0, description="Sampling temperature") | |
top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") | |
stream: Optional[bool] = Field(default=False, description="Whether to stream responses") | |
class ChatCompletionChoice(BaseModel): | |
index: int | |
message: ChatMessage | |
finish_reason: str | |
class ChatCompletionResponse(BaseModel): | |
id: str | |
object: str = "chat.completion" | |
created: int | |
model: str | |
choices: List[ChatCompletionChoice] | |
class HealthResponse(BaseModel): | |
status: str | |
model: str | |
version: str | |
backend: str | |
memory_optimization: str | |
# Global variables for model management | |
# Use smaller Gemma 2B model for better memory efficiency | |
current_model = os.environ.get("AI_MODEL", "google/gemma-2-2b-it") | |
tokenizer = None | |
model = None | |
text_pipeline = None | |
async def lifespan(app: FastAPI): | |
"""Application lifespan manager with memory-optimized model loading""" | |
global tokenizer, model, text_pipeline | |
logger.info("🚀 Starting Lightweight Backend Service...") | |
if os.environ.get("DEMO_MODE", "").strip() not in ("", "0", "false", "False"): | |
logger.info("🧪 DEMO_MODE enabled: skipping model load") | |
yield | |
logger.info("🔄 Shutting down Lightweight Backend Service (demo mode)...") | |
return | |
try: | |
logger.info(f"📥 Loading lightweight model: {current_model}") | |
# Let accelerate handle device and thread management automatically | |
logger.info("⚙️ Configuring accelerate-optimized model loading...") | |
# Load tokenizer first | |
tokenizer = AutoTokenizer.from_pretrained( | |
current_model, | |
trust_remote_code=True, | |
use_fast=True | |
) | |
# Ensure pad token exists | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model with memory optimizations | |
model = AutoModelForCausalLM.from_pretrained( | |
current_model, | |
torch_dtype=torch.float32, # Use float32 for CPU (more compatible) | |
device_map="auto", # Let accelerate handle device placement automatically | |
low_cpu_mem_usage=True, # Enable memory-efficient loading | |
trust_remote_code=True, | |
# Additional memory optimizations | |
attn_implementation="eager", # Use basic attention (less memory) | |
) | |
# Create pipeline for efficient generation (let accelerate handle device) | |
text_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
# Remove device=-1 to avoid conflict with accelerate | |
max_new_tokens=256, # Default limit | |
do_sample=True, | |
temperature=1.0, | |
top_p=0.95, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
logger.info("✅ Successfully loaded lightweight model with accelerate optimizations") | |
logger.info(f"📊 Model: {current_model}") | |
logger.info(f"🔧 Device: auto (managed by accelerate)") | |
logger.info(f"🧠 Memory Mode: CPU-optimized with float32") | |
except Exception as e: | |
logger.error(f"❌ Failed to initialize model: {e}") | |
logger.info("🔄 Starting service in demo mode") | |
model = None | |
tokenizer = None | |
text_pipeline = None | |
yield | |
logger.info("🔄 Shutting down Lightweight Backend Service...") | |
# Clean up model resources | |
if model: | |
del model | |
if tokenizer: | |
del tokenizer | |
if text_pipeline: | |
del text_pipeline | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Lightweight Gemma Backend Service", | |
description="Memory-optimized OpenAI-compatible chat completion API", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: | |
"""Convert OpenAI messages format to Gemma chat format.""" | |
prompt_parts = [] | |
for message in messages: | |
role = message.role | |
content = message.content.strip() | |
if role == "system": | |
prompt_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") | |
elif role == "user": | |
prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") | |
elif role == "assistant": | |
prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") | |
# Add the start for model response | |
prompt_parts.append("<start_of_turn>model\n") | |
return "\n".join(prompt_parts) | |
def generate_response(messages: List[ChatMessage], max_tokens: int = 256, temperature: float = 1.0, top_p: float = 0.95) -> str: | |
"""Generate response using lightweight transformers pipeline.""" | |
if text_pipeline is None: | |
return "🤖 Demo mode: Model not loaded. This would be a real response from the Gemma model." | |
try: | |
# Convert messages to prompt | |
prompt = convert_messages_to_prompt(messages) | |
# Limit max_tokens for memory efficiency | |
max_tokens = min(max_tokens, 512) | |
# Generate response | |
result = text_pipeline( | |
prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
return_full_text=False, # Only return the new tokens | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
# Extract generated text | |
if result and len(result) > 0: | |
response_text = result[0]['generated_text'].strip() | |
# Clean up any unwanted tokens | |
if "<end_of_turn>" in response_text: | |
response_text = response_text.split("<end_of_turn>")[0].strip() | |
return response_text | |
else: | |
return "I apologize, but I'm having trouble generating a response right now." | |
except Exception as e: | |
logger.error(f"Generation failed: {e}") | |
return "I apologize, but I'm having trouble generating a response right now. Please try again." | |
async def root() -> Dict[str, Any]: | |
"""Root endpoint with service information""" | |
return { | |
"service": "Lightweight Gemma Backend", | |
"version": "1.0.0", | |
"model": current_model, | |
"backend": "transformers-cpu", | |
"optimization": "memory-efficient", | |
"endpoints": { | |
"health": "/health", | |
"chat": "/v1/chat/completions", | |
"docs": "/docs" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint""" | |
status = "healthy" if text_pipeline is not None else "demo_mode" | |
return HealthResponse( | |
status=status, | |
model=current_model, | |
version="1.0.0", | |
backend="transformers-cpu", | |
memory_optimization="float32-cpu-lowmem" | |
) | |
async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse: | |
"""Create a chat completion (OpenAI-compatible) using lightweight model""" | |
try: | |
# Generate response | |
response_text = generate_response( | |
messages=request.messages, | |
max_tokens=request.max_tokens or 256, | |
temperature=request.temperature or 1.0, | |
top_p=request.top_p or 0.95 | |
) | |
# Create response message | |
response_message = ChatMessage(role="assistant", content=response_text) | |
# Create choice | |
choice = ChatCompletionChoice( | |
index=0, | |
message=response_message, | |
finish_reason="stop" | |
) | |
# Create completion response | |
completion = ChatCompletionResponse( | |
id=f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
object="chat.completion", | |
created=int(time.time()), | |
model=request.model, | |
choices=[choice] | |
) | |
return completion | |
except Exception as e: | |
logger.error(f"Chat completion failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |