dan92 commited on
Commit
9c2cf37
·
verified ·
1 Parent(s): 4d18321

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -2
  2. app.py +111 -150
Dockerfile CHANGED
@@ -23,5 +23,4 @@ ENV PYTHONUNBUFFERED=1
23
  EXPOSE 3000
24
 
25
  # 使用 gunicorn 作为生产级 WSGI 服务器
26
- # Dockerfile 中修改 gunicorn 命令
27
- CMD ["gunicorn", "--bind", "0.0.0.0:3000", "--workers", "4", "--timeout", "120", "--keep-alive", "5", "--worker-class", "sync", "app:app"]
 
23
  EXPOSE 3000
24
 
25
  # 使用 gunicorn 作为生产级 WSGI 服务器
26
+ CMD ["gunicorn", "--bind", "0.0.0.0:3000", "--workers", "4", "app:app"]
 
app.py CHANGED
@@ -18,6 +18,8 @@ from urllib3.util.connection import create_connection
18
  import urllib3
19
  from cachetools import TTLCache
20
  import threading
 
 
21
 
22
  # 新增导入
23
  import register_bot
@@ -98,14 +100,8 @@ class CustomHTTPAdapter(HTTPAdapter):
98
 
99
  # 创建自定义的 Session
100
  def create_custom_session():
101
- """创建自定义的 Session,添加超时设置"""
102
  session = requests.Session()
103
- adapter = CustomHTTPAdapter(
104
- pool_connections=100,
105
- pool_maxsize=100,
106
- max_retries=3,
107
- pool_block=False
108
- )
109
  session.mount('https://', adapter)
110
  session.mount('http://', adapter)
111
  return session
@@ -311,42 +307,18 @@ class MultiAuthManager:
311
  def __init__(self, credentials):
312
  self.auth_managers = [AuthManager(email, password) for email, password in credentials]
313
  self.current_index = 0
314
- self.last_success_index = None
315
  self._last_rotation = time.time()
316
- self._rotation_interval = 300
317
- self._model_usage = {}
318
- self._invalid_accounts = set() # 记录失效的账号
319
-
320
- def remove_invalid_account(self, auth_manager):
321
- """移除失效的账号"""
322
- if auth_manager._email not in self._invalid_accounts:
323
- self._invalid_accounts.add(auth_manager._email)
324
- self.auth_managers = [am for am in self.auth_managers if am._email != auth_manager._email]
325
- logger.info(f"已移除失效账号: {auth_manager._email}")
326
-
327
- # 如果移除的是当前使用的账号,重置索引
328
- if self.last_success_index is not None:
329
- if self.last_success_index >= len(self.auth_managers):
330
- self.last_success_index = None
331
- if self.current_index >= len(self.auth_managers):
332
- self.current_index = 0
333
 
334
  def get_next_auth_manager(self, model):
335
- """改进的账号选择逻辑,优先使用模型对应的上次成功账号"""
336
- if not self.auth_managers: # 如果没有可用账号,返回 None
337
- return None
338
-
339
- # 首先尝试使用该模型上次成功的账号
340
- if model in self._model_usage:
341
- last_success_index = self._model_usage[model]
342
- if last_success_index < len(self.auth_managers):
343
- auth_manager = self.auth_managers[last_success_index]
344
- if auth_manager.is_model_available(model) and auth_manager._should_attempt_auth():
345
- return auth_manager
346
-
347
- # 如果没有该模型的成功记录,或上次成功的账号不可用,则从当前位置开始轮询
348
- if len(self.auth_managers) == 0:
349
- return None
350
 
351
  start_index = self.current_index
352
  for _ in range(len(self.auth_managers)):
@@ -359,34 +331,13 @@ class MultiAuthManager:
359
  return None
360
 
361
  def ensure_valid_token(self, model):
362
- """确保获取有效的token并返回可用的auth_manager"""
363
- auth_manager = self.get_next_auth_manager(model)
364
- if not auth_manager:
365
- return None
366
-
367
- try:
368
- if auth_manager.ensure_valid_token():
369
  return auth_manager
370
- except requests.exceptions.RequestException as e:
371
- if "400 Client Error: Bad Request" in str(e):
372
- logger.error(f"账号 {auth_manager._email} 已失效")
373
- self.remove_invalid_account(auth_manager)
374
- return None
375
-
376
  return None
377
 
378
- def mark_success(self, auth_manager, model):
379
- """记录成功使用的账号索引,并与模型关联"""
380
- for i, manager in enumerate(self.auth_managers):
381
- if manager == auth_manager:
382
- self._model_usage[model] = i
383
- self.last_success_index = i
384
- break
385
-
386
- def reset_model_status(self):
387
- """重置所有账号的模型使用状态"""
388
- self._model_usage.clear()
389
- self._invalid_accounts.clear() # 清除失效账号记录
390
  for auth_manager in self.auth_managers:
391
  auth_manager.reset_model_status()
392
 
@@ -516,54 +467,37 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0301"):
516
  """计算消息列表中的总令牌数量。"""
517
  return sum(count_tokens(str(message), model) for message in messages)
518
 
519
- # 在文件开头添加常量
520
- STREAM_TIMEOUT = 30 # 流式响应超时时间(秒)
521
- REQUEST_TIMEOUT = 10 # 普通请求超时时间(秒)
522
- CHUNK_SIZE = 512 # 减小块大小以加快处理速度
523
-
524
  def stream_notdiamond_response(response, model):
525
- """改进的流式响应处理,添加超时和错误处理"""
526
  buffer = ""
527
  full_content = ""
528
- last_chunk_time = time.time()
529
 
530
- try:
531
- for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
532
- current_time = time.time()
533
- if current_time - last_chunk_time > STREAM_TIMEOUT:
534
- logger.warning("Stream timeout reached")
535
- break
536
 
537
- if chunk:
538
- try:
539
- new_content = chunk.decode('utf-8')
540
- buffer += new_content
541
- full_content += new_content
542
-
543
- # 创建完整的响应块
544
- chunk_data = create_openai_chunk(new_content, model)
545
-
546
- # 确保响应块包含完整的上下文
547
- if 'choices' in chunk_data and chunk_data['choices']:
548
- chunk_data['choices'][0]['delta']['content'] = new_content
549
- chunk_data['choices'][0]['context'] = full_content
550
-
551
- yield chunk_data
552
- last_chunk_time = current_time
553
-
554
- except Exception as e:
555
- logger.error(f"Error processing chunk: {e}")
556
- continue
557
- except requests.exceptions.RequestException as e:
558
- logger.error(f"Stream error: {e}")
559
- except Exception as e:
560
- logger.error(f"Unexpected error in stream processing: {e}")
561
- finally:
562
- # 确保发送完成标记
563
- final_chunk = create_openai_chunk('', model, 'stop')
564
- if 'choices' in final_chunk and final_chunk['choices']:
565
- final_chunk['choices'][0]['context'] = full_content
566
- yield final_chunk
567
 
568
  def handle_non_stream_response(response, model, prompt_tokens):
569
  """改进的非流式响应处理,确保保持完整上下文。"""
@@ -611,32 +545,23 @@ def handle_non_stream_response(response, model, prompt_tokens):
611
  logger.error(f"Error processing non-stream response: {e}")
612
  raise
613
 
614
- # 修改 generate_stream_response 函数
615
  def generate_stream_response(response, model, prompt_tokens):
616
- """改进的流式 HTTP 响应生成器"""
617
  total_completion_tokens = 0
618
- start_time = time.time()
619
 
620
- try:
621
- for chunk in stream_notdiamond_response(response, model):
622
- if time.time() - start_time > STREAM_TIMEOUT:
623
- logger.warning("Response generation timeout")
624
- break
625
-
626
- content = chunk['choices'][0]['delta'].get('content', '')
627
- total_completion_tokens += count_tokens(content, model)
628
-
629
- chunk['usage'] = {
630
- "prompt_tokens": prompt_tokens,
631
- "completion_tokens": total_completion_tokens,
632
- "total_tokens": prompt_tokens + total_completion_tokens
633
- }
634
-
635
- yield f"data: {json.dumps(chunk)}\n\n"
636
- except Exception as e:
637
- logger.error(f"Error generating stream response: {e}")
638
- finally:
639
- yield "data: [DONE]\n\n"
640
 
641
  def get_auth_credentials():
642
  """从API获取认证凭据"""
@@ -898,57 +823,93 @@ def make_request(payload, auth_manager, model_id):
898
  url,
899
  headers=headers,
900
  json=payload,
901
- stream=True,
902
- timeout=REQUEST_TIMEOUT # 添加超时设置
903
  ).result()
904
 
905
  if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
906
  logger.info(f"请求成功,使用账号 {auth_manager._email}")
907
- multi_auth_manager.mark_success(auth_manager, model_id)
908
  return response
909
 
910
  headers_cache.clear()
911
 
912
  if response.status_code == 401: # Unauthorized
913
  logger.info(f"Token expired for account {auth_manager._email}, attempting refresh")
914
- try:
915
- if auth_manager.ensure_valid_token():
916
- continue
917
- except requests.exceptions.RequestException as e:
918
- if "400 Client Error: Bad Request" in str(e):
919
- logger.error(f"账号 {auth_manager._email} 已失效")
920
- multi_auth_manager.remove_invalid_account(auth_manager)
921
- break
922
 
923
  if response.status_code == 403: # Forbidden, 模型使用限制
924
  logger.warning(f"Model {model_id} usage limit reached for account {auth_manager._email}")
925
  auth_manager.set_model_unavailable(model_id)
926
- break
927
 
928
  logger.error(f"Request failed with status {response.status_code} for account {auth_manager._email}")
929
 
930
  except Exception as e:
931
  logger.error(f"Request attempt {attempt + 1} failed for account {auth_manager._email}: {e}")
932
- if "400 Client Error: Bad Request" in str(e):
933
- logger.error(f"账号 {auth_manager._email} 已失效")
934
- multi_auth_manager.remove_invalid_account(auth_manager)
935
- break
936
  if attempt < max_retries - 1:
937
  time.sleep(retry_delay)
938
  continue
939
 
940
- # 检查是否需要注册新账号
941
- if not multi_auth_manager.auth_managers or len(tried_accounts) == len(multi_auth_manager.auth_managers):
942
- logger.info("所有现有账号都已尝试或无可用账号,开始注册新账号")
943
  successful_accounts = register_bot.register_and_verify(5)
944
  if successful_accounts:
945
  credentials = [(account['email'], account['password']) for account in successful_accounts]
946
  multi_auth_manager = MultiAuthManager(credentials)
 
947
  return make_request(payload, None, model_id)
948
 
949
  raise Exception("所有账号均不可用,且注册新账号失败")
950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
  if __name__ == "__main__":
 
 
 
952
  port = int(os.environ.get("PORT", 3000))
953
  app.run(debug=False, host='0.0.0.0', port=port, threaded=True)
954
 
 
18
  import urllib3
19
  from cachetools import TTLCache
20
  import threading
21
+ from time import sleep
22
+ from datetime import datetime, timedelta
23
 
24
  # 新增导入
25
  import register_bot
 
100
 
101
  # 创建自定义的 Session
102
  def create_custom_session():
 
103
  session = requests.Session()
104
+ adapter = CustomHTTPAdapter()
 
 
 
 
 
105
  session.mount('https://', adapter)
106
  session.mount('http://', adapter)
107
  return session
 
307
  def __init__(self, credentials):
308
  self.auth_managers = [AuthManager(email, password) for email, password in credentials]
309
  self.current_index = 0
 
310
  self._last_rotation = time.time()
311
+ self._rotation_interval = 300 # 5分钟轮转间隔
312
+
313
+ def _should_rotate(self) -> bool:
314
+ """检查是否应该轮转到下一个账号"""
315
+ return time.time() - self._last_rotation >= self._rotation_interval
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  def get_next_auth_manager(self, model):
318
+ """改进的账号选择逻辑"""
319
+ if self._should_rotate():
320
+ self.current_index = (self.current_index + 1) % len(self.auth_managers)
321
+ self._last_rotation = time.time()
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  start_index = self.current_index
324
  for _ in range(len(self.auth_managers)):
 
331
  return None
332
 
333
  def ensure_valid_token(self, model):
334
+ for _ in range(len(self.auth_managers)):
335
+ auth_manager = self.get_next_auth_manager(model)
336
+ if auth_manager and auth_manager.ensure_valid_token():
 
 
 
 
337
  return auth_manager
 
 
 
 
 
 
338
  return None
339
 
340
+ def reset_all_model_status(self):
 
 
 
 
 
 
 
 
 
 
 
341
  for auth_manager in self.auth_managers:
342
  auth_manager.reset_model_status()
343
 
 
467
  """计算消息列表中的总令牌数量。"""
468
  return sum(count_tokens(str(message), model) for message in messages)
469
 
 
 
 
 
 
470
  def stream_notdiamond_response(response, model):
471
+ """改进的流式响应处理,确保保持上下文完整性。"""
472
  buffer = ""
473
  full_content = ""
 
474
 
475
+ for chunk in response.iter_content(chunk_size=1024):
476
+ if chunk:
477
+ try:
478
+ new_content = chunk.decode('utf-8')
479
+ buffer += new_content
480
+ full_content += new_content
481
 
482
+ # 创建完整的响应块
483
+ chunk_data = create_openai_chunk(new_content, model)
484
+
485
+ # 确保响应块包含完整的上下文
486
+ if 'choices' in chunk_data and chunk_data['choices']:
487
+ chunk_data['choices'][0]['delta']['content'] = new_content
488
+ chunk_data['choices'][0]['context'] = full_content # 添加完整上下文
489
+
490
+ yield chunk_data
491
+
492
+ except Exception as e:
493
+ logger.error(f"Error processing chunk: {e}")
494
+ continue
495
+
496
+ # 发送完成标记
497
+ final_chunk = create_openai_chunk('', model, 'stop')
498
+ if 'choices' in final_chunk and final_chunk['choices']:
499
+ final_chunk['choices'][0]['context'] = full_content # 在最终块中包含完整上下文
500
+ yield final_chunk
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  def handle_non_stream_response(response, model, prompt_tokens):
503
  """改进的非流式响应处理,确保保持完整上下文。"""
 
545
  logger.error(f"Error processing non-stream response: {e}")
546
  raise
547
 
 
548
  def generate_stream_response(response, model, prompt_tokens):
549
+ """生成流式 HTTP 响应。"""
550
  total_completion_tokens = 0
 
551
 
552
+ for chunk in stream_notdiamond_response(response, model):
553
+ content = chunk['choices'][0]['delta'].get('content', '')
554
+ total_completion_tokens += count_tokens(content, model)
555
+
556
+ chunk['usage'] = {
557
+ "prompt_tokens": prompt_tokens,
558
+ "completion_tokens": total_completion_tokens,
559
+ "total_tokens": prompt_tokens + total_completion_tokens
560
+ }
561
+
562
+ yield f"data: {json.dumps(chunk)}\n\n"
563
+
564
+ yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
565
 
566
  def get_auth_credentials():
567
  """从API获取认证凭据"""
 
823
  url,
824
  headers=headers,
825
  json=payload,
826
+ stream=True
 
827
  ).result()
828
 
829
  if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
830
  logger.info(f"请求成功,使用账号 {auth_manager._email}")
 
831
  return response
832
 
833
  headers_cache.clear()
834
 
835
  if response.status_code == 401: # Unauthorized
836
  logger.info(f"Token expired for account {auth_manager._email}, attempting refresh")
837
+ if auth_manager.ensure_valid_token():
838
+ continue
 
 
 
 
 
 
839
 
840
  if response.status_code == 403: # Forbidden, 模型使用限制
841
  logger.warning(f"Model {model_id} usage limit reached for account {auth_manager._email}")
842
  auth_manager.set_model_unavailable(model_id)
843
+ break # 跳出重试循环,尝试下一个账号
844
 
845
  logger.error(f"Request failed with status {response.status_code} for account {auth_manager._email}")
846
 
847
  except Exception as e:
848
  logger.error(f"Request attempt {attempt + 1} failed for account {auth_manager._email}: {e}")
 
 
 
 
849
  if attempt < max_retries - 1:
850
  time.sleep(retry_delay)
851
  continue
852
 
853
+ # 所有账号都尝试过且失败后,才进行注册
854
+ if len(tried_accounts) == len(multi_auth_manager.auth_managers):
855
+ logger.info("所有现有账号都已尝试,开始注册新账号")
856
  successful_accounts = register_bot.register_and_verify(5)
857
  if successful_accounts:
858
  credentials = [(account['email'], account['password']) for account in successful_accounts]
859
  multi_auth_manager = MultiAuthManager(credentials)
860
+ # 使用新注册的账号重试请求
861
  return make_request(payload, None, model_id)
862
 
863
  raise Exception("所有账号均不可用,且注册新账号失败")
864
 
865
+ def health_check():
866
+ """改进的健康检查函数"""
867
+ last_check_time = {} # 用于跟踪每个账号的最后检查时间
868
+
869
+ while True:
870
+ try:
871
+ if multi_auth_manager:
872
+ current_time = time.time()
873
+
874
+ for auth_manager in multi_auth_manager.auth_managers:
875
+ email = auth_manager._email
876
+
877
+ # 检查是否需要进行健康检查
878
+ if email not in last_check_time or \
879
+ current_time - last_check_time[email] >= AUTH_CHECK_INTERVAL:
880
+
881
+ if not auth_manager._should_attempt_auth():
882
+ logger.info(f"Skipping health check for {email} due to rate limiting")
883
+ continue
884
+
885
+ if not auth_manager.ensure_valid_token():
886
+ logger.warning(f"Auth token validation failed during health check for {email}")
887
+ auth_manager.clear_auth()
888
+ else:
889
+ logger.info(f"Health check passed for {email}")
890
+
891
+ last_check_time[email] = current_time
892
+
893
+ # 每天重置所有账号的模型使用状态
894
+ current_time_local = time.localtime()
895
+ if current_time_local.tm_hour == 0 and current_time_local.tm_min == 0:
896
+ multi_auth_manager.reset_all_model_status()
897
+ logger.info("Reset model status for all accounts")
898
+
899
+ except Exception as e:
900
+ logger.error(f"Health check error: {e}")
901
+
902
+ sleep(60) # 主循环每分钟运行一次
903
+
904
+ # 为了兼容 Flask CLI 和 Gunicorn,修改启动逻辑
905
+ if __name__ != "__main__":
906
+ health_check_thread = threading.Thread(target=health_check, daemon=True)
907
+ health_check_thread.start()
908
+
909
  if __name__ == "__main__":
910
+ health_check_thread = threading.Thread(target=health_check, daemon=True)
911
+ health_check_thread.start()
912
+
913
  port = int(os.environ.get("PORT", 3000))
914
  app.run(debug=False, host='0.0.0.0', port=port, threaded=True)
915