import os import requests from flask import Flask, request, Response, stream_with_context, jsonify import json app = Flask(__name__) DEEPINFRA_API_URL = "https://api.deepinfra.com/v1/openai/chat/completions" API_KEY = os.environ.get("API_KEY") def authenticate(): auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): return False token = auth_header.split(" ")[1] return token == API_KEY @app.route('/v1/chat/completions', methods=['POST']) def chat_completions(): if not authenticate(): return jsonify({"error": "Unauthorized"}), 401 # 获取OpenAI格式的请求 openai_request = request.json # 转换为DeepInfra格式 deepinfra_request = { "model": openai_request.get("model", "meta-llama/Meta-Llama-3.1-70B-Instruct"), "temperature": openai_request.get("temperature", 0.7), "max_tokens": openai_request.get("max_tokens", 1000), "stream": openai_request.get("stream", False), "messages": openai_request.get("messages", []) } headers = { "Content-Type": "application/json", "Accept": "text/event-stream" if deepinfra_request["stream"] else "application/json" } # 发送请求到DeepInfra API response = requests.post(DEEPINFRA_API_URL, json=deepinfra_request, headers=headers, stream=True) if deepinfra_request["stream"]: # 流式响应 def generate(): for line in response.iter_lines(): if line: yield f"data: {line.decode('utf-8')}\n\n" return Response(stream_with_context(generate()), content_type='text/event-stream') else: # 非流式响应 deepinfra_response = response.json() openai_response = { "id": deepinfra_response.get("id", ""), "object": "chat.completion", "created": deepinfra_response.get("created", 0), "model": deepinfra_response.get("model", ""), "choices": [ { "index": 0, "message": { "role": "assistant", "content": deepinfra_response["choices"][0]["message"]["content"] }, "finish_reason": deepinfra_response["choices"][0].get("finish_reason", "stop") } ], "usage": deepinfra_response.get("usage", {}) } return json.dumps(openai_response), 200, {'Content-Type': 'application/json'} if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)