Yadav122 commited on
Commit
8aea355
·
verified ·
1 Parent(s): f6dc4cc

Fix: Simplified app with better error handling

Browse files
Files changed (1) hide show
  1. app_simple.py +224 -0
app_simple.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import Optional
4
+ from datetime import datetime
5
+
6
+ from fastapi import FastAPI, HTTPException, Depends, Security, status
7
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, Field
10
+ import uvicorn
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Initialize FastAPI app
17
+ app = FastAPI(
18
+ title="LLM AI Agent API",
19
+ description="Secure AI Agent API with Local LLM deployment",
20
+ version="1.0.0"
21
+ )
22
+
23
+ # CORS middleware
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ # Security
33
+ security = HTTPBearer()
34
+
35
+ # Configuration
36
+ API_KEYS = {
37
+ os.getenv("API_KEY_1", "27Eud5J73j6SqPQAT2ioV-CtiCg-p0WNqq6I4U0Ig6E"): "user1",
38
+ os.getenv("API_KEY_2", "QbzG2CqHU1Nn6F1EogZ1d3dp8ilRTMJQBwTJDQBzS-U"): "user2",
39
+ }
40
+
41
+ # Global variables for model
42
+ model = None
43
+ tokenizer = None
44
+ model_loaded = False
45
+
46
+ # Request/Response models
47
+ class ChatRequest(BaseModel):
48
+ message: str = Field(..., min_length=1, max_length=1000)
49
+ max_length: Optional[int] = Field(100, ge=10, le=500)
50
+ temperature: Optional[float] = Field(0.7, ge=0.1, le=2.0)
51
+
52
+ class ChatResponse(BaseModel):
53
+ response: str
54
+ model_used: str
55
+ timestamp: str
56
+ processing_time: float
57
+
58
+ class HealthResponse(BaseModel):
59
+ status: str
60
+ model_loaded: bool
61
+ timestamp: str
62
+
63
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
64
+ """Verify API key authentication"""
65
+ api_key = credentials.credentials
66
+
67
+ if api_key not in API_KEYS:
68
+ raise HTTPException(
69
+ status_code=status.HTTP_401_UNAUTHORIZED,
70
+ detail="Invalid API key"
71
+ )
72
+
73
+ return API_KEYS[api_key]
74
+
75
+ @app.on_event("startup")
76
+ async def load_model():
77
+ """Load the LLM model on startup"""
78
+ global model, tokenizer, model_loaded
79
+
80
+ try:
81
+ logger.info("Loading model...")
82
+
83
+ # Try to import and load transformers
84
+ try:
85
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
86
+ import torch
87
+
88
+ model_name = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small")
89
+ logger.info(f"Loading model: {model_name}")
90
+
91
+ # Load tokenizer
92
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
93
+ if tokenizer.pad_token is None:
94
+ tokenizer.pad_token = tokenizer.eos_token
95
+
96
+ # Load model
97
+ model = AutoModelForCausalLM.from_pretrained(
98
+ model_name,
99
+ torch_dtype=torch.float32, # Use float32 for compatibility
100
+ low_cpu_mem_usage=True
101
+ )
102
+
103
+ model_loaded = True
104
+ logger.info("Model loaded successfully!")
105
+
106
+ except Exception as e:
107
+ logger.warning(f"Could not load transformers model: {e}")
108
+ logger.info("Running in demo mode with simple responses")
109
+ model_loaded = False
110
+
111
+ except Exception as e:
112
+ logger.error(f"Error during startup: {str(e)}")
113
+ model_loaded = False
114
+
115
+ @app.get("/", response_model=HealthResponse)
116
+ async def root():
117
+ """Health check endpoint"""
118
+ return HealthResponse(
119
+ status="healthy",
120
+ model_loaded=model_loaded,
121
+ timestamp=datetime.now().isoformat()
122
+ )
123
+
124
+ @app.get("/health", response_model=HealthResponse)
125
+ async def health_check():
126
+ """Detailed health check"""
127
+ return HealthResponse(
128
+ status="healthy" if model_loaded else "demo_mode",
129
+ model_loaded=model_loaded,
130
+ timestamp=datetime.now().isoformat()
131
+ )
132
+
133
+ @app.post("/chat", response_model=ChatResponse)
134
+ async def chat(
135
+ request: ChatRequest,
136
+ user: str = Depends(verify_api_key)
137
+ ):
138
+ """Main chat endpoint for AI agent interaction"""
139
+ start_time = datetime.now()
140
+
141
+ try:
142
+ if model_loaded and model is not None and tokenizer is not None:
143
+ # Use actual model
144
+ from transformers import pipeline
145
+
146
+ generator = pipeline(
147
+ "text-generation",
148
+ model=model,
149
+ tokenizer=tokenizer,
150
+ device=-1 # Use CPU
151
+ )
152
+
153
+ # Generate response
154
+ generated = generator(
155
+ request.message,
156
+ max_length=request.max_length,
157
+ temperature=request.temperature,
158
+ do_sample=True,
159
+ pad_token_id=tokenizer.eos_token_id,
160
+ num_return_sequences=1
161
+ )
162
+
163
+ response_text = generated[0]['generated_text']
164
+ if request.message in response_text:
165
+ response_text = response_text.replace(request.message, "").strip()
166
+
167
+ model_used = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small")
168
+
169
+ else:
170
+ # Demo mode - simple responses
171
+ demo_responses = {
172
+ "hello": "Hello! I'm your AI assistant. How can I help you today?",
173
+ "hi": "Hi there! I'm ready to assist you.",
174
+ "how are you": "I'm doing well, thank you for asking! How can I help you?",
175
+ "what is ai": "AI (Artificial Intelligence) is the simulation of human intelligence in machines that are programmed to think and learn.",
176
+ "machine learning": "Machine learning is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed.",
177
+ "default": "I'm an AI assistant ready to help you. Could you please rephrase your question?"
178
+ }
179
+
180
+ message_lower = request.message.lower()
181
+ response_text = demo_responses.get("default", "I'm here to help!")
182
+
183
+ for key, response in demo_responses.items():
184
+ if key in message_lower:
185
+ response_text = response
186
+ break
187
+
188
+ model_used = "demo_mode"
189
+
190
+ # Calculate processing time
191
+ processing_time = (datetime.now() - start_time).total_seconds()
192
+
193
+ return ChatResponse(
194
+ response=response_text,
195
+ model_used=model_used,
196
+ timestamp=datetime.now().isoformat(),
197
+ processing_time=processing_time
198
+ )
199
+
200
+ except Exception as e:
201
+ logger.error(f"Error generating response: {str(e)}")
202
+ raise HTTPException(
203
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
204
+ detail=f"Error generating response: {str(e)}"
205
+ )
206
+
207
+ @app.get("/models")
208
+ async def get_model_info(user: str = Depends(verify_api_key)):
209
+ """Get information about the loaded model"""
210
+ return {
211
+ "model_name": os.getenv("MODEL_NAME", "microsoft/DialoGPT-small"),
212
+ "model_loaded": model_loaded,
213
+ "status": "loaded" if model_loaded else "demo_mode"
214
+ }
215
+
216
+ if __name__ == "__main__":
217
+ # For local development and Hugging Face Spaces
218
+ port = int(os.getenv("PORT", "7860"))
219
+ uvicorn.run(
220
+ "app_simple:app",
221
+ host="0.0.0.0",
222
+ port=port,
223
+ reload=False
224
+ )