File size: 7,339 Bytes
8aea355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import os
import logging
from typing import Optional
from datetime import datetime

from fastapi import FastAPI, HTTPException, Depends, Security, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(
    title="LLM AI Agent API",
    description="Secure AI Agent API with Local LLM deployment",
    version="1.0.0"
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Security
security = HTTPBearer()

# Configuration
API_KEYS = {
    os.getenv("API_KEY_1", "27Eud5J73j6SqPQAT2ioV-CtiCg-p0WNqq6I4U0Ig6E"): "user1",
    os.getenv("API_KEY_2", "QbzG2CqHU1Nn6F1EogZ1d3dp8ilRTMJQBwTJDQBzS-U"): "user2",
}

# Global variables for model
model = None
tokenizer = None
model_loaded = False

# Request/Response models
class ChatRequest(BaseModel):
    message: str = Field(..., min_length=1, max_length=1000)
    max_length: Optional[int] = Field(100, ge=10, le=500)
    temperature: Optional[float] = Field(0.7, ge=0.1, le=2.0)

class ChatResponse(BaseModel):
    response: str
    model_used: str
    timestamp: str
    processing_time: float

class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    timestamp: str

def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
    """Verify API key authentication"""
    api_key = credentials.credentials
    
    if api_key not in API_KEYS:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid API key"
        )
    
    return API_KEYS[api_key]

@app.on_event("startup")
async def load_model():
    """Load the LLM model on startup"""
    global model, tokenizer, model_loaded
    
    try:
        logger.info("Loading model...")
        
        # Try to import and load transformers
        try:
            from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
            import torch
            
            model_name = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small")
            logger.info(f"Loading model: {model_name}")
            
            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            # Load model
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float32,  # Use float32 for compatibility
                low_cpu_mem_usage=True
            )
            
            model_loaded = True
            logger.info("Model loaded successfully!")
            
        except Exception as e:
            logger.warning(f"Could not load transformers model: {e}")
            logger.info("Running in demo mode with simple responses")
            model_loaded = False
        
    except Exception as e:
        logger.error(f"Error during startup: {str(e)}")
        model_loaded = False

@app.get("/", response_model=HealthResponse)
async def root():
    """Health check endpoint"""
    return HealthResponse(
        status="healthy",
        model_loaded=model_loaded,
        timestamp=datetime.now().isoformat()
    )

@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Detailed health check"""
    return HealthResponse(
        status="healthy" if model_loaded else "demo_mode",
        model_loaded=model_loaded,
        timestamp=datetime.now().isoformat()
    )

@app.post("/chat", response_model=ChatResponse)
async def chat(
    request: ChatRequest,
    user: str = Depends(verify_api_key)
):
    """Main chat endpoint for AI agent interaction"""
    start_time = datetime.now()
    
    try:
        if model_loaded and model is not None and tokenizer is not None:
            # Use actual model
            from transformers import pipeline
            
            generator = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                device=-1  # Use CPU
            )
            
            # Generate response
            generated = generator(
                request.message,
                max_length=request.max_length,
                temperature=request.temperature,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                num_return_sequences=1
            )
            
            response_text = generated[0]['generated_text']
            if request.message in response_text:
                response_text = response_text.replace(request.message, "").strip()
            
            model_used = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small")
            
        else:
            # Demo mode - simple responses
            demo_responses = {
                "hello": "Hello! I'm your AI assistant. How can I help you today?",
                "hi": "Hi there! I'm ready to assist you.",
                "how are you": "I'm doing well, thank you for asking! How can I help you?",
                "what is ai": "AI (Artificial Intelligence) is the simulation of human intelligence in machines that are programmed to think and learn.",
                "machine learning": "Machine learning is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed.",
                "default": "I'm an AI assistant ready to help you. Could you please rephrase your question?"
            }
            
            message_lower = request.message.lower()
            response_text = demo_responses.get("default", "I'm here to help!")
            
            for key, response in demo_responses.items():
                if key in message_lower:
                    response_text = response
                    break
            
            model_used = "demo_mode"
        
        # Calculate processing time
        processing_time = (datetime.now() - start_time).total_seconds()
        
        return ChatResponse(
            response=response_text,
            model_used=model_used,
            timestamp=datetime.now().isoformat(),
            processing_time=processing_time
        )
        
    except Exception as e:
        logger.error(f"Error generating response: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error generating response: {str(e)}"
        )

@app.get("/models")
async def get_model_info(user: str = Depends(verify_api_key)):
    """Get information about the loaded model"""
    return {
        "model_name": os.getenv("MODEL_NAME", "microsoft/DialoGPT-small"),
        "model_loaded": model_loaded,
        "status": "loaded" if model_loaded else "demo_mode"
    }

if __name__ == "__main__":
    # For local development and Hugging Face Spaces
    port = int(os.getenv("PORT", "7860"))
    uvicorn.run(
        "app_simple:app",
        host="0.0.0.0",
        port=port,
        reload=False
    )