File size: 3,693 Bytes
4965971
b767805
4965971
6504035
 
b767805
6504035
 
 
 
b767805
 
6504035
b767805
07a1a55
fafe777
6504035
b767805
4965971
fafe777
6504035
b767805
 
 
 
 
 
 
cd0b413
fafe777
b767805
 
 
 
 
 
 
 
 
6504035
b767805
 
 
 
6504035
b767805
 
 
 
 
 
 
 
 
 
fafe777
b767805
 
 
 
 
 
 
 
 
 
fafe777
b767805
6504035
 
b767805
 
fafe777
b767805
 
 
fafe777
b767805
 
 
6504035
fafe777
 
 
6504035
fafe777
6504035
 
 
b767805
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from flask import Flask, request, Response, stream_with_context, jsonify
import requests
import json

app = Flask(__name__)

@app.route('/')
def index():
    return "Hello, this is the root page of your Flask application!"

@app.route('/<path:subpath>', methods=['POST'])
def forward_to_target(subpath):
    try:
        # 构建目标 URL
        target_url = f'https://{subpath}'
        print(f"Target URL: {target_url}")  # 调试信息

        # 获取请求数据
        data = request.json
        print(f"Request data: {data}")  # 调试信息

        # 检查是否是特定路径需要特殊处理
        if '/v1/chat/completions' in subpath:
            auth_header = request.headers.get('Authorization')
            if not auth_header or not auth_header.startswith('Bearer '):
                return jsonify({"error": "Unauthorized"}), 401

            api_key = auth_header.split(" ")[1]
            target_url = f"https://{subpath.split('/')[1]}"
            print(f"Adjusted target URL for special handling: {target_url}")  # 调试信息

            model = data['model']
            messages = data['messages']
            temperature = data.get('temperature', 0.7)  # 默认值0.7
            top_p = data.get('top_p', 1.0)              # 默认值1.0
            n = data.get('n', 1)                        # 默认值1
            stream = data.get('stream', False)          # 默认值False
            functions = data.get('functions', None)     # Functions for function calling
            function_call = data.get('function_call', None)  # Specific function call request

            headers = {
                'Authorization': f'Bearer {api_key}',
                'Content-Type': 'application/json'
            }

            payload = {
                'model': model,
                'messages': messages,
                'temperature': temperature,
                'top_p': top_p,
                'n': n,
                'stream': stream,
                'functions': functions,
                'function_call': function_call
            }
            print(f"Payload: {payload}")  # 调试信息

            if stream:
                def generate():
                    with requests.post(target_url, headers=headers, json=payload, stream=True) as r:
                        for chunk in r.iter_content(chunk_size=8192):
                            if chunk:
                                yield f"data: {chunk.decode('utf-8')}\n\n"
                return Response(stream_with_context(generate()), content_type='text/event-stream')
            else:
                response = requests.post(target_url, headers=headers, json=payload)
                response.raise_for_status()  # 确保抛出请求错误
                return jsonify(response.json())

        else:
            # 获取请求头
            headers = {key: value for key, value in request.headers if key != 'Host'}
            print(f"Headers: {headers}")  # 调试信息

            # 转发请求到目标 URL
            response = requests.post(target_url, headers=headers, json=data)
            response.raise_for_status()  # 确保抛出请求错误

            # 返回目标 URL 的响应
            return Response(response.content, status=response.status_code, content_type=response.headers['Content-Type'])

    except requests.exceptions.RequestException as e:
        print(f"RequestException: {e}")  # 调试信息
        return jsonify({"error": str(e)}), 500
    except Exception as e:
        print(f"Exception: {e}")  # 调试信息
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860, threaded=True)