|
import os |
|
import requests |
|
from flask import Flask, request, Response, stream_with_context, jsonify |
|
import json |
|
import logging |
|
|
|
app = Flask(__name__) |
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
DEEPINFRA_API_URL = "https://api.deepinfra.com/v1/openai/chat/completions" |
|
API_KEY = os.environ.get("API_KEY") |
|
|
|
def authenticate(): |
|
auth_header = request.headers.get("Authorization") |
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
return False |
|
token = auth_header.split(" ")[1] |
|
return token == API_KEY |
|
|
|
@app.route('/hf/v1/chat/completions', methods=['POST']) |
|
def chat_completions(): |
|
logging.debug(f"Received headers: {request.headers}") |
|
logging.debug(f"Received body: {request.get_data(as_text=True)}") |
|
|
|
if not authenticate(): |
|
logging.warning("Unauthorized access attempt") |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
try: |
|
openai_request = request.json |
|
except json.JSONDecodeError: |
|
logging.error("Invalid JSON in request body") |
|
return jsonify({"error": "Invalid JSON in request body"}), 400 |
|
|
|
logging.info(f"Received request: {openai_request}") |
|
|
|
deepinfra_request = { |
|
"model": openai_request.get("model", "meta-llama/Meta-Llama-3.1-405B-Instruct"), |
|
"temperature": openai_request.get("temperature", 0.7), |
|
"max_tokens": openai_request.get("max_tokens", 1000), |
|
"stream": openai_request.get("stream", False), |
|
"messages": openai_request.get("messages", []) |
|
} |
|
|
|
headers = { |
|
"Content-Type": "application/json", |
|
"Accept": "text/event-stream" if deepinfra_request["stream"] else "application/json" |
|
} |
|
|
|
try: |
|
response = requests.post(DEEPINFRA_API_URL, json=deepinfra_request, headers=headers, stream=deepinfra_request["stream"]) |
|
response.raise_for_status() |
|
logging.debug(f"DeepInfra API response status: {response.status_code}") |
|
logging.debug(f"DeepInfra API response headers: {response.headers}") |
|
except requests.RequestException as e: |
|
logging.error(f"Error calling DeepInfra API: {str(e)}") |
|
return jsonify({"error": "Failed to call DeepInfra API"}), 500 |
|
|
|
if deepinfra_request["stream"]: |
|
def generate(): |
|
full_content = "" |
|
for line in response.iter_lines(): |
|
if not line: |
|
logging.warning("Received empty line from DeepInfra API") |
|
continue |
|
try: |
|
line_text = line.decode('utf-8') |
|
if line_text.startswith('data: '): |
|
data_text = line_text.split('data: ', 1)[1] |
|
if data_text == "[DONE]": |
|
yield f"data: [DONE]\n\n" |
|
break |
|
data = json.loads(data_text) |
|
|
|
delta_content = data['choices'][0]['delta'].get('content', '') |
|
full_content += delta_content |
|
|
|
openai_format = { |
|
"id": data['id'], |
|
"object": "chat.completion.chunk", |
|
"created": data['created'], |
|
"model": data['model'], |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"delta": { |
|
"content": delta_content |
|
}, |
|
"finish_reason": data['choices'][0].get('finish_reason') |
|
} |
|
] |
|
} |
|
yield f"data: {json.dumps(openai_format)}\n\n" |
|
except json.JSONDecodeError as e: |
|
logging.error(f"JSON decode error: {e}. Raw line: {line}") |
|
continue |
|
except Exception as e: |
|
logging.error(f"Error processing line: {e}. Raw line: {line}") |
|
continue |
|
|
|
|
|
if 'usage' in data: |
|
final_chunk = { |
|
"id": data['id'], |
|
"object": "chat.completion.chunk", |
|
"created": data['created'], |
|
"model": data['model'], |
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], |
|
"usage": data['usage'] |
|
} |
|
yield f"data: {json.dumps(final_chunk)}\n\n" |
|
|
|
return Response(stream_with_context(generate()), content_type='text/event-stream') |
|
else: |
|
try: |
|
deepinfra_response = response.json() |
|
logging.info(f"Received response from DeepInfra: {deepinfra_response}") |
|
|
|
if 'error' in deepinfra_response: |
|
return jsonify({"error": deepinfra_response['error']}), 400 |
|
|
|
if 'choices' not in deepinfra_response or not deepinfra_response['choices']: |
|
return jsonify({"error": "Unexpected response format from DeepInfra"}), 500 |
|
|
|
openai_response = { |
|
"id": deepinfra_response.get("id", ""), |
|
"object": "chat.completion", |
|
"created": deepinfra_response.get("created", 0), |
|
"model": deepinfra_response.get("model", ""), |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": deepinfra_response["choices"][0]["message"]["content"] |
|
}, |
|
"finish_reason": deepinfra_response["choices"][0].get("finish_reason", "stop") |
|
} |
|
], |
|
"usage": deepinfra_response.get("usage", {}) |
|
} |
|
return json.dumps(openai_response), 200, {'Content-Type': 'application/json'} |
|
except Exception as e: |
|
logging.error(f"Error processing DeepInfra response: {str(e)}") |
|
return jsonify({"error": "Failed to process DeepInfra response"}), 500 |
|
|
|
@app.route('/') |
|
def home(): |
|
return "Welcome to the API proxy server. Please use the /hf/v1/chat/completions endpoint for chat completions." |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |