File size: 3,467 Bytes
97ce33c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

import random
import json
import aiohttp
import asyncio
from aiohttp import web
from datetime import datetime

# Debug mode switch
DEBUG_MODE = False

# Define fixed model information
DEFAULT_MODEL = "llama3.1-8b"
ALTERNATE_MODEL = "llama3.1-70b"
FIXED_URL = "https://api.cerebras.ai/v1/chat/completions"
FIXED_TEMPERATURE = 0.2
FIXED_TOP_P = 1
FIXED_MAX_TOKENS = 4096

# Log function for basic information
def log_basic_info(message):
    timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(f"[{timestamp}] {message}")

# Asynchronous function to send request and print debug information
async def send_request(auth_tokens, data):
    try:
        headers = {
            "accept": "application/json",
            "authorization": f"Bearer {auth_tokens[0]}",
            "content-type": "application/json"
        }

        requested_model = data.get("model", DEFAULT_MODEL)
        model_to_use = ALTERNATE_MODEL if requested_model == ALTERNATE_MODEL else DEFAULT_MODEL
        
        log_basic_info(f"Requested model: {requested_model}, Using model: {model_to_use}")
        
        payload = {
            "messages": data.get("messages", []),
            "model": model_to_use,
            "temperature": FIXED_TEMPERATURE,
            "top_p": FIXED_TOP_P,
            "max_tokens": FIXED_MAX_TOKENS
        }

        if DEBUG_MODE:
            print("Request Payload:", json.dumps(payload, indent=4))
            print("Request Headers:", headers)

        async with aiohttp.ClientSession() as session:
            async with session.post(FIXED_URL, headers=headers, json=payload) as resp:
                response_text = await resp.text()
                response_json = json.loads(response_text)

                total_tokens = response_json.get('usage', {}).get('total_tokens', 'N/A')
                total_time = response_json.get('time_info', {}).get('total_time', 'N/A')

                log_basic_info(f"Path: {FIXED_URL}, Status Code: {resp.status}, Total Tokens Used: {total_tokens}, Total Time: {total_time:.3f} seconds")

                return response_text

    except Exception as e:
        log_basic_info(f"Exception occurred: {str(e)}")

# Main handler function
async def handle_request(request):
    try:
        request_data = await request.json()
        headers = dict(request.headers)

        authorization_header = headers.get('Authorization', '')
        auth_tokens = [auth.strip() for auth in authorization_header.replace('Bearer ', '').split(',')]
        
        if not auth_tokens:
            return web.json_response({"error": "Missing Authorization token"}, status=400)
        
        auth_token = random.choice(auth_tokens)
        headers['Authorization'] = f"Bearer {auth_token}"

        log_basic_info(f"Received request for path: {request.path}")

        if DEBUG_MODE:
            print("Received Request Data:", json.dumps(request_data, indent=4))
            print("Received Headers:", headers)

        response_text = await send_request(auth_tokens, request_data)

        return web.json_response(json.loads(response_text))

    except Exception as e:
        log_basic_info(f"Exception occurred in handling request: {str(e)}")
        return web.json_response({"error": str(e)}, status=500)

# Set up routes
app = web.Application()
app.router.add_post('/hf/v1/chat/completions', handle_request)

# Run the server
if __name__ == '__main__':
    web.run_app(app, host='0.0.0.0', port=7860)