caidaoli commited on
Commit
1b38e37
·
verified ·
1 Parent(s): bb51b44

Update openai_ondemand_adapter.py

Browse files
Files changed (1) hide show
  1. openai_ondemand_adapter.py +258 -261
openai_ondemand_adapter.py CHANGED
@@ -1,261 +1,258 @@
1
- from flask import Flask, request, Response, jsonify
2
- import requests
3
- import uuid
4
- import time
5
- import json
6
- import threading
7
- import logging
8
- import os
9
-
10
- # ====== 读取 Huggingface Secret 配置的私有key =======
11
- PRIVATE_KEY = os.environ.get("PRIVATE_KEY", "")
12
- SAFE_HEADER = "X-API-KEY"
13
-
14
- # 全局接口访问权限检查
15
- def check_private_key():
16
- # 可以在这里放宽部分接口,比如首页等
17
- if request.path in ["/", "/favicon.ico"]:
18
- return
19
- key = request.headers.get(SAFE_HEADER)
20
- if not key or key != PRIVATE_KEY:
21
- return jsonify({"error": "Unauthorized, must provide correct X-API-KEY"}), 401
22
-
23
- # 应用所有API鉴权
24
- app = Flask(__name__)
25
- app.before_request(check_private_key)
26
-
27
- # ========== KEY池(每行一个)==========
28
- ONDEMAND_APIKEYS = [
29
- "Key1",
30
- "Key2",
31
- ]
32
- BAD_KEY_RETRY_INTERVAL = 600 # 秒
33
-
34
- # ========== OnDemand模型映射 ==========
35
- MODEL_MAP = {
36
- "gpto3-mini": "predefined-openai-gpto3-mini",
37
- "gpt-4o": "predefined-openai-gpt4o",
38
- "gpt-4.1": "predefined-openai-gpt4.1",
39
- "gpt-4.1-mini": "predefined-openai-gpt4.1-mini",
40
- "gpt-4.1-nano": "predefined-openai-gpt4.1-nano",
41
- "gpt-4o-mini": "predefined-openai-gpt4o-mini",
42
- "deepseek-v3": "predefined-deepseek-v3",
43
- "deepseek-r1": "predefined-deepseek-r1",
44
- "claude-3.7-sonnet": "predefined-claude-3.7-sonnet",
45
- "gemini-2.0-flash": "predefined-gemini-2.0-flash",
46
- }
47
- DEFAULT_ONDEMAND_MODEL = "predefined-openai-gpt4o"
48
- # ==========================================
49
-
50
- class KeyManager:
51
- def __init__(self, key_list):
52
- self.key_list = list(key_list)
53
- self.lock = threading.Lock()
54
- self.key_status = {k: {"bad": False, "bad_ts": None} for k in self.key_list}
55
- self.idx = 0
56
-
57
- def display_key(self, key):
58
- return f"{key[:6]}...{key[-4:]}"
59
-
60
- def get(self):
61
- with self.lock:
62
- total = len(self.key_list)
63
- for _ in range(total):
64
- key = self.key_list[self.idx]
65
- self.idx = (self.idx + 1) % total
66
- s = self.key_status[key]
67
- if not s["bad"]:
68
- print(f"【对话请求】【使用API KEY: {self.display_key(key)}】【状态:正常】")
69
- return key
70
- if s["bad"] and s["bad_ts"]:
71
- ago = time.time() - s["bad_ts"]
72
- if ago >= BAD_KEY_RETRY_INTERVAL:
73
- print(f"【KEY自动尝试恢复】API KEY: {self.display_key(key)} 满足重试周期,标记为正常")
74
- self.key_status[key]["bad"] = False
75
- self.key_status[key]["bad_ts"] = None
76
- print(f"【对话请求】【使用API KEY: {self.display_key(key)}】【状态:正常】")
77
- return key
78
- print("【警告】全部KEY已被禁用,强制选用第一个KEY继续尝试:", self.display_key(self.key_list[0]))
79
- for k in self.key_list:
80
- self.key_status[k]["bad"] = False
81
- self.key_status[k]["bad_ts"] = None
82
- self.idx = 0
83
- print(f"【对话请求】【使用API KEY: {self.display_key(self.key_list[0])}】【状态:强制尝试(全部异常)】")
84
- return self.key_list[0]
85
-
86
- def mark_bad(self, key):
87
- with self.lock:
88
- if key in self.key_status and not self.key_status[key]["bad"]:
89
- print(f"【禁用KEY】API KEY: {self.display_key(key)},接口返回无效(将在{BAD_KEY_RETRY_INTERVAL//60}分钟后自动重试)")
90
- self.key_status[key]["bad"] = True
91
- self.key_status[key]["bad_ts"] = time.time()
92
-
93
- keymgr = KeyManager(ONDEMAND_APIKEYS)
94
-
95
- ONDEMAND_API_BASE = "https://api.on-demand.io/chat/v1"
96
-
97
- def get_endpoint_id(openai_model):
98
- m = str(openai_model or "").lower().replace(" ", "")
99
- return MODEL_MAP.get(m, DEFAULT_ONDEMAND_MODEL)
100
-
101
- def create_session(apikey, external_user_id=None, plugin_ids=None):
102
- url = f"{ONDEMAND_API_BASE}/sessions"
103
- payload = {"externalUserId": external_user_id or str(uuid.uuid4())}
104
- if plugin_ids is not None:
105
- payload["pluginIds"] = plugin_ids
106
- headers = {"apikey": apikey, "Content-Type": "application/json"}
107
- resp = requests.post(url, json=payload, headers=headers, timeout=20)
108
- resp.raise_for_status()
109
- return resp.json()["data"]["id"]
110
-
111
- def format_openai_sse_delta(chunk_str):
112
- return f"data: {json.dumps(chunk_str, ensure_ascii=False)}\n\n"
113
-
114
- @app.route("/v1/chat/completions", methods=["POST"])
115
- def chat_completions():
116
- data = request.json
117
- if not data or "messages" not in data:
118
- return jsonify({"error": "请求缺少messages字段"}), 400
119
-
120
- messages = data["messages"]
121
- openai_model = data.get("model", "gpt-4o")
122
- endpoint_id = get_endpoint_id(openai_model)
123
- is_stream = bool(data.get("stream", False))
124
-
125
- user_msg = None
126
- for msg in reversed(messages):
127
- if msg.get("role") == "user":
128
- user_msg = msg.get("content")
129
- break
130
- if user_msg is None:
131
- return jsonify({"error": "未找到用户消息"}), 400
132
-
133
- def with_valid_key(func):
134
- bad_cnt = 0
135
- max_retry = len(keymgr.key_list)*2
136
- while bad_cnt < max_retry:
137
- key = keymgr.get()
138
- try:
139
- return func(key)
140
- except Exception as e:
141
- if hasattr(e, 'response'):
142
- r = e.response
143
- if r.status_code in (401, 403, 429, 500):
144
- keymgr.mark_bad(key)
145
- bad_cnt += 1
146
- continue
147
- raise
148
- return jsonify({"error": "没有可用API KEY,请补充新KEY或联系技术支持"}), 500
149
-
150
- if is_stream:
151
- def generate():
152
- def do_once(apikey):
153
- sid = create_session(apikey)
154
- url = f"{ONDEMAND_API_BASE}/sessions/{sid}/query"
155
- payload = {
156
- "query": user_msg,
157
- "endpointId": endpoint_id,
158
- "pluginIds": [],
159
- "responseMode": "stream"
160
- }
161
- headers = {"apikey": apikey, "Content-Type": "application/json", "Accept": "text/event-stream"}
162
- with requests.post(url, json=payload, headers=headers, stream=True, timeout=120) as resp:
163
- if resp.status_code != 200:
164
- raise requests.HTTPError(response=resp)
165
- answer_acc = ""
166
- first_chunk = True
167
- for line in resp.iter_lines():
168
- if not line:
169
- continue
170
- line = line.decode("utf-8")
171
- if line.startswith("data:"):
172
- datapart = line[5:].strip()
173
- if datapart == "[DONE]":
174
- yield "data: [DONE]\n\n"
175
- break
176
- elif datapart.startswith("[ERROR]:"):
177
- err_json = datapart[len("[ERROR]:"):].strip()
178
- yield format_openai_sse_delta({"error": err_json})
179
- break
180
- else:
181
- try:
182
- js = json.loads(datapart)
183
- except Exception:
184
- continue
185
- if js.get("eventType") == "fulfillment":
186
- delta = js.get("answer", "")
187
- answer_acc += delta
188
- chunk = {
189
- "id": "chatcmpl-" + str(uuid.uuid4())[:8],
190
- "object": "chat.completion.chunk",
191
- "created": int(time.time()),
192
- "model": openai_model,
193
- "choices": [{
194
- "delta": {
195
- "role": "assistant",
196
- "content": delta
197
- } if first_chunk else {
198
- "content": delta
199
- },
200
- "index": 0,
201
- "finish_reason": None
202
- }]
203
- }
204
- yield format_openai_sse_delta(chunk)
205
- first_chunk = False
206
- yield "data: [DONE]\n\n"
207
- yield from with_valid_key(do_once)
208
- return Response(generate(), content_type='text/event-stream')
209
-
210
- def nonstream(apikey):
211
- sid = create_session(apikey)
212
- url = f"{ONDEMAND_API_BASE}/sessions/{sid}/query"
213
- payload = {
214
- "query": user_msg,
215
- "endpointId": endpoint_id,
216
- "pluginIds": [],
217
- "responseMode": "sync"
218
- }
219
- headers = {"apikey": apikey, "Content-Type": "application/json"}
220
- resp = requests.post(url, json=payload, headers=headers, timeout=120)
221
- if resp.status_code != 200:
222
- raise requests.HTTPError(response=resp)
223
- ai_response = resp.json()["data"]["answer"]
224
- resp_obj = {
225
- "id": "chatcmpl-" + str(uuid.uuid4())[:8],
226
- "object": "chat.completion",
227
- "created": int(time.time()),
228
- "model": openai_model,
229
- "choices": [
230
- {
231
- "index": 0,
232
- "message": {"role": "assistant", "content": ai_response},
233
- "finish_reason": "stop"
234
- }
235
- ],
236
- "usage": {}
237
- }
238
- return jsonify(resp_obj)
239
-
240
- return with_valid_key(nonstream)
241
-
242
- @app.route("/v1/models", methods=["GET"])
243
- def models():
244
- model_objs = []
245
- for mdl in MODEL_MAP.keys():
246
- model_objs.append({
247
- "id": mdl,
248
- "object": "model",
249
- "owned_by": "ondemand-proxy"
250
- })
251
- uniq = {m["id"]: m for m in model_objs}.values()
252
- return jsonify({
253
- "object": "list",
254
- "data": list(uniq)
255
- })
256
-
257
- if __name__ == "__main__":
258
- log_fmt = '[%(asctime)s] %(levelname)s: %(message)s'
259
- logging.basicConfig(level=logging.INFO, format=log_fmt)
260
- print("======== OnDemand KEY池数量:", len(ONDEMAND_APIKEYS), "========")
261
- app.run(host="0.0.0.0", port=7860, debug=False)
 
1
+ from flask import Flask, request, Response, jsonify
2
+ import requests
3
+ import uuid
4
+ import time
5
+ import json
6
+ import threading
7
+ import logging
8
+ import os
9
+
10
+ # ====== 读取 Huggingface Secret 配置的私有key =======
11
+ PRIVATE_KEY = os.environ.get("PRIVATE_KEY", "")
12
+ SAFE_HEADER = "X-API-KEY"
13
+
14
+ # 全局接口访问权限检查
15
+ def check_private_key():
16
+ # 可以在这里放宽部分接口,比如首页等
17
+ if request.path in ["/", "/favicon.ico"]:
18
+ return
19
+ key = request.headers.get(SAFE_HEADER)
20
+ if not key or key != PRIVATE_KEY:
21
+ return jsonify({"error": "Unauthorized, must provide correct X-API-KEY"}), 401
22
+
23
+ # 应用所有API鉴权
24
+ app = Flask(__name__)
25
+ app.before_request(check_private_key)
26
+
27
+ # ========== KEY池(每行一个)==========
28
+ ONDEMAND_APIKEYS = os.environ.get("ONDEMAND_APIKEYS", "")
29
+ BAD_KEY_RETRY_INTERVAL = 600 # 秒
30
+
31
+ # ========== OnDemand模型映射 ==========
32
+ MODEL_MAP = {
33
+ "gpto3-mini": "predefined-openai-gpto3-mini",
34
+ "gpt-4o": "predefined-openai-gpt4o",
35
+ "gpt-4.1": "predefined-openai-gpt4.1",
36
+ "gpt-4.1-mini": "predefined-openai-gpt4.1-mini",
37
+ "gpt-4.1-nano": "predefined-openai-gpt4.1-nano",
38
+ "gpt-4o-mini": "predefined-openai-gpt4o-mini",
39
+ "deepseek-v3": "predefined-deepseek-v3",
40
+ "deepseek-r1": "predefined-deepseek-r1",
41
+ "claude-3.7-sonnet": "predefined-claude-3.7-sonnet",
42
+ "gemini-2.0-flash": "predefined-gemini-2.0-flash",
43
+ }
44
+ DEFAULT_ONDEMAND_MODEL = "predefined-openai-gpt4o"
45
+ # ==========================================
46
+
47
+ class KeyManager:
48
+ def __init__(self, key_list):
49
+ self.key_list = list(key_list)
50
+ self.lock = threading.Lock()
51
+ self.key_status = {k: {"bad": False, "bad_ts": None} for k in self.key_list}
52
+ self.idx = 0
53
+
54
+ def display_key(self, key):
55
+ return f"{key[:6]}...{key[-4:]}"
56
+
57
+ def get(self):
58
+ with self.lock:
59
+ total = len(self.key_list)
60
+ for _ in range(total):
61
+ key = self.key_list[self.idx]
62
+ self.idx = (self.idx + 1) % total
63
+ s = self.key_status[key]
64
+ if not s["bad"]:
65
+ print(f"【对话请求】【使用API KEY: {self.display_key(key)}】【状态:正常】")
66
+ return key
67
+ if s["bad"] and s["bad_ts"]:
68
+ ago = time.time() - s["bad_ts"]
69
+ if ago >= BAD_KEY_RETRY_INTERVAL:
70
+ print(f"【KEY自动尝试恢复】API KEY: {self.display_key(key)} 满足重试周期,标记为正常")
71
+ self.key_status[key]["bad"] = False
72
+ self.key_status[key]["bad_ts"] = None
73
+ print(f"【对话请求】【使用API KEY: {self.display_key(key)}】【状态:正常】")
74
+ return key
75
+ print("【警告】全部KEY已被禁用,强制选用第一个KEY继续尝试:", self.display_key(self.key_list[0]))
76
+ for k in self.key_list:
77
+ self.key_status[k]["bad"] = False
78
+ self.key_status[k]["bad_ts"] = None
79
+ self.idx = 0
80
+ print(f"【对话请求】【使用API KEY: {self.display_key(self.key_list[0])}】【状态:强制尝试(全部异常)】")
81
+ return self.key_list[0]
82
+
83
+ def mark_bad(self, key):
84
+ with self.lock:
85
+ if key in self.key_status and not self.key_status[key]["bad"]:
86
+ print(f"【禁用KEY】API KEY: {self.display_key(key)},接口返回无效(将在{BAD_KEY_RETRY_INTERVAL//60}分钟后自动重试)")
87
+ self.key_status[key]["bad"] = True
88
+ self.key_status[key]["bad_ts"] = time.time()
89
+
90
+ keymgr = KeyManager(ONDEMAND_APIKEYS)
91
+
92
+ ONDEMAND_API_BASE = "https://api.on-demand.io/chat/v1"
93
+
94
+ def get_endpoint_id(openai_model):
95
+ m = str(openai_model or "").lower().replace(" ", "")
96
+ return MODEL_MAP.get(m, DEFAULT_ONDEMAND_MODEL)
97
+
98
+ def create_session(apikey, external_user_id=None, plugin_ids=None):
99
+ url = f"{ONDEMAND_API_BASE}/sessions"
100
+ payload = {"externalUserId": external_user_id or str(uuid.uuid4())}
101
+ if plugin_ids is not None:
102
+ payload["pluginIds"] = plugin_ids
103
+ headers = {"apikey": apikey, "Content-Type": "application/json"}
104
+ resp = requests.post(url, json=payload, headers=headers, timeout=20)
105
+ resp.raise_for_status()
106
+ return resp.json()["data"]["id"]
107
+
108
+ def format_openai_sse_delta(chunk_str):
109
+ return f"data: {json.dumps(chunk_str, ensure_ascii=False)}\n\n"
110
+
111
+ @app.route("/v1/chat/completions", methods=["POST"])
112
+ def chat_completions():
113
+ data = request.json
114
+ if not data or "messages" not in data:
115
+ return jsonify({"error": "请求缺少messages字段"}), 400
116
+
117
+ messages = data["messages"]
118
+ openai_model = data.get("model", "gpt-4o")
119
+ endpoint_id = get_endpoint_id(openai_model)
120
+ is_stream = bool(data.get("stream", False))
121
+
122
+ user_msg = None
123
+ for msg in reversed(messages):
124
+ if msg.get("role") == "user":
125
+ user_msg = msg.get("content")
126
+ break
127
+ if user_msg is None:
128
+ return jsonify({"error": "未找到用户消息"}), 400
129
+
130
+ def with_valid_key(func):
131
+ bad_cnt = 0
132
+ max_retry = len(keymgr.key_list)*2
133
+ while bad_cnt < max_retry:
134
+ key = keymgr.get()
135
+ try:
136
+ return func(key)
137
+ except Exception as e:
138
+ if hasattr(e, 'response'):
139
+ r = e.response
140
+ if r.status_code in (401, 403, 429, 500):
141
+ keymgr.mark_bad(key)
142
+ bad_cnt += 1
143
+ continue
144
+ raise
145
+ return jsonify({"error": "没有可用API KEY,请补充新KEY或联系技术支持"}), 500
146
+
147
+ if is_stream:
148
+ def generate():
149
+ def do_once(apikey):
150
+ sid = create_session(apikey)
151
+ url = f"{ONDEMAND_API_BASE}/sessions/{sid}/query"
152
+ payload = {
153
+ "query": user_msg,
154
+ "endpointId": endpoint_id,
155
+ "pluginIds": [],
156
+ "responseMode": "stream"
157
+ }
158
+ headers = {"apikey": apikey, "Content-Type": "application/json", "Accept": "text/event-stream"}
159
+ with requests.post(url, json=payload, headers=headers, stream=True, timeout=120) as resp:
160
+ if resp.status_code != 200:
161
+ raise requests.HTTPError(response=resp)
162
+ answer_acc = ""
163
+ first_chunk = True
164
+ for line in resp.iter_lines():
165
+ if not line:
166
+ continue
167
+ line = line.decode("utf-8")
168
+ if line.startswith("data:"):
169
+ datapart = line[5:].strip()
170
+ if datapart == "[DONE]":
171
+ yield "data: [DONE]\n\n"
172
+ break
173
+ elif datapart.startswith("[ERROR]:"):
174
+ err_json = datapart[len("[ERROR]:"):].strip()
175
+ yield format_openai_sse_delta({"error": err_json})
176
+ break
177
+ else:
178
+ try:
179
+ js = json.loads(datapart)
180
+ except Exception:
181
+ continue
182
+ if js.get("eventType") == "fulfillment":
183
+ delta = js.get("answer", "")
184
+ answer_acc += delta
185
+ chunk = {
186
+ "id": "chatcmpl-" + str(uuid.uuid4())[:8],
187
+ "object": "chat.completion.chunk",
188
+ "created": int(time.time()),
189
+ "model": openai_model,
190
+ "choices": [{
191
+ "delta": {
192
+ "role": "assistant",
193
+ "content": delta
194
+ } if first_chunk else {
195
+ "content": delta
196
+ },
197
+ "index": 0,
198
+ "finish_reason": None
199
+ }]
200
+ }
201
+ yield format_openai_sse_delta(chunk)
202
+ first_chunk = False
203
+ yield "data: [DONE]\n\n"
204
+ yield from with_valid_key(do_once)
205
+ return Response(generate(), content_type='text/event-stream')
206
+
207
+ def nonstream(apikey):
208
+ sid = create_session(apikey)
209
+ url = f"{ONDEMAND_API_BASE}/sessions/{sid}/query"
210
+ payload = {
211
+ "query": user_msg,
212
+ "endpointId": endpoint_id,
213
+ "pluginIds": [],
214
+ "responseMode": "sync"
215
+ }
216
+ headers = {"apikey": apikey, "Content-Type": "application/json"}
217
+ resp = requests.post(url, json=payload, headers=headers, timeout=120)
218
+ if resp.status_code != 200:
219
+ raise requests.HTTPError(response=resp)
220
+ ai_response = resp.json()["data"]["answer"]
221
+ resp_obj = {
222
+ "id": "chatcmpl-" + str(uuid.uuid4())[:8],
223
+ "object": "chat.completion",
224
+ "created": int(time.time()),
225
+ "model": openai_model,
226
+ "choices": [
227
+ {
228
+ "index": 0,
229
+ "message": {"role": "assistant", "content": ai_response},
230
+ "finish_reason": "stop"
231
+ }
232
+ ],
233
+ "usage": {}
234
+ }
235
+ return jsonify(resp_obj)
236
+
237
+ return with_valid_key(nonstream)
238
+
239
+ @app.route("/v1/models", methods=["GET"])
240
+ def models():
241
+ model_objs = []
242
+ for mdl in MODEL_MAP.keys():
243
+ model_objs.append({
244
+ "id": mdl,
245
+ "object": "model",
246
+ "owned_by": "ondemand-proxy"
247
+ })
248
+ uniq = {m["id"]: m for m in model_objs}.values()
249
+ return jsonify({
250
+ "object": "list",
251
+ "data": list(uniq)
252
+ })
253
+
254
+ if __name__ == "__main__":
255
+ log_fmt = '[%(asctime)s] %(levelname)s: %(message)s'
256
+ logging.basicConfig(level=logging.INFO, format=log_fmt)
257
+ print("======== OnDemand KEY池数量:", len(ONDEMAND_APIKEYS), "========")
258
+ app.run(host="0.0.0.0", port=7860, debug=False)