Yadav122 commited on
Commit
25f8aca
·
verified ·
1 Parent(s): 8aea355

Fix: Use simplified app for better compatibility

Browse files
Files changed (1) hide show
  1. app.py +108 -145
app.py CHANGED
@@ -1,16 +1,12 @@
1
  import os
2
- import secrets
3
- import hashlib
4
- from typing import Optional, Dict, Any
5
- from datetime import datetime, timedelta
6
  import logging
 
 
7
 
8
  from fastapi import FastAPI, HTTPException, Depends, Security, status
9
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel, Field
12
- import torch
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  import uvicorn
15
 
16
  # Configure logging
@@ -21,15 +17,13 @@ logger = logging.getLogger(__name__)
21
  app = FastAPI(
22
  title="LLM AI Agent API",
23
  description="Secure AI Agent API with Local LLM deployment",
24
- version="1.0.0",
25
- docs_url="/docs",
26
- redoc_url="/redoc"
27
  )
28
 
29
- # CORS middleware for cross-origin requests
30
  app.add_middleware(
31
  CORSMiddleware,
32
- allow_origins=["*"], # Configure this for production
33
  allow_credentials=True,
34
  allow_methods=["*"],
35
  allow_headers=["*"],
@@ -39,135 +33,101 @@ app.add_middleware(
39
  security = HTTPBearer()
40
 
41
  # Configuration
42
- class Config:
43
- # API Keys - In production, use environment variables
44
- API_KEYS = {
45
- os.getenv("API_KEY_1", "your-secure-api-key-1"): "user1",
46
- os.getenv("API_KEY_2", "your-secure-api-key-2"): "user2",
47
- # Add more API keys as needed
48
- }
49
-
50
- # Model configuration
51
- MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") # Lightweight model for free tier
52
- MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
53
- TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
54
- TOP_P = float(os.getenv("TOP_P", "0.9"))
55
-
56
- # Rate limiting (requests per minute per API key)
57
- RATE_LIMIT = int(os.getenv("RATE_LIMIT", "10"))
58
 
59
- # Global variables for model and tokenizer
60
  model = None
61
  tokenizer = None
62
- text_generator = None
63
 
64
  # Request/Response models
65
  class ChatRequest(BaseModel):
66
- message: str = Field(..., min_length=1, max_length=1000, description="Input message for the AI agent")
67
- max_length: Optional[int] = Field(None, ge=10, le=2048, description="Maximum response length")
68
- temperature: Optional[float] = Field(None, ge=0.1, le=2.0, description="Response creativity (0.1-2.0)")
69
- system_prompt: Optional[str] = Field(None, max_length=500, description="Optional system prompt")
70
 
71
  class ChatResponse(BaseModel):
72
  response: str
73
  model_used: str
74
  timestamp: str
75
- tokens_used: int
76
  processing_time: float
77
 
78
  class HealthResponse(BaseModel):
79
  status: str
80
  model_loaded: bool
81
  timestamp: str
82
- version: str
83
-
84
- # Rate limiting storage (in production, use Redis)
85
- request_counts: Dict[str, Dict[str, int]] = {}
86
 
87
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
88
  """Verify API key authentication"""
89
  api_key = credentials.credentials
90
 
91
- if api_key not in Config.API_KEYS:
92
  raise HTTPException(
93
  status_code=status.HTTP_401_UNAUTHORIZED,
94
- detail="Invalid API key",
95
- headers={"WWW-Authenticate": "Bearer"},
96
  )
97
 
98
- return Config.API_KEYS[api_key]
99
-
100
- def check_rate_limit(api_key: str) -> bool:
101
- """Simple rate limiting implementation"""
102
- current_minute = datetime.now().strftime("%Y-%m-%d-%H-%M")
103
-
104
- if api_key not in request_counts:
105
- request_counts[api_key] = {}
106
-
107
- if current_minute not in request_counts[api_key]:
108
- request_counts[api_key][current_minute] = 0
109
-
110
- if request_counts[api_key][current_minute] >= Config.RATE_LIMIT:
111
- return False
112
-
113
- request_counts[api_key][current_minute] += 1
114
- return True
115
 
116
  @app.on_event("startup")
117
  async def load_model():
118
  """Load the LLM model on startup"""
119
- global model, tokenizer, text_generator
120
 
121
  try:
122
- logger.info(f"Loading model: {Config.MODEL_NAME}")
123
-
124
- # Load tokenizer
125
- tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
126
-
127
- # Add padding token if it doesn't exist
128
- if tokenizer.pad_token is None:
129
- tokenizer.pad_token = tokenizer.eos_token
130
-
131
- # Load model with optimizations for free tier
132
- model = AutoModelForCausalLM.from_pretrained(
133
- Config.MODEL_NAME,
134
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
135
- device_map="auto" if torch.cuda.is_available() else None,
136
- low_cpu_mem_usage=True
137
- )
138
-
139
- # Create text generation pipeline
140
- text_generator = pipeline(
141
- "text-generation",
142
- model=model,
143
- tokenizer=tokenizer,
144
- device=0 if torch.cuda.is_available() else -1
145
- )
146
 
147
- logger.info("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  except Exception as e:
150
- logger.error(f"Error loading model: {str(e)}")
151
- raise e
152
 
153
  @app.get("/", response_model=HealthResponse)
154
  async def root():
155
  """Health check endpoint"""
156
  return HealthResponse(
157
  status="healthy",
158
- model_loaded=model is not None,
159
- timestamp=datetime.now().isoformat(),
160
- version="1.0.0"
161
  )
162
 
163
  @app.get("/health", response_model=HealthResponse)
164
  async def health_check():
165
  """Detailed health check"""
166
  return HealthResponse(
167
- status="healthy" if model is not None else "model_not_loaded",
168
- model_loaded=model is not None,
169
- timestamp=datetime.now().isoformat(),
170
- version="1.0.0"
171
  )
172
 
173
  @app.post("/chat", response_model=ChatResponse)
@@ -178,58 +138,62 @@ async def chat(
178
  """Main chat endpoint for AI agent interaction"""
179
  start_time = datetime.now()
180
 
181
- # Check rate limiting
182
- api_key = None # In a real implementation, you'd extract this from the token
183
- # if not check_rate_limit(api_key):
184
- # raise HTTPException(
185
- # status_code=status.HTTP_429_TOO_MANY_REQUESTS,
186
- # detail="Rate limit exceeded. Please try again later."
187
- # )
188
-
189
- if model is None or tokenizer is None:
190
- raise HTTPException(
191
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
192
- detail="Model not loaded. Please try again later."
193
- )
194
-
195
  try:
196
- # Prepare input
197
- input_text = request.message
198
- if request.system_prompt:
199
- input_text = f"System: {request.system_prompt}\nUser: {request.message}\nAssistant:"
200
-
201
- # Generate response
202
- max_length = request.max_length or Config.MAX_LENGTH
203
- temperature = request.temperature or Config.TEMPERATURE
204
-
205
- # Generate text
206
- generated = text_generator(
207
- input_text,
208
- max_length=max_length,
209
- temperature=temperature,
210
- top_p=Config.TOP_P,
211
- do_sample=True,
212
- pad_token_id=tokenizer.eos_token_id,
213
- num_return_sequences=1,
214
- truncation=True
215
- )
216
-
217
- # Extract response
218
- response_text = generated[0]['generated_text']
219
- if input_text in response_text:
220
- response_text = response_text.replace(input_text, "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  # Calculate processing time
223
  processing_time = (datetime.now() - start_time).total_seconds()
224
 
225
- # Count tokens (approximate)
226
- tokens_used = len(tokenizer.encode(response_text))
227
-
228
  return ChatResponse(
229
  response=response_text,
230
- model_used=Config.MODEL_NAME,
231
  timestamp=datetime.now().isoformat(),
232
- tokens_used=tokens_used,
233
  processing_time=processing_time
234
  )
235
 
@@ -244,18 +208,17 @@ async def chat(
244
  async def get_model_info(user: str = Depends(verify_api_key)):
245
  """Get information about the loaded model"""
246
  return {
247
- "model_name": Config.MODEL_NAME,
248
- "model_loaded": model is not None,
249
- "max_length": Config.MAX_LENGTH,
250
- "temperature": Config.TEMPERATURE,
251
- "device": "cuda" if torch.cuda.is_available() else "cpu"
252
  }
253
 
254
  if __name__ == "__main__":
255
- # For local development
 
256
  uvicorn.run(
257
- "app:app",
258
  host="0.0.0.0",
259
- port=int(os.getenv("PORT", "7860")), # Hugging Face Spaces uses port 7860
260
  reload=False
261
  )
 
1
  import os
 
 
 
 
2
  import logging
3
+ from typing import Optional
4
+ from datetime import datetime
5
 
6
  from fastapi import FastAPI, HTTPException, Depends, Security, status
7
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel, Field
 
 
10
  import uvicorn
11
 
12
  # Configure logging
 
17
  app = FastAPI(
18
  title="LLM AI Agent API",
19
  description="Secure AI Agent API with Local LLM deployment",
20
+ version="1.0.0"
 
 
21
  )
22
 
23
+ # CORS middleware
24
  app.add_middleware(
25
  CORSMiddleware,
26
+ allow_origins=["*"],
27
  allow_credentials=True,
28
  allow_methods=["*"],
29
  allow_headers=["*"],
 
33
  security = HTTPBearer()
34
 
35
  # Configuration
36
+ API_KEYS = {
37
+ os.getenv("API_KEY_1", "27Eud5J73j6SqPQAT2ioV-CtiCg-p0WNqq6I4U0Ig6E"): "user1",
38
+ os.getenv("API_KEY_2", "QbzG2CqHU1Nn6F1EogZ1d3dp8ilRTMJQBwTJDQBzS-U"): "user2",
39
+ }
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Global variables for model
42
  model = None
43
  tokenizer = None
44
+ model_loaded = False
45
 
46
  # Request/Response models
47
  class ChatRequest(BaseModel):
48
+ message: str = Field(..., min_length=1, max_length=1000)
49
+ max_length: Optional[int] = Field(100, ge=10, le=500)
50
+ temperature: Optional[float] = Field(0.7, ge=0.1, le=2.0)
 
51
 
52
  class ChatResponse(BaseModel):
53
  response: str
54
  model_used: str
55
  timestamp: str
 
56
  processing_time: float
57
 
58
  class HealthResponse(BaseModel):
59
  status: str
60
  model_loaded: bool
61
  timestamp: str
 
 
 
 
62
 
63
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
64
  """Verify API key authentication"""
65
  api_key = credentials.credentials
66
 
67
+ if api_key not in API_KEYS:
68
  raise HTTPException(
69
  status_code=status.HTTP_401_UNAUTHORIZED,
70
+ detail="Invalid API key"
 
71
  )
72
 
73
+ return API_KEYS[api_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  @app.on_event("startup")
76
  async def load_model():
77
  """Load the LLM model on startup"""
78
+ global model, tokenizer, model_loaded
79
 
80
  try:
81
+ logger.info("Loading model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Try to import and load transformers
84
+ try:
85
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
86
+ import torch
87
+
88
+ model_name = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small")
89
+ logger.info(f"Loading model: {model_name}")
90
+
91
+ # Load tokenizer
92
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
93
+ if tokenizer.pad_token is None:
94
+ tokenizer.pad_token = tokenizer.eos_token
95
+
96
+ # Load model
97
+ model = AutoModelForCausalLM.from_pretrained(
98
+ model_name,
99
+ torch_dtype=torch.float32, # Use float32 for compatibility
100
+ low_cpu_mem_usage=True
101
+ )
102
+
103
+ model_loaded = True
104
+ logger.info("Model loaded successfully!")
105
+
106
+ except Exception as e:
107
+ logger.warning(f"Could not load transformers model: {e}")
108
+ logger.info("Running in demo mode with simple responses")
109
+ model_loaded = False
110
 
111
  except Exception as e:
112
+ logger.error(f"Error during startup: {str(e)}")
113
+ model_loaded = False
114
 
115
  @app.get("/", response_model=HealthResponse)
116
  async def root():
117
  """Health check endpoint"""
118
  return HealthResponse(
119
  status="healthy",
120
+ model_loaded=model_loaded,
121
+ timestamp=datetime.now().isoformat()
 
122
  )
123
 
124
  @app.get("/health", response_model=HealthResponse)
125
  async def health_check():
126
  """Detailed health check"""
127
  return HealthResponse(
128
+ status="healthy" if model_loaded else "demo_mode",
129
+ model_loaded=model_loaded,
130
+ timestamp=datetime.now().isoformat()
 
131
  )
132
 
133
  @app.post("/chat", response_model=ChatResponse)
 
138
  """Main chat endpoint for AI agent interaction"""
139
  start_time = datetime.now()
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  try:
142
+ if model_loaded and model is not None and tokenizer is not None:
143
+ # Use actual model
144
+ from transformers import pipeline
145
+
146
+ generator = pipeline(
147
+ "text-generation",
148
+ model=model,
149
+ tokenizer=tokenizer,
150
+ device=-1 # Use CPU
151
+ )
152
+
153
+ # Generate response
154
+ generated = generator(
155
+ request.message,
156
+ max_length=request.max_length,
157
+ temperature=request.temperature,
158
+ do_sample=True,
159
+ pad_token_id=tokenizer.eos_token_id,
160
+ num_return_sequences=1
161
+ )
162
+
163
+ response_text = generated[0]['generated_text']
164
+ if request.message in response_text:
165
+ response_text = response_text.replace(request.message, "").strip()
166
+
167
+ model_used = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small")
168
+
169
+ else:
170
+ # Demo mode - simple responses
171
+ demo_responses = {
172
+ "hello": "Hello! I'm your AI assistant. How can I help you today?",
173
+ "hi": "Hi there! I'm ready to assist you.",
174
+ "how are you": "I'm doing well, thank you for asking! How can I help you?",
175
+ "what is ai": "AI (Artificial Intelligence) is the simulation of human intelligence in machines that are programmed to think and learn.",
176
+ "machine learning": "Machine learning is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed.",
177
+ "default": "I'm an AI assistant ready to help you. Could you please rephrase your question?"
178
+ }
179
+
180
+ message_lower = request.message.lower()
181
+ response_text = demo_responses.get("default", "I'm here to help!")
182
+
183
+ for key, response in demo_responses.items():
184
+ if key in message_lower:
185
+ response_text = response
186
+ break
187
+
188
+ model_used = "demo_mode"
189
 
190
  # Calculate processing time
191
  processing_time = (datetime.now() - start_time).total_seconds()
192
 
 
 
 
193
  return ChatResponse(
194
  response=response_text,
195
+ model_used=model_used,
196
  timestamp=datetime.now().isoformat(),
 
197
  processing_time=processing_time
198
  )
199
 
 
208
  async def get_model_info(user: str = Depends(verify_api_key)):
209
  """Get information about the loaded model"""
210
  return {
211
+ "model_name": os.getenv("MODEL_NAME", "microsoft/DialoGPT-small"),
212
+ "model_loaded": model_loaded,
213
+ "status": "loaded" if model_loaded else "demo_mode"
 
 
214
  }
215
 
216
  if __name__ == "__main__":
217
+ # For local development and Hugging Face Spaces
218
+ port = int(os.getenv("PORT", "7860"))
219
  uvicorn.run(
220
+ "app_simple:app",
221
  host="0.0.0.0",
222
+ port=port,
223
  reload=False
224
  )