Spaces:
Sleeping
Sleeping
import os | |
import secrets | |
import hashlib | |
from typing import Optional, Dict, Any | |
from datetime import datetime, timedelta | |
import logging | |
from fastapi import FastAPI, HTTPException, Depends, Security, status | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import uvicorn | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="LLM AI Agent API", | |
description="Secure AI Agent API with Local LLM deployment", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# CORS middleware for cross-origin requests | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Configure this for production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Security | |
security = HTTPBearer() | |
# Configuration | |
class Config: | |
# API Keys - In production, use environment variables | |
API_KEYS = { | |
os.getenv("API_KEY_1", "your-secure-api-key-1"): "user1", | |
os.getenv("API_KEY_2", "your-secure-api-key-2"): "user2", | |
# Add more API keys as needed | |
} | |
# Model configuration | |
MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") # Lightweight model for free tier | |
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512")) | |
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) | |
TOP_P = float(os.getenv("TOP_P", "0.9")) | |
# Rate limiting (requests per minute per API key) | |
RATE_LIMIT = int(os.getenv("RATE_LIMIT", "10")) | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
text_generator = None | |
# Request/Response models | |
class ChatRequest(BaseModel): | |
message: str = Field(..., min_length=1, max_length=1000, description="Input message for the AI agent") | |
max_length: Optional[int] = Field(None, ge=10, le=2048, description="Maximum response length") | |
temperature: Optional[float] = Field(None, ge=0.1, le=2.0, description="Response creativity (0.1-2.0)") | |
system_prompt: Optional[str] = Field(None, max_length=500, description="Optional system prompt") | |
class ChatResponse(BaseModel): | |
response: str | |
model_used: str | |
timestamp: str | |
tokens_used: int | |
processing_time: float | |
class HealthResponse(BaseModel): | |
status: str | |
model_loaded: bool | |
timestamp: str | |
version: str | |
# Rate limiting storage (in production, use Redis) | |
request_counts: Dict[str, Dict[str, int]] = {} | |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: | |
"""Verify API key authentication""" | |
api_key = credentials.credentials | |
if api_key not in Config.API_KEYS: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid API key", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return Config.API_KEYS[api_key] | |
def check_rate_limit(api_key: str) -> bool: | |
"""Simple rate limiting implementation""" | |
current_minute = datetime.now().strftime("%Y-%m-%d-%H-%M") | |
if api_key not in request_counts: | |
request_counts[api_key] = {} | |
if current_minute not in request_counts[api_key]: | |
request_counts[api_key][current_minute] = 0 | |
if request_counts[api_key][current_minute] >= Config.RATE_LIMIT: | |
return False | |
request_counts[api_key][current_minute] += 1 | |
return True | |
async def load_model(): | |
"""Load the LLM model on startup""" | |
global model, tokenizer, text_generator | |
try: | |
logger.info(f"Loading model: {Config.MODEL_NAME}") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME) | |
# Add padding token if it doesn't exist | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model with optimizations for free tier | |
model = AutoModelForCausalLM.from_pretrained( | |
Config.MODEL_NAME, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None, | |
low_cpu_mem_usage=True | |
) | |
# Create text generation pipeline | |
text_generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
logger.info("Model loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise e | |
async def root(): | |
"""Health check endpoint""" | |
return HealthResponse( | |
status="healthy", | |
model_loaded=model is not None, | |
timestamp=datetime.now().isoformat(), | |
version="1.0.0" | |
) | |
async def health_check(): | |
"""Detailed health check""" | |
return HealthResponse( | |
status="healthy" if model is not None else "model_not_loaded", | |
model_loaded=model is not None, | |
timestamp=datetime.now().isoformat(), | |
version="1.0.0" | |
) | |
async def chat( | |
request: ChatRequest, | |
user: str = Depends(verify_api_key) | |
): | |
"""Main chat endpoint for AI agent interaction""" | |
start_time = datetime.now() | |
# Check rate limiting | |
api_key = None # In a real implementation, you'd extract this from the token | |
# if not check_rate_limit(api_key): | |
# raise HTTPException( | |
# status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
# detail="Rate limit exceeded. Please try again later." | |
# ) | |
if model is None or tokenizer is None: | |
raise HTTPException( | |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
detail="Model not loaded. Please try again later." | |
) | |
try: | |
# Prepare input | |
input_text = request.message | |
if request.system_prompt: | |
input_text = f"System: {request.system_prompt}\nUser: {request.message}\nAssistant:" | |
# Generate response | |
max_length = request.max_length or Config.MAX_LENGTH | |
temperature = request.temperature or Config.TEMPERATURE | |
# Generate text | |
generated = text_generator( | |
input_text, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=Config.TOP_P, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1, | |
truncation=True | |
) | |
# Extract response | |
response_text = generated[0]['generated_text'] | |
if input_text in response_text: | |
response_text = response_text.replace(input_text, "").strip() | |
# Calculate processing time | |
processing_time = (datetime.now() - start_time).total_seconds() | |
# Count tokens (approximate) | |
tokens_used = len(tokenizer.encode(response_text)) | |
return ChatResponse( | |
response=response_text, | |
model_used=Config.MODEL_NAME, | |
timestamp=datetime.now().isoformat(), | |
tokens_used=tokens_used, | |
processing_time=processing_time | |
) | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Error generating response: {str(e)}" | |
) | |
async def get_model_info(user: str = Depends(verify_api_key)): | |
"""Get information about the loaded model""" | |
return { | |
"model_name": Config.MODEL_NAME, | |
"model_loaded": model is not None, | |
"max_length": Config.MAX_LENGTH, | |
"temperature": Config.TEMPERATURE, | |
"device": "cuda" if torch.cuda.is_available() else "cpu" | |
} | |
if __name__ == "__main__": | |
# For local development | |
uvicorn.run( | |
"app:app", | |
host="0.0.0.0", | |
port=int(os.getenv("PORT", "7860")), # Hugging Face Spaces uses port 7860 | |
reload=False | |
) | |