Niansuh commited on
Commit
fd28e64
·
verified ·
1 Parent(s): e41f64b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +95 -68
main.py CHANGED
@@ -14,16 +14,22 @@ from pydantic import BaseModel
14
  from starlette.middleware.cors import CORSMiddleware
15
  from starlette.responses import StreamingResponse, Response
16
 
 
17
  logging.basicConfig(
18
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
19
  )
20
  logger = logging.getLogger(__name__)
21
 
 
22
  load_dotenv()
 
 
23
  app = FastAPI()
 
 
24
  BASE_URL = "https://aichatonlineorg.erweima.ai/aichatonline"
25
- APP_SECRET = os.getenv("APP_SECRET","666")
26
- ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN","")
27
  headers = {
28
  'accept': '*/*',
29
  'accept-language': 'en-US,en;q=0.9',
@@ -38,6 +44,7 @@ headers = {
38
  'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
39
  }
40
 
 
41
  ALLOWED_MODELS = [
42
  {"id": "claude-3.5-sonnet", "name": "claude-3.5-sonnet"},
43
  {"id": "claude-3-opus", "name": "claude-3-opus"},
@@ -47,28 +54,30 @@ ALLOWED_MODELS = [
47
  {"id": "o1-mini", "name": "o1-mini"},
48
  {"id": "gpt-4o-mini", "name": "gpt-4o-mini"},
49
  ]
50
- # Configure CORS
 
51
  app.add_middleware(
52
  CORSMiddleware,
53
- allow_origins=["*"], # Allow all sources, you can restrict specific sources if needed
54
  allow_credentials=True,
55
- allow_methods=["*"], # All methods allowed
56
  allow_headers=["*"], # Allow all headers
57
  )
58
- security = HTTPBearer()
59
 
 
 
60
 
 
61
  class Message(BaseModel):
62
  role: str
63
  content: str
64
 
65
-
66
  class ChatRequest(BaseModel):
67
  model: str
68
  messages: List[Message]
69
  stream: Optional[bool] = False
70
 
71
-
72
  def simulate_data(content, model):
73
  return {
74
  "id": f"chatcmpl-{uuid.uuid4()}",
@@ -85,7 +94,6 @@ def simulate_data(content, model):
85
  "usage": None,
86
  }
87
 
88
-
89
  def stop_data(content, model):
90
  return {
91
  "id": f"chatcmpl-{uuid.uuid4()}",
@@ -101,8 +109,7 @@ def stop_data(content, model):
101
  ],
102
  "usage": None,
103
  }
104
-
105
-
106
  def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
107
  return {
108
  "id": f"chatcmpl-{uuid.uuid4()}",
@@ -119,12 +126,16 @@ def create_chat_completion_data(content: str, model: str, finish_reason: Optiona
119
  "usage": None,
120
  }
121
 
122
-
123
  def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
124
  if credentials.credentials != APP_SECRET:
125
  raise HTTPException(status_code=403, detail="Invalid APP_SECRET")
126
  return credentials.credentials
127
 
 
 
 
 
 
128
 
129
  @app.options("/hf/v1/chat/completions")
130
  async def chat_completions_options():
@@ -137,72 +148,68 @@ async def chat_completions_options():
137
  },
138
  )
139
 
140
-
141
- def replace_escaped_newlines(input_string: str) -> str:
142
- return input_string.replace("\\n", "\n")
143
-
144
-
145
  @app.get("/hf/v1/models")
146
  async def list_models():
147
  return {"object": "list", "data": ALLOWED_MODELS}
148
 
149
-
150
  @app.post("/hf/v1/chat/completions")
151
  async def chat_completions(
152
  request: ChatRequest, app_secret: str = Depends(verify_app_secret)
153
  ):
154
  logger.info(f"Received chat completion request for model: {request.model}")
155
 
 
156
  if request.model not in [model['id'] for model in ALLOWED_MODELS]:
157
  raise HTTPException(
158
  status_code=400,
159
  detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}",
160
  )
 
161
  # Generate a UUID
162
  original_uuid = uuid.uuid4()
163
  uuid_str = str(original_uuid).replace("-", "")
164
 
165
- json_data = {
166
- 'prompt': "\n".join(
167
- [
168
- f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}"
169
- for msg in request.messages
170
- ]
171
- ),
172
- 'stream': True,
173
- 'app_name': 'ChitChat_Edge_Ext',
174
- 'app_version': '4.28.0',
175
- 'tz_name': 'Asia/Karachi',
176
- 'cid': 'C092SEMXM9BJ',
177
- 'model': request.model,
178
- 'search': False, # Ensure search is disabled
179
- 'auto_search': False, # Ensure auto_search is disabled
180
- 'filter_search_history': False,
181
- 'from': 'chat',
182
- 'group_id': 'default',
183
- 'chat_models': [],
184
- 'files': [],
185
- 'prompt_template': {
186
- 'key': '',
187
- 'attributes': {
188
- 'lang': 'original',
 
 
189
  },
190
- },
191
- 'tools': {
192
- 'auto': [
193
- 'text_to_image',
194
- 'data_analysis',
195
- # Removed 'search' from the list
196
- ],
197
- },
198
- 'extra_info': {
199
- 'origin_url': '',
200
- 'origin_title': '',
201
- },
202
- }
203
-
204
-
205
 
 
206
  async def generate():
207
  async with httpx.AsyncClient() as client:
208
  try:
@@ -210,10 +217,28 @@ json_data = {
210
  response.raise_for_status()
211
  async for line in response.aiter_lines():
212
  if line and ("[DONE]" not in line):
213
- content = json.loads(line[5:])["data"]
214
- yield f"data: {json.dumps(create_chat_completion_data(content.get('text',''), request.model))}\n\n"
215
- yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
216
- yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  except httpx.HTTPStatusError as e:
218
  logger.error(f"HTTP error occurred: {e}")
219
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
@@ -229,11 +254,14 @@ json_data = {
229
  full_response = ""
230
  async for chunk in generate():
231
  if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
232
- # print(chunk)
233
- data = json.loads(chunk[6:])
234
- if data["choices"][0]["delta"].get("content"):
235
- full_response += data["choices"][0]["delta"]["content"]
236
-
 
 
 
237
  return {
238
  "id": f"chatcmpl-{uuid.uuid4()}",
239
  "object": "chat.completion",
@@ -249,7 +277,6 @@ json_data = {
249
  "usage": None,
250
  }
251
 
252
-
253
-
254
  if __name__ == "__main__":
255
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
14
  from starlette.middleware.cors import CORSMiddleware
15
  from starlette.responses import StreamingResponse, Response
16
 
17
+ # Configure logging
18
  logging.basicConfig(
19
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
20
  )
21
  logger = logging.getLogger(__name__)
22
 
23
+ # Load environment variables from .env file
24
  load_dotenv()
25
+
26
+ # Initialize FastAPI app
27
  app = FastAPI()
28
+
29
+ # Constants and configurations
30
  BASE_URL = "https://aichatonlineorg.erweima.ai/aichatonline"
31
+ APP_SECRET = os.getenv("APP_SECRET", "666")
32
+ ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN", "")
33
  headers = {
34
  'accept': '*/*',
35
  'accept-language': 'en-US,en;q=0.9',
 
44
  'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
45
  }
46
 
47
+ # Define allowed models
48
  ALLOWED_MODELS = [
49
  {"id": "claude-3.5-sonnet", "name": "claude-3.5-sonnet"},
50
  {"id": "claude-3-opus", "name": "claude-3-opus"},
 
54
  {"id": "o1-mini", "name": "o1-mini"},
55
  {"id": "gpt-4o-mini", "name": "gpt-4o-mini"},
56
  ]
57
+
58
+ # Configure CORS middleware
59
  app.add_middleware(
60
  CORSMiddleware,
61
+ allow_origins=["*"], # Allow all origins; restrict if necessary
62
  allow_credentials=True,
63
+ allow_methods=["*"], # Allow all HTTP methods
64
  allow_headers=["*"], # Allow all headers
65
  )
 
66
 
67
+ # Security configuration
68
+ security = HTTPBearer()
69
 
70
+ # Pydantic models
71
  class Message(BaseModel):
72
  role: str
73
  content: str
74
 
 
75
  class ChatRequest(BaseModel):
76
  model: str
77
  messages: List[Message]
78
  stream: Optional[bool] = False
79
 
80
+ # Helper functions
81
  def simulate_data(content, model):
82
  return {
83
  "id": f"chatcmpl-{uuid.uuid4()}",
 
94
  "usage": None,
95
  }
96
 
 
97
  def stop_data(content, model):
98
  return {
99
  "id": f"chatcmpl-{uuid.uuid4()}",
 
109
  ],
110
  "usage": None,
111
  }
112
+
 
113
  def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
114
  return {
115
  "id": f"chatcmpl-{uuid.uuid4()}",
 
126
  "usage": None,
127
  }
128
 
 
129
  def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
130
  if credentials.credentials != APP_SECRET:
131
  raise HTTPException(status_code=403, detail="Invalid APP_SECRET")
132
  return credentials.credentials
133
 
134
+ # Utility function to replace escaped newlines
135
+ def replace_escaped_newlines(input_string: str) -> str:
136
+ return input_string.replace("\\n", "\n")
137
+
138
+ # API Endpoints
139
 
140
  @app.options("/hf/v1/chat/completions")
141
  async def chat_completions_options():
 
148
  },
149
  )
150
 
 
 
 
 
 
151
  @app.get("/hf/v1/models")
152
  async def list_models():
153
  return {"object": "list", "data": ALLOWED_MODELS}
154
 
 
155
  @app.post("/hf/v1/chat/completions")
156
  async def chat_completions(
157
  request: ChatRequest, app_secret: str = Depends(verify_app_secret)
158
  ):
159
  logger.info(f"Received chat completion request for model: {request.model}")
160
 
161
+ # Validate model
162
  if request.model not in [model['id'] for model in ALLOWED_MODELS]:
163
  raise HTTPException(
164
  status_code=400,
165
  detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}",
166
  )
167
+
168
  # Generate a UUID
169
  original_uuid = uuid.uuid4()
170
  uuid_str = str(original_uuid).replace("-", "")
171
 
172
+ # Construct the payload to send to the external API
173
+ json_data = {
174
+ 'prompt': "\n".join(
175
+ [
176
+ f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}"
177
+ for msg in request.messages
178
+ ]
179
+ ),
180
+ 'stream': True,
181
+ 'app_name': 'ChitChat_Edge_Ext',
182
+ 'app_version': '4.28.0',
183
+ 'tz_name': 'Asia/Karachi',
184
+ 'cid': 'C092SEMXM9BJ',
185
+ 'model': request.model,
186
+ 'search': False, # Ensure search is disabled
187
+ 'auto_search': False, # Ensure auto_search is disabled
188
+ 'filter_search_history': False,
189
+ 'from': 'chat',
190
+ 'group_id': 'default',
191
+ 'chat_models': [],
192
+ 'files': [],
193
+ 'prompt_template': {
194
+ 'key': '',
195
+ 'attributes': {
196
+ 'lang': 'original',
197
+ },
198
  },
199
+ 'tools': {
200
+ 'auto': [
201
+ 'text_to_image',
202
+ 'data_analysis',
203
+ # 'search' has been removed to disable search functionality
204
+ ],
205
+ },
206
+ 'extra_info': {
207
+ 'origin_url': '',
208
+ 'origin_title': '',
209
+ },
210
+ }
 
 
 
211
 
212
+ # Define the asynchronous generator for streaming responses
213
  async def generate():
214
  async with httpx.AsyncClient() as client:
215
  try:
 
217
  response.raise_for_status()
218
  async for line in response.aiter_lines():
219
  if line and ("[DONE]" not in line):
220
+ # Assuming the line starts with some prefix before JSON, e.g., "data: "
221
+ # Adjust if necessary based on actual response format
222
+ try:
223
+ # Remove any prefix before JSON if present
224
+ if line.startswith("data: "):
225
+ line_content = line[6:]
226
+ else:
227
+ line_content = line
228
+
229
+ # Parse the JSON content
230
+ content = json.loads(line_content)["data"]
231
+
232
+ # Yield the formatted data
233
+ yield f"data: {json.dumps(create_chat_completion_data(content.get('text',''), request.model))}\n\n"
234
+ except json.JSONDecodeError as e:
235
+ logger.error(f"JSON decode error: {e}")
236
+ continue
237
+ else:
238
+ # Signal the end of the stream
239
+ if line and "[DONE]" in line:
240
+ yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
241
+ yield "data: [DONE]\n\n"
242
  except httpx.HTTPStatusError as e:
243
  logger.error(f"HTTP error occurred: {e}")
244
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
 
254
  full_response = ""
255
  async for chunk in generate():
256
  if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
257
+ try:
258
+ data = json.loads(chunk[6:])
259
+ if data["choices"][0]["delta"].get("content"):
260
+ full_response += data["choices"][0]["delta"]["content"]
261
+ except json.JSONDecodeError as e:
262
+ logger.error(f"JSON decode error in non-streaming response: {e}")
263
+ continue
264
+
265
  return {
266
  "id": f"chatcmpl-{uuid.uuid4()}",
267
  "object": "chat.completion",
 
277
  "usage": None,
278
  }
279
 
280
+ # Entry point for running the application
 
281
  if __name__ == "__main__":
282
  uvicorn.run(app, host="0.0.0.0", port=7860)