Yadav122 commited on
Commit
e8434f3
·
verified ·
1 Parent(s): 9571152

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +261 -0
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import secrets
3
+ import hashlib
4
+ from typing import Optional, Dict, Any
5
+ from datetime import datetime, timedelta
6
+ import logging
7
+
8
+ from fastapi import FastAPI, HTTPException, Depends, Security, status
9
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
+ import torch
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
+ import uvicorn
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Initialize FastAPI app
21
+ app = FastAPI(
22
+ title="LLM AI Agent API",
23
+ description="Secure AI Agent API with Local LLM deployment",
24
+ version="1.0.0",
25
+ docs_url="/docs",
26
+ redoc_url="/redoc"
27
+ )
28
+
29
+ # CORS middleware for cross-origin requests
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"], # Configure this for production
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # Security
39
+ security = HTTPBearer()
40
+
41
+ # Configuration
42
+ class Config:
43
+ # API Keys - In production, use environment variables
44
+ API_KEYS = {
45
+ os.getenv("API_KEY_1", "your-secure-api-key-1"): "user1",
46
+ os.getenv("API_KEY_2", "your-secure-api-key-2"): "user2",
47
+ # Add more API keys as needed
48
+ }
49
+
50
+ # Model configuration
51
+ MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") # Lightweight model for free tier
52
+ MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
53
+ TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
54
+ TOP_P = float(os.getenv("TOP_P", "0.9"))
55
+
56
+ # Rate limiting (requests per minute per API key)
57
+ RATE_LIMIT = int(os.getenv("RATE_LIMIT", "10"))
58
+
59
+ # Global variables for model and tokenizer
60
+ model = None
61
+ tokenizer = None
62
+ text_generator = None
63
+
64
+ # Request/Response models
65
+ class ChatRequest(BaseModel):
66
+ message: str = Field(..., min_length=1, max_length=1000, description="Input message for the AI agent")
67
+ max_length: Optional[int] = Field(None, ge=10, le=2048, description="Maximum response length")
68
+ temperature: Optional[float] = Field(None, ge=0.1, le=2.0, description="Response creativity (0.1-2.0)")
69
+ system_prompt: Optional[str] = Field(None, max_length=500, description="Optional system prompt")
70
+
71
+ class ChatResponse(BaseModel):
72
+ response: str
73
+ model_used: str
74
+ timestamp: str
75
+ tokens_used: int
76
+ processing_time: float
77
+
78
+ class HealthResponse(BaseModel):
79
+ status: str
80
+ model_loaded: bool
81
+ timestamp: str
82
+ version: str
83
+
84
+ # Rate limiting storage (in production, use Redis)
85
+ request_counts: Dict[str, Dict[str, int]] = {}
86
+
87
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
88
+ """Verify API key authentication"""
89
+ api_key = credentials.credentials
90
+
91
+ if api_key not in Config.API_KEYS:
92
+ raise HTTPException(
93
+ status_code=status.HTTP_401_UNAUTHORIZED,
94
+ detail="Invalid API key",
95
+ headers={"WWW-Authenticate": "Bearer"},
96
+ )
97
+
98
+ return Config.API_KEYS[api_key]
99
+
100
+ def check_rate_limit(api_key: str) -> bool:
101
+ """Simple rate limiting implementation"""
102
+ current_minute = datetime.now().strftime("%Y-%m-%d-%H-%M")
103
+
104
+ if api_key not in request_counts:
105
+ request_counts[api_key] = {}
106
+
107
+ if current_minute not in request_counts[api_key]:
108
+ request_counts[api_key][current_minute] = 0
109
+
110
+ if request_counts[api_key][current_minute] >= Config.RATE_LIMIT:
111
+ return False
112
+
113
+ request_counts[api_key][current_minute] += 1
114
+ return True
115
+
116
+ @app.on_event("startup")
117
+ async def load_model():
118
+ """Load the LLM model on startup"""
119
+ global model, tokenizer, text_generator
120
+
121
+ try:
122
+ logger.info(f"Loading model: {Config.MODEL_NAME}")
123
+
124
+ # Load tokenizer
125
+ tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
126
+
127
+ # Add padding token if it doesn't exist
128
+ if tokenizer.pad_token is None:
129
+ tokenizer.pad_token = tokenizer.eos_token
130
+
131
+ # Load model with optimizations for free tier
132
+ model = AutoModelForCausalLM.from_pretrained(
133
+ Config.MODEL_NAME,
134
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
135
+ device_map="auto" if torch.cuda.is_available() else None,
136
+ low_cpu_mem_usage=True
137
+ )
138
+
139
+ # Create text generation pipeline
140
+ text_generator = pipeline(
141
+ "text-generation",
142
+ model=model,
143
+ tokenizer=tokenizer,
144
+ device=0 if torch.cuda.is_available() else -1
145
+ )
146
+
147
+ logger.info("Model loaded successfully!")
148
+
149
+ except Exception as e:
150
+ logger.error(f"Error loading model: {str(e)}")
151
+ raise e
152
+
153
+ @app.get("/", response_model=HealthResponse)
154
+ async def root():
155
+ """Health check endpoint"""
156
+ return HealthResponse(
157
+ status="healthy",
158
+ model_loaded=model is not None,
159
+ timestamp=datetime.now().isoformat(),
160
+ version="1.0.0"
161
+ )
162
+
163
+ @app.get("/health", response_model=HealthResponse)
164
+ async def health_check():
165
+ """Detailed health check"""
166
+ return HealthResponse(
167
+ status="healthy" if model is not None else "model_not_loaded",
168
+ model_loaded=model is not None,
169
+ timestamp=datetime.now().isoformat(),
170
+ version="1.0.0"
171
+ )
172
+
173
+ @app.post("/chat", response_model=ChatResponse)
174
+ async def chat(
175
+ request: ChatRequest,
176
+ user: str = Depends(verify_api_key)
177
+ ):
178
+ """Main chat endpoint for AI agent interaction"""
179
+ start_time = datetime.now()
180
+
181
+ # Check rate limiting
182
+ api_key = None # In a real implementation, you'd extract this from the token
183
+ # if not check_rate_limit(api_key):
184
+ # raise HTTPException(
185
+ # status_code=status.HTTP_429_TOO_MANY_REQUESTS,
186
+ # detail="Rate limit exceeded. Please try again later."
187
+ # )
188
+
189
+ if model is None or tokenizer is None:
190
+ raise HTTPException(
191
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
192
+ detail="Model not loaded. Please try again later."
193
+ )
194
+
195
+ try:
196
+ # Prepare input
197
+ input_text = request.message
198
+ if request.system_prompt:
199
+ input_text = f"System: {request.system_prompt}\nUser: {request.message}\nAssistant:"
200
+
201
+ # Generate response
202
+ max_length = request.max_length or Config.MAX_LENGTH
203
+ temperature = request.temperature or Config.TEMPERATURE
204
+
205
+ # Generate text
206
+ generated = text_generator(
207
+ input_text,
208
+ max_length=max_length,
209
+ temperature=temperature,
210
+ top_p=Config.TOP_P,
211
+ do_sample=True,
212
+ pad_token_id=tokenizer.eos_token_id,
213
+ num_return_sequences=1,
214
+ truncation=True
215
+ )
216
+
217
+ # Extract response
218
+ response_text = generated[0]['generated_text']
219
+ if input_text in response_text:
220
+ response_text = response_text.replace(input_text, "").strip()
221
+
222
+ # Calculate processing time
223
+ processing_time = (datetime.now() - start_time).total_seconds()
224
+
225
+ # Count tokens (approximate)
226
+ tokens_used = len(tokenizer.encode(response_text))
227
+
228
+ return ChatResponse(
229
+ response=response_text,
230
+ model_used=Config.MODEL_NAME,
231
+ timestamp=datetime.now().isoformat(),
232
+ tokens_used=tokens_used,
233
+ processing_time=processing_time
234
+ )
235
+
236
+ except Exception as e:
237
+ logger.error(f"Error generating response: {str(e)}")
238
+ raise HTTPException(
239
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
240
+ detail=f"Error generating response: {str(e)}"
241
+ )
242
+
243
+ @app.get("/models")
244
+ async def get_model_info(user: str = Depends(verify_api_key)):
245
+ """Get information about the loaded model"""
246
+ return {
247
+ "model_name": Config.MODEL_NAME,
248
+ "model_loaded": model is not None,
249
+ "max_length": Config.MAX_LENGTH,
250
+ "temperature": Config.TEMPERATURE,
251
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
252
+ }
253
+
254
+ if __name__ == "__main__":
255
+ # For local development
256
+ uvicorn.run(
257
+ "app:app",
258
+ host="0.0.0.0",
259
+ port=int(os.getenv("PORT", "7860")), # Hugging Face Spaces uses port 7860
260
+ reload=False
261
+ )