Niansuh commited on
Commit
b59c7b9
·
verified ·
1 Parent(s): 6baac50

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +172 -53
main.py CHANGED
@@ -8,118 +8,237 @@ from typing import Any, Dict, List, Optional
8
  import httpx
9
  import uvicorn
10
  from dotenv import load_dotenv
11
- from fastapi import FastAPI, HTTPException
 
12
  from pydantic import BaseModel
13
- from starlette.responses import StreamingResponse
 
14
 
 
15
  logging.basicConfig(
16
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
17
  )
18
  logger = logging.getLogger(__name__)
19
 
 
20
  load_dotenv()
21
  app = FastAPI()
 
 
 
 
22
  ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN", "")
23
  headers = {
 
 
24
  'authorization': f'Bearer {ACCESS_TOKEN}',
25
- 'user-agent': 'Mozilla/5.0',
 
 
 
 
 
 
 
26
  }
27
 
 
28
  ALLOWED_MODELS = [
29
- "claude-3.5-sonnet", "sider", "gpt-4o-mini", "claude-3-haiku", "claude-3.5-haiku",
30
- "gemini-1.5-flash", "llama-3", "gpt-4o", "gemini-1.5-pro", "llama-3.1-405b"
 
 
 
 
 
 
 
 
31
  ]
32
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
34
  class Message(BaseModel):
35
  role: str
36
  content: str
37
 
38
-
39
  class ChatRequest(BaseModel):
40
  model: str
41
  messages: List[Message]
42
  stream: Optional[bool] = False
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
 
45
  @app.post("/hf/v1/chat/completions")
46
- async def chat_completions(request: ChatRequest):
 
 
47
  logger.info(f"Received chat completion request for model: {request.model}")
48
 
49
- if request.model not in ALLOWED_MODELS:
 
 
50
  logger.error(f"Model {request.model} is not allowed.")
51
  raise HTTPException(
52
  status_code=400,
53
- detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(ALLOWED_MODELS)}",
54
  )
55
 
56
- # Build the payload as per the required structure
57
- json_data = {
58
- 'prompt': "\n".join(
59
- [f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}" for msg in request.messages]
60
- ),
61
- 'stream': request.stream,
62
- 'app_name': "ChitChat_Chrome_Ext",
63
- 'app_version': "4.28.0",
64
- 'auto_search': False,
65
- 'chat_models': [request.model], # Try setting model here
66
- 'cid': "C082SLBVNR9J",
67
- 'extra_info': {
68
- 'origin_url': "chrome-extension://difoiogjjojoaoomphldepapgpbgkhkb/standalone.html?from=sidebar",
69
- 'origin_title': "Sider"
70
- },
71
- 'files': [],
72
- 'filter_search_history': False,
73
- 'from': "chat",
74
- 'group_id': "default",
75
- 'model': request.model, # Keep original model field
76
- 'prompt_template': {
77
- 'key': "artifacts",
78
- 'attributes': {
79
- 'lang': "original"
80
- }
81
- },
82
- 'search': False,
83
- 'tools': {
84
- 'auto': ["search", "text_to_image", "data_analysis"]
85
- },
86
- 'tz_name': "Asia/Karachi"
87
- }
88
 
 
 
 
 
89
 
90
- logger.info(f"Outgoing API request payload: {json.dumps(json_data, indent=2)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  async def generate():
93
  async with httpx.AsyncClient() as client:
94
  try:
95
- async with client.stream('POST', 'https://sider.ai/api/v3/completion/text', headers=headers, json=json_data, timeout=120.0) as response:
 
 
 
 
 
 
96
  response.raise_for_status()
97
  async for line in response.aiter_lines():
98
  if line and ("[DONE]" not in line):
99
- content = json.loads(line[5:])["data"]
100
- yield f"data: {json.dumps(create_chat_completion_data(content.get('text', ''), request.model))}\n\n"
 
 
 
 
 
 
 
 
 
 
101
  yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
102
  yield "data: [DONE]\n\n"
103
  except httpx.HTTPStatusError as e:
104
- logger.error(f"HTTP error occurred: {e}")
105
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
106
  except httpx.RequestError as e:
107
  logger.error(f"An error occurred while requesting: {e}")
108
  raise HTTPException(status_code=500, detail=str(e))
109
 
110
  if request.stream:
111
- logger.info("Streaming response")
112
  return StreamingResponse(generate(), media_type="text/event-stream")
113
  else:
114
- logger.info("Non-streaming response")
115
  full_response = ""
116
  async for chunk in generate():
117
  if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
118
- data = json.loads(chunk[6:])
119
- if data["choices"][0]["delta"].get("content"):
120
- full_response += data["choices"][0]["delta"]["content"]
 
 
 
 
121
 
122
- logger.info(f"Full response generated: {full_response}")
123
  return {
124
  "id": f"chatcmpl-{uuid.uuid4()}",
125
  "object": "chat.completion",
@@ -135,6 +254,6 @@ json_data = {
135
  "usage": None,
136
  }
137
 
138
-
139
  if __name__ == "__main__":
140
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  import httpx
9
  import uvicorn
10
  from dotenv import load_dotenv
11
+ from fastapi import FastAPI, HTTPException, Depends
12
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel
14
+ from starlette.middleware.cors import CORSMiddleware
15
+ from starlette.responses import StreamingResponse, Response
16
 
17
+ # Configure Logging
18
  logging.basicConfig(
19
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
20
  )
21
  logger = logging.getLogger(__name__)
22
 
23
+ # Load Environment Variables
24
  load_dotenv()
25
  app = FastAPI()
26
+
27
+ # Configuration Constants
28
+ BASE_URL = "https://aichatonlineorg.erweima.ai/aichatonline"
29
+ APP_SECRET = os.getenv("APP_SECRET", "666")
30
  ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN", "")
31
  headers = {
32
+ 'accept': '*/*',
33
+ 'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6',
34
  'authorization': f'Bearer {ACCESS_TOKEN}',
35
+ 'cache-control': 'no-cache',
36
+ 'origin': 'chrome-extension://dhoenijjpgpeimemopealfcbiecgceod',
37
+ 'pragma': 'no-cache',
38
+ 'priority': 'u=1, i',
39
+ 'sec-fetch-dest': 'empty',
40
+ 'sec-fetch-mode': 'cors',
41
+ 'sec-fetch-site': 'none',
42
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0',
43
  }
44
 
45
+ # Updated ALLOWED_MODELS List (No Duplicates)
46
  ALLOWED_MODELS = [
47
+ {"id": "claude-3.5-sonnet", "name": "Claude 3.5 Sonnet"},
48
+ {"id": "sider", "name": "Sider"},
49
+ {"id": "gpt-4o-mini", "name": "GPT-4o Mini"},
50
+ {"id": "claude-3-haiku", "name": "Claude 3 Haiku"},
51
+ {"id": "claude-3.5-haiku", "name": "Claude 3.5 Haiku"},
52
+ {"id": "gemini-1.5-flash", "name": "Gemini 1.5 Flash"},
53
+ {"id": "llama-3", "name": "Llama 3"},
54
+ {"id": "gpt-4o", "name": "GPT-4o"},
55
+ {"id": "gemini-1.5-pro", "name": "Gemini 1.5 Pro"},
56
+ {"id": "llama-3.1-405b", "name": "Llama 3.1 405b"},
57
  ]
58
 
59
+ # Configure CORS
60
+ app.add_middleware(
61
+ CORSMiddleware,
62
+ allow_origins=["*"], # Restrict this to specific origins in production
63
+ allow_credentials=True,
64
+ allow_methods=["*"], # All methods allowed
65
+ allow_headers=["*"], # Allow all headers
66
+ )
67
+
68
+ # Security Configuration
69
+ security = HTTPBearer()
70
 
71
+ # Pydantic Models
72
  class Message(BaseModel):
73
  role: str
74
  content: str
75
 
 
76
  class ChatRequest(BaseModel):
77
  model: str
78
  messages: List[Message]
79
  stream: Optional[bool] = False
80
 
81
+ # Utility Functions
82
+ def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
83
+ return {
84
+ "id": f"chatcmpl-{uuid.uuid4()}",
85
+ "object": "chat.completion.chunk",
86
+ "created": int(datetime.now().timestamp()),
87
+ "model": model,
88
+ "choices": [
89
+ {
90
+ "index": 0,
91
+ "delta": {"content": content, "role": "assistant"},
92
+ "finish_reason": finish_reason,
93
+ }
94
+ ],
95
+ "usage": None,
96
+ }
97
+
98
+ def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
99
+ if credentials.credentials != APP_SECRET:
100
+ logger.warning(f"Invalid APP_SECRET provided: {credentials.credentials}")
101
+ raise HTTPException(status_code=403, detail="Invalid APP_SECRET")
102
+ logger.info("APP_SECRET verified successfully.")
103
+ return credentials.credentials
104
+
105
+ # CORS Preflight Options Endpoint
106
+ @app.options("/hf/v1/chat/completions")
107
+ async def chat_completions_options():
108
+ return Response(
109
+ status_code=200,
110
+ headers={
111
+ "Access-Control-Allow-Origin": "*",
112
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
113
+ "Access-Control-Allow-Headers": "Content-Type, Authorization",
114
+ },
115
+ )
116
+
117
+ # Replace Escaped Newlines
118
+ def replace_escaped_newlines(input_string: str) -> str:
119
+ return input_string.replace("\\n", "\n")
120
 
121
+ # List Available Models
122
+ @app.get("/hf/v1/models")
123
+ async def list_models():
124
+ return {"object": "list", "data": ALLOWED_MODELS}
125
+
126
+ # Chat Completions Endpoint
127
  @app.post("/hf/v1/chat/completions")
128
+ async def chat_completions(
129
+ request: ChatRequest, app_secret: str = Depends(verify_app_secret)
130
+ ):
131
  logger.info(f"Received chat completion request for model: {request.model}")
132
 
133
+ # Validate Selected Model
134
+ if request.model not in [model['id'] for model in ALLOWED_MODELS]:
135
+ allowed = ', '.join(model['id'] for model in ALLOWED_MODELS)
136
  logger.error(f"Model {request.model} is not allowed.")
137
  raise HTTPException(
138
  status_code=400,
139
+ detail=f"Model '{request.model}' is not allowed. Allowed models are: {allowed}",
140
  )
141
 
142
+ logger.info(f"Using model: {request.model}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ # Generate a UUID
145
+ original_uuid = uuid.uuid4()
146
+ uuid_str = str(original_uuid).replace("-", "")
147
+ logger.debug(f"Generated UUID: {uuid_str}")
148
 
149
+ # Prepare JSON Payload for External API
150
+ json_data = {
151
+ 'prompt': "\n".join(
152
+ [
153
+ f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}"
154
+ for msg in request.messages
155
+ ]
156
+ ),
157
+ 'stream': True,
158
+ 'app_name': 'ChitChat_Edge_Ext',
159
+ 'app_version': '4.26.1',
160
+ 'tz_name': 'Asia/Karachi',
161
+ 'cid': '',
162
+ 'model': request.model, # Use the selected model directly
163
+ 'search': False,
164
+ 'auto_search': False,
165
+ 'filter_search_history': False,
166
+ 'from': 'chat',
167
+ 'group_id': 'default',
168
+ 'chat_models': [request.model], # Include the model in chat_models
169
+ 'files': [],
170
+ 'prompt_template': {
171
+ 'key': '',
172
+ 'attributes': {
173
+ 'lang': 'original',
174
+ },
175
+ },
176
+ 'tools': {
177
+ 'auto': [
178
+ 'search',
179
+ 'text_to_image',
180
+ 'data_analysis',
181
+ ],
182
+ },
183
+ 'extra_info': {
184
+ 'origin_url': '',
185
+ 'origin_title': '',
186
+ },
187
+ }
188
+
189
+ logger.debug(f"JSON Data Sent to External API: {json.dumps(json_data, indent=2)}")
190
 
191
  async def generate():
192
  async with httpx.AsyncClient() as client:
193
  try:
194
+ async with client.stream(
195
+ 'POST',
196
+ 'https://sider.ai/api/v2/completion/text',
197
+ headers=headers,
198
+ json=json_data,
199
+ timeout=120.0
200
+ ) as response:
201
  response.raise_for_status()
202
  async for line in response.aiter_lines():
203
  if line and ("[DONE]" not in line):
204
+ # Assuming the line starts with 'data: ' followed by JSON
205
+ if line.startswith("data: "):
206
+ json_line = line[6:]
207
+ if json_line.startswith("{"):
208
+ try:
209
+ data = json.loads(json_line)
210
+ content = data.get("data", {}).get("text", "")
211
+ logger.debug(f"Received content: {content}")
212
+ yield f"data: {json.dumps(create_chat_completion_data(content, request.model))}\n\n"
213
+ except json.JSONDecodeError as e:
214
+ logger.error(f"JSON decode error: {e} - Line: {json_line}")
215
+ # Send the stop signal
216
  yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
217
  yield "data: [DONE]\n\n"
218
  except httpx.HTTPStatusError as e:
219
+ logger.error(f"HTTP error occurred: {e} - Response: {e.response.text}")
220
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
221
  except httpx.RequestError as e:
222
  logger.error(f"An error occurred while requesting: {e}")
223
  raise HTTPException(status_code=500, detail=str(e))
224
 
225
  if request.stream:
226
+ logger.info("Streaming response initiated.")
227
  return StreamingResponse(generate(), media_type="text/event-stream")
228
  else:
229
+ logger.info("Non-streaming response initiated.")
230
  full_response = ""
231
  async for chunk in generate():
232
  if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
233
+ # Parse the JSON part after 'data: '
234
+ try:
235
+ data = json.loads(chunk[6:])
236
+ if data["choices"][0]["delta"].get("content"):
237
+ full_response += data["choices"][0]["delta"]["content"]
238
+ except json.JSONDecodeError:
239
+ logger.warning(f"Failed to decode JSON from chunk: {chunk}")
240
 
241
+ # Final Response Structure
242
  return {
243
  "id": f"chatcmpl-{uuid.uuid4()}",
244
  "object": "chat.completion",
 
254
  "usage": None,
255
  }
256
 
257
+ # Entry Point
258
  if __name__ == "__main__":
259
  uvicorn.run(app, host="0.0.0.0", port=7860)