Spaces:
Running
Running
import os | |
import logging | |
from typing import Optional | |
from datetime import datetime | |
from fastapi import FastAPI, HTTPException, Depends, Security, status | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import uvicorn | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="LLM AI Agent API", | |
description="Secure AI Agent API with Local LLM deployment", | |
version="1.0.0" | |
) | |
# CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Security | |
security = HTTPBearer() | |
# Configuration | |
API_KEYS = { | |
os.getenv("API_KEY_1", "27Eud5J73j6SqPQAT2ioV-CtiCg-p0WNqq6I4U0Ig6E"): "user1", | |
os.getenv("API_KEY_2", "QbzG2CqHU1Nn6F1EogZ1d3dp8ilRTMJQBwTJDQBzS-U"): "user2", | |
} | |
# Global variables for model | |
model = None | |
tokenizer = None | |
model_loaded = False | |
# Request/Response models | |
class ChatRequest(BaseModel): | |
message: str = Field(..., min_length=1, max_length=1000) | |
max_length: Optional[int] = Field(100, ge=10, le=500) | |
temperature: Optional[float] = Field(0.7, ge=0.1, le=2.0) | |
class ChatResponse(BaseModel): | |
response: str | |
model_used: str | |
timestamp: str | |
processing_time: float | |
class HealthResponse(BaseModel): | |
status: str | |
model_loaded: bool | |
timestamp: str | |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: | |
"""Verify API key authentication""" | |
api_key = credentials.credentials | |
if api_key not in API_KEYS: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid API key" | |
) | |
return API_KEYS[api_key] | |
async def load_model(): | |
"""Load the LLM model on startup""" | |
global model, tokenizer, model_loaded | |
try: | |
logger.info("Loading model...") | |
# Try to import and load transformers | |
try: | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
model_name = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small") | |
logger.info(f"Loading model: {model_name}") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, # Use float32 for compatibility | |
low_cpu_mem_usage=True | |
) | |
model_loaded = True | |
logger.info("Model loaded successfully!") | |
except Exception as e: | |
logger.warning(f"Could not load transformers model: {e}") | |
logger.info("Running in demo mode with simple responses") | |
model_loaded = False | |
except Exception as e: | |
logger.error(f"Error during startup: {str(e)}") | |
model_loaded = False | |
async def root(): | |
"""Health check endpoint""" | |
return HealthResponse( | |
status="healthy", | |
model_loaded=model_loaded, | |
timestamp=datetime.now().isoformat() | |
) | |
async def health_check(): | |
"""Detailed health check""" | |
return HealthResponse( | |
status="healthy" if model_loaded else "demo_mode", | |
model_loaded=model_loaded, | |
timestamp=datetime.now().isoformat() | |
) | |
async def chat( | |
request: ChatRequest, | |
user: str = Depends(verify_api_key) | |
): | |
"""Main chat endpoint for AI agent interaction""" | |
start_time = datetime.now() | |
try: | |
if model_loaded and model is not None and tokenizer is not None: | |
# Use actual model | |
from transformers import pipeline | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=-1 # Use CPU | |
) | |
# Generate response | |
generated = generator( | |
request.message, | |
max_length=request.max_length, | |
temperature=request.temperature, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1 | |
) | |
response_text = generated[0]['generated_text'] | |
if request.message in response_text: | |
response_text = response_text.replace(request.message, "").strip() | |
model_used = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small") | |
else: | |
# Demo mode - simple responses | |
demo_responses = { | |
"hello": "Hello! I'm your AI assistant. How can I help you today?", | |
"hi": "Hi there! I'm ready to assist you.", | |
"how are you": "I'm doing well, thank you for asking! How can I help you?", | |
"what is ai": "AI (Artificial Intelligence) is the simulation of human intelligence in machines that are programmed to think and learn.", | |
"machine learning": "Machine learning is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed.", | |
"default": "I'm an AI assistant ready to help you. Could you please rephrase your question?" | |
} | |
message_lower = request.message.lower() | |
response_text = demo_responses.get("default", "I'm here to help!") | |
for key, response in demo_responses.items(): | |
if key in message_lower: | |
response_text = response | |
break | |
model_used = "demo_mode" | |
# Calculate processing time | |
processing_time = (datetime.now() - start_time).total_seconds() | |
return ChatResponse( | |
response=response_text, | |
model_used=model_used, | |
timestamp=datetime.now().isoformat(), | |
processing_time=processing_time | |
) | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Error generating response: {str(e)}" | |
) | |
async def get_model_info(user: str = Depends(verify_api_key)): | |
"""Get information about the loaded model""" | |
return { | |
"model_name": os.getenv("MODEL_NAME", "microsoft/DialoGPT-small"), | |
"model_loaded": model_loaded, | |
"status": "loaded" if model_loaded else "demo_mode" | |
} | |
if __name__ == "__main__": | |
# For local development and Hugging Face Spaces | |
port = int(os.getenv("PORT", "7860")) | |
uvicorn.run( | |
"app_simple:app", | |
host="0.0.0.0", | |
port=port, | |
reload=False | |
) | |