Niansuh commited on
Commit
a932354
·
verified ·
1 Parent(s): 7939e38

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +77 -56
main.py CHANGED
@@ -7,21 +7,24 @@ from typing import Any, Dict, List, Optional
7
 
8
  import httpx
9
  import uvicorn
10
- from fastapi import FastAPI, HTTPException, Depends
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
  from pydantic import BaseModel
13
  from starlette.middleware.cors import CORSMiddleware
14
  from starlette.responses import StreamingResponse, Response
15
 
16
- # Configure Logging
 
 
 
 
 
17
  logging.basicConfig(
18
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
  # Load Environment Variables for Sensitive Information
23
- from dotenv import load_dotenv
24
-
25
  load_dotenv()
26
 
27
  app = FastAPI()
@@ -55,21 +58,23 @@ OUTGOING_HEADERS = {
55
  'sec-fetch-dest': 'empty',
56
  'sec-fetch-mode': 'cors',
57
  'sec-fetch-site': 'none',
58
- 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
 
 
59
  }
60
 
61
  # Updated ALLOWED_MODELS List (No Duplicates)
62
  ALLOWED_MODELS = [
63
- {"id": "claude-3.5-sonnet", "name": "claude-3.5-sonnet"},
64
- {"id": "sider", "name": "sider"},
65
- {"id": "gpt-4o-mini", "name": "gpt-4o-mini"},
66
- {"id": "claude-3-haiku", "name": "claude-3-haiku"},
67
- {"id": "claude-3.5-haiku", "name": "claude-3.5-haiku"},
68
- {"id": "gemini-1.5-flash", "name": "gemini-1.5-flash"},
69
- {"id": "llama-3", "name": "llama-3"},
70
- {"id": "gpt-4o", "name": "gpt-4o"},
71
- {"id": "gemini-1.5-pro", "name": "gemini-1.5-pro"},
72
- {"id": "llama-3.1-405b", "name": "llama-3.1-405b"},
73
  ]
74
 
75
  # Configure CORS
@@ -84,19 +89,16 @@ app.add_middleware(
84
  # Security Dependency
85
  security = HTTPBearer()
86
 
87
-
88
  # Pydantic Models
89
  class Message(BaseModel):
90
  role: str
91
  content: str
92
 
93
-
94
  class ChatRequest(BaseModel):
95
  model: str
96
  messages: List[Message]
97
  stream: Optional[bool] = False
98
 
99
-
100
  # Utility Functions
101
  def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
102
  return {
@@ -114,7 +116,6 @@ def create_chat_completion_data(content: str, model: str, finish_reason: Optiona
114
  "usage": None,
115
  }
116
 
117
-
118
  def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
119
  if credentials.credentials != APP_SECRET:
120
  logger.warning(f"Invalid APP_SECRET provided: {credentials.credentials}")
@@ -122,6 +123,26 @@ def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(securi
122
  logger.info("APP_SECRET verified successfully.")
123
  return credentials.credentials
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # CORS Preflight Options Endpoint
127
  @app.options("/hf/v1/chat/completions")
@@ -135,19 +156,18 @@ async def chat_completions_options():
135
  },
136
  )
137
 
138
-
139
  # List Available Models
140
  @app.get("/hf/v1/models")
141
  async def list_models():
142
  return {"object": "list", "data": ALLOWED_MODELS}
143
 
144
-
145
  # Chat Completions Endpoint
146
  @app.post("/hf/v1/chat/completions")
147
  async def chat_completions(
148
- request: ChatRequest, app_secret: str = Depends(verify_app_secret)
149
  ):
150
- logger.info(f"Received chat completion request for model: {request.model}")
 
151
 
152
  # Validate Selected Model
153
  if request.model not in [model['id'] for model in ALLOWED_MODELS]:
@@ -207,39 +227,41 @@ async def chat_completions(
207
 
208
  logger.debug(f"JSON Data Sent to External API: {json.dumps(json_data, indent=2)}")
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  async def generate():
211
- async with httpx.AsyncClient() as client:
212
- try:
213
- async with client.stream(
214
- 'POST',
215
- 'https://sider.ai/api/v2/completion/text', # Updated endpoint
216
- headers=OUTGOING_HEADERS,
217
- json=json_data,
218
- timeout=120.0
219
- ) as response:
220
- response.raise_for_status()
221
- async for line in response.aiter_lines():
222
- if line and ("[DONE]" not in line):
223
- # Assuming the line starts with 'data: ' followed by JSON
224
- if line.startswith("data: "):
225
- json_line = line[6:]
226
- if json_line.startswith("{"):
227
- try:
228
- data = json.loads(json_line)
229
- content = data.get("data", {}).get("text", "")
230
- logger.debug(f"Received content: {content}")
231
- yield f"data: {json.dumps(create_chat_completion_data(content, request.model))}\n\n"
232
- except json.JSONDecodeError as e:
233
- logger.error(f"JSON decode error: {e} - Line: {json_line}")
234
- # Send the stop signal
235
- yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
236
- yield "data: [DONE]\n\n"
237
- except httpx.HTTPStatusError as e:
238
- logger.error(f"HTTP error occurred: {e} - Response: {e.response.text}")
239
- raise HTTPException(status_code=e.response.status_code, detail=str(e))
240
- except httpx.RequestError as e:
241
- logger.error(f"An error occurred while requesting: {e}")
242
- raise HTTPException(status_code=500, detail=str(e))
243
 
244
  if request.stream:
245
  logger.info("Streaming response initiated.")
@@ -273,7 +295,6 @@ async def chat_completions(
273
  "usage": None,
274
  }
275
 
276
-
277
  # Entry Point
278
  if __name__ == "__main__":
279
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
7
 
8
  import httpx
9
  import uvicorn
10
+ from fastapi import FastAPI, HTTPException, Depends, Request, status
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
  from pydantic import BaseModel
13
  from starlette.middleware.cors import CORSMiddleware
14
  from starlette.responses import StreamingResponse, Response
15
 
16
+ from dotenv import load_dotenv
17
+
18
+ # Retry Mechanism Libraries
19
+ from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
20
+
21
+ # Initialize Logging
22
  logging.basicConfig(
23
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
24
  )
25
  logger = logging.getLogger(__name__)
26
 
27
  # Load Environment Variables for Sensitive Information
 
 
28
  load_dotenv()
29
 
30
  app = FastAPI()
 
58
  'sec-fetch-dest': 'empty',
59
  'sec-fetch-mode': 'cors',
60
  'sec-fetch-site': 'none',
61
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
62
+ 'AppleWebKit/537.36 (KHTML, like Gecko) '
63
+ 'Chrome/130.0.0.0 Safari/537.36',
64
  }
65
 
66
  # Updated ALLOWED_MODELS List (No Duplicates)
67
  ALLOWED_MODELS = [
68
+ {"id": "claude-3.5-sonnet", "name": "Claude 3.5 Sonnet"},
69
+ {"id": "sider", "name": "Sider"},
70
+ {"id": "gpt-4o-mini", "name": "GPT-4o Mini"},
71
+ {"id": "claude-3-haiku", "name": "Claude 3 Haiku"},
72
+ {"id": "claude-3.5-haiku", "name": "Claude 3.5 Haiku"},
73
+ {"id": "gemini-1.5-flash", "name": "Gemini 1.5 Flash"},
74
+ {"id": "llama-3", "name": "Llama 3"},
75
+ {"id": "gpt-4o", "name": "GPT-4o"},
76
+ {"id": "gemini-1.5-pro", "name": "Gemini 1.5 Pro"},
77
+ {"id": "llama-3.1-405b", "name": "Llama 3.1 405b"},
78
  ]
79
 
80
  # Configure CORS
 
89
  # Security Dependency
90
  security = HTTPBearer()
91
 
 
92
  # Pydantic Models
93
  class Message(BaseModel):
94
  role: str
95
  content: str
96
 
 
97
  class ChatRequest(BaseModel):
98
  model: str
99
  messages: List[Message]
100
  stream: Optional[bool] = False
101
 
 
102
  # Utility Functions
103
  def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
104
  return {
 
116
  "usage": None,
117
  }
118
 
 
119
  def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
120
  if credentials.credentials != APP_SECRET:
121
  logger.warning(f"Invalid APP_SECRET provided: {credentials.credentials}")
 
123
  logger.info("APP_SECRET verified successfully.")
124
  return credentials.credentials
125
 
126
+ # Retry Configuration using Tenacity
127
+ def is_retryable_exception(exception):
128
+ return isinstance(exception, httpx.HTTPStatusError) and exception.response.status_code == 429
129
+
130
+ @retry(
131
+ retry=retry_if_exception_type(httpx.HTTPStatusError) & is_retryable_exception,
132
+ wait=wait_exponential(multiplier=1, min=2, max=10),
133
+ stop=stop_after_attempt(5),
134
+ reraise=True
135
+ )
136
+ async def send_request_with_retry(json_data: Dict[str, Any]) -> httpx.Response:
137
+ async with httpx.AsyncClient() as client:
138
+ response = await client.post(
139
+ 'https://sider.ai/api/v3/completion/text', # Updated endpoint
140
+ headers=OUTGOING_HEADERS,
141
+ json=json_data,
142
+ timeout=120.0
143
+ )
144
+ response.raise_for_status()
145
+ return response
146
 
147
  # CORS Preflight Options Endpoint
148
  @app.options("/hf/v1/chat/completions")
 
156
  },
157
  )
158
 
 
159
  # List Available Models
160
  @app.get("/hf/v1/models")
161
  async def list_models():
162
  return {"object": "list", "data": ALLOWED_MODELS}
163
 
 
164
  # Chat Completions Endpoint
165
  @app.post("/hf/v1/chat/completions")
166
  async def chat_completions(
167
+ request: ChatRequest, app_secret: str = Depends(verify_app_secret), req: Request = None
168
  ):
169
+ client_ip = req.client.host if req else "unknown"
170
+ logger.info(f"Received chat completion request from {client_ip} for model: {request.model}")
171
 
172
  # Validate Selected Model
173
  if request.model not in [model['id'] for model in ALLOWED_MODELS]:
 
227
 
228
  logger.debug(f"JSON Data Sent to External API: {json.dumps(json_data, indent=2)}")
229
 
230
+ try:
231
+ response = await send_request_with_retry(json_data)
232
+ except httpx.HTTPStatusError as e:
233
+ status_code = e.response.status_code
234
+ if status_code == 429:
235
+ retry_after = e.response.headers.get("Retry-After", "60")
236
+ logger.warning(f"Rate limited by Sider AI. Retry after {retry_after} seconds.")
237
+ raise HTTPException(
238
+ status_code=429,
239
+ detail=f"Rate limited by external service. Please retry after {retry_after} seconds."
240
+ )
241
+ else:
242
+ logger.error(f"HTTP error occurred: {e} - Response: {e.response.text}")
243
+ raise HTTPException(status_code=status_code, detail=str(e))
244
+ except httpx.RequestError as e:
245
+ logger.error(f"An error occurred while requesting: {e}")
246
+ raise HTTPException(status_code=500, detail=str(e))
247
+
248
  async def generate():
249
+ async for line in response.aiter_lines():
250
+ if line and ("[DONE]" not in line):
251
+ # Assuming the line starts with 'data: ' followed by JSON
252
+ if line.startswith("data: "):
253
+ json_line = line[6:]
254
+ if json_line.startswith("{"):
255
+ try:
256
+ data = json.loads(json_line)
257
+ content = data.get("data", {}).get("text", "")
258
+ logger.debug(f"Received content: {content}")
259
+ yield f"data: {json.dumps(create_chat_completion_data(content, request.model))}\n\n"
260
+ except json.JSONDecodeError as e:
261
+ logger.error(f"JSON decode error: {e} - Line: {json_line}")
262
+ # Send the stop signal
263
+ yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
264
+ yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  if request.stream:
267
  logger.info("Streaming response initiated.")
 
295
  "usage": None,
296
  }
297
 
 
298
  # Entry Point
299
  if __name__ == "__main__":
300
  uvicorn.run(app, host="0.0.0.0", port=7860)