tianlong12 commited on
Commit
b767805
·
verified ·
1 Parent(s): 8b82275

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -61
app.py CHANGED
@@ -1,83 +1,80 @@
1
  from flask import Flask, request, Response, stream_with_context, jsonify
2
- from openai import OpenAI
3
  import json
4
 
5
  app = Flask(__name__)
6
-
7
  @app.route('/')
8
  def index():
9
  return "Hello, this is the root page of your Flask application!"
10
 
11
- @app.route('/hf/v1/chat/completions', methods=['POST'])
12
- def chat():
13
  try:
14
- # 验证请求头中的API密钥
15
- auth_header = request.headers.get('Authorization')
16
- if not auth_header or not auth_header.startswith('Bearer '):
17
- return jsonify({"error": "Unauthorized"}), 401
18
-
19
- api_key = auth_header.split(" ")[1]
20
- base_url= auth_header.split(" ")[2]
21
 
 
22
  data = request.json
23
- #print("Received data:", data) # 打印请求体以进行调试
24
-
25
- # 验证请求格式
26
- if not data or 'messages' not in data or 'model' not in data:
27
- return jsonify({"error": "Missing 'messages' or 'model' in request body"}), 400
28
 
29
- model = data['model']
30
- messages = data['messages']
31
- temperature = data.get('temperature', 0.7) # 默认值0.7
32
- #max_tokens = calculate_max_tokens(model, messages, requested_max_tokens)
33
- top_p = data.get('top_p', 1.0) # 默认值1.0
34
- n = data.get('n', 1) # 默认值1
35
- stream = data.get('stream', False) # 默认值False
36
- functions = data.get('functions', None) # Functions for function calling
37
- function_call = data.get('function_call', None) # Specific function call request
 
 
 
 
 
 
 
 
38
 
39
- # 创建每个请求的 OpenAI 客户端实例
40
- client = OpenAI(
41
- api_key=api_key,
42
- base_url=base_url,
43
- )
44
 
45
- # 处理模型响应
46
- if stream:
47
- # 处理流式响应
48
- def generate():
49
- response = client.chat.completions.create(
50
- model=model,
51
- messages=messages,
52
- temperature=temperature,
53
- #max_tokens=max_tokens,
54
- top_p=top_p,
55
- n=n,
56
- stream=True,
57
- functions=functions,
58
- function_call=function_call
59
- )
60
- for chunk in response:
61
- yield f"data: {json.dumps(chunk.to_dict())}\n\n"
 
 
 
 
62
 
63
- return Response(stream_with_context(generate()), content_type='text/event-stream')
64
  else:
65
- # 处理非流式响应
66
- response = client.chat.completions.create(
67
- model=model,
68
- messages=messages,
69
- temperature=temperature,
70
- #max_tokens=max_tokens,
71
- top_p=top_p,
72
- n=n,
73
- functions=functions,
74
- function_call=function_call,
75
- )
76
- return jsonify(response.to_dict())
77
 
78
  except Exception as e:
79
  print("Exception:", e)
80
  return jsonify({"error": str(e)}), 500
81
 
82
  if __name__ == "__main__":
83
- app.run(host='0.0.0.0', port=7860, threaded=True)
 
1
  from flask import Flask, request, Response, stream_with_context, jsonify
2
+ import requests
3
  import json
4
 
5
  app = Flask(__name__)
6
+
7
  @app.route('/')
8
  def index():
9
  return "Hello, this is the root page of your Flask application!"
10
 
11
+ @app.route('/<path:subpath>', methods=['POST'])
12
+ def forward_to_target(subpath):
13
  try:
14
+ # 构建目标 URL
15
+ target_url = f'http://{subpath}'
 
 
 
 
 
16
 
17
+ # 获取请求数据
18
  data = request.json
 
 
 
 
 
19
 
20
+ # 检查是否是特定路径需要特殊处理
21
+ if '/v1/chat/completions' in subpath:
22
+ auth_header = request.headers.get('Authorization')
23
+ if not auth_header or not auth_header.startswith('Bearer '):
24
+ return jsonify({"error": "Unauthorized"}), 401
25
+
26
+ api_key = auth_header.split(" ")[1]
27
+ target_url = f"http://{subpath.split('/')[0]}"
28
+
29
+ model = data['model']
30
+ messages = data['messages']
31
+ temperature = data.get('temperature', 0.7) # 默认值0.7
32
+ top_p = data.get('top_p', 1.0) # 默认值1.0
33
+ n = data.get('n', 1) # 默认值1
34
+ stream = data.get('stream', False) # 默认值False
35
+ functions = data.get('functions', None) # Functions for function calling
36
+ function_call = data.get('function_call', None) # Specific function call request
37
 
38
+ headers = {
39
+ 'Authorization': f'Bearer {api_key}',
40
+ 'Content-Type': 'application/json'
41
+ }
 
42
 
43
+ payload = {
44
+ 'model': model,
45
+ 'messages': messages,
46
+ 'temperature': temperature,
47
+ 'top_p': top_p,
48
+ 'n': n,
49
+ 'stream': stream,
50
+ 'functions': functions,
51
+ 'function_call': function_call
52
+ }
53
+
54
+ if stream:
55
+ def generate():
56
+ with requests.post(target_url, headers=headers, json=payload, stream=True) as r:
57
+ for chunk in r.iter_content(chunk_size=8192):
58
+ if chunk:
59
+ yield f"data: {chunk.decode('utf-8')}\n\n"
60
+ return Response(stream_with_context(generate()), content_type='text/event-stream')
61
+ else:
62
+ response = requests.post(target_url, headers=headers, json=payload)
63
+ return jsonify(response.json())
64
 
 
65
  else:
66
+ # 获取请求头
67
+ headers = {key: value for key, value in request.headers if key != 'Host'}
68
+
69
+ # 转发请求到目标 URL
70
+ response = requests.post(target_url, headers=headers, json=data)
71
+
72
+ # 返回目标 URL 的响应
73
+ return Response(response.content, status=response.status_code, content_type=response.headers['Content-Type'])
 
 
 
 
74
 
75
  except Exception as e:
76
  print("Exception:", e)
77
  return jsonify({"error": str(e)}), 500
78
 
79
  if __name__ == "__main__":
80
+ app.run(host='0.0.0.0', port=7860, threaded=True)