tianlong12's picture
Update app.py
31853b5 verified
raw
history blame
6.45 kB
import os
import requests
from flask import Flask, request, Response, stream_with_context, jsonify
import json
import logging
app = Flask(__name__)
logging.basicConfig(level=logging.DEBUG)
DEEPINFRA_API_URL = "https://api.deepinfra.com/v1/openai/chat/completions"
API_KEY = os.environ.get("API_KEY")
def authenticate():
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return False
token = auth_header.split(" ")[1]
return token == API_KEY
@app.route('/hf/v1/chat/completions', methods=['POST'])
def chat_completions():
logging.debug(f"Received headers: {request.headers}")
logging.debug(f"Received body: {request.get_data(as_text=True)}")
if not authenticate():
logging.warning("Unauthorized access attempt")
return jsonify({"error": "Unauthorized"}), 401
try:
openai_request = request.json
except json.JSONDecodeError:
logging.error("Invalid JSON in request body")
return jsonify({"error": "Invalid JSON in request body"}), 400
logging.info(f"Received request: {openai_request}")
deepinfra_request = {
"model": openai_request.get("model", "meta-llama/Meta-Llama-3.1-405B-Instruct"),
"temperature": openai_request.get("temperature", 0.7),
"max_tokens": openai_request.get("max_tokens", 1000),
"stream": openai_request.get("stream", False),
"messages": openai_request.get("messages", [])
}
headers = {
"Content-Type": "application/json",
"Accept": "text/event-stream" if deepinfra_request["stream"] else "application/json"
}
try:
response = requests.post(DEEPINFRA_API_URL, json=deepinfra_request, headers=headers, stream=deepinfra_request["stream"])
response.raise_for_status()
logging.debug(f"DeepInfra API response status: {response.status_code}")
logging.debug(f"DeepInfra API response headers: {response.headers}")
except requests.RequestException as e:
logging.error(f"Error calling DeepInfra API: {str(e)}")
return jsonify({"error": "Failed to call DeepInfra API"}), 500
if deepinfra_request["stream"]:
def generate():
full_content = ""
for line in response.iter_lines():
if not line:
logging.warning("Received empty line from DeepInfra API")
continue
try:
line_text = line.decode('utf-8')
if line_text.startswith('data: '):
data_text = line_text.split('data: ', 1)[1]
if data_text == "[DONE]":
yield f"data: [DONE]\n\n"
break
data = json.loads(data_text)
delta_content = data['choices'][0]['delta'].get('content', '')
full_content += delta_content
openai_format = {
"id": data['id'],
"object": "chat.completion.chunk",
"created": data['created'],
"model": data['model'],
"choices": [
{
"index": 0,
"delta": {
"content": delta_content
},
"finish_reason": data['choices'][0].get('finish_reason')
}
]
}
yield f"data: {json.dumps(openai_format)}\n\n"
except json.JSONDecodeError as e:
logging.error(f"JSON decode error: {e}. Raw line: {line}")
continue
except Exception as e:
logging.error(f"Error processing line: {e}. Raw line: {line}")
continue
# Send the final usage information
if 'usage' in data:
final_chunk = {
"id": data['id'],
"object": "chat.completion.chunk",
"created": data['created'],
"model": data['model'],
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": data['usage']
}
yield f"data: {json.dumps(final_chunk)}\n\n"
return Response(stream_with_context(generate()), content_type='text/event-stream')
else:
try:
deepinfra_response = response.json()
logging.info(f"Received response from DeepInfra: {deepinfra_response}")
if 'error' in deepinfra_response:
return jsonify({"error": deepinfra_response['error']}), 400
if 'choices' not in deepinfra_response or not deepinfra_response['choices']:
return jsonify({"error": "Unexpected response format from DeepInfra"}), 500
openai_response = {
"id": deepinfra_response.get("id", ""),
"object": "chat.completion",
"created": deepinfra_response.get("created", 0),
"model": deepinfra_response.get("model", ""),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": deepinfra_response["choices"][0]["message"]["content"]
},
"finish_reason": deepinfra_response["choices"][0].get("finish_reason", "stop")
}
],
"usage": deepinfra_response.get("usage", {})
}
return json.dumps(openai_response), 200, {'Content-Type': 'application/json'}
except Exception as e:
logging.error(f"Error processing DeepInfra response: {str(e)}")
return jsonify({"error": "Failed to process DeepInfra response"}), 500
@app.route('/')
def home():
return "Welcome to the API proxy server. Please use the /hf/v1/chat/completions endpoint for chat completions."
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)