File size: 2,309 Bytes
91b2dc1
 
 
 
aaea5e8
461845f
aaea5e8
91b2dc1
 
aaea5e8
461845f
aaea5e8
91b2dc1
 
 
 
 
 
 
 
 
 
 
c60c816
91b2dc1
 
 
 
 
 
aaea5e8
 
91b2dc1
 
aaea5e8
 
 
 
 
 
91b2dc1
 
aaea5e8
 
 
91b2dc1
aaea5e8
 
 
 
91b2dc1
aaea5e8
 
 
 
 
 
 
91b2dc1
aaea5e8
 
 
 
 
 
 
 
c60c816
aaea5e8
 
91b2dc1
aaea5e8
 
 
 
 
 
 
 
 
 
8ce7317
91b2dc1
47f9485
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
from huggingface_hub import login
from flask import Flask, request, jsonify

# Authenticate with Hugging Face token from Secrets
login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))

API_TOKEN = os.getenv("HF_API_TOKEN")  # Set this token in your Space Secrets

# Setup
model_name = "cerebras/btlm-3b-8k-chat"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
os.environ['HF_HOME'] = '/tmp/cache'

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch_dtype,
    device_map="auto",
    trust_remote_code=True
)

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    torch_dtype=torch_dtype,
    pad_token_id=tokenizer.eos_token_id,
    trust_remote_code=True
)

app = Flask(__name__)

@app.route("/")
def home():
    return "API is running"

@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
    auth_header = request.headers.get("Authorization", "")
    if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
        return jsonify({"error": "Unauthorized"}), 401

    data = request.get_json()
    messages = data.get("messages", [])
    max_tokens = data.get("max_tokens", 256)
    temperature = data.get("temperature", 0.7)

    prompt = ""
    for msg in messages:
        role = msg.get("role", "")
        content = msg.get("content", "")
        if role and content:
            prompt += f"{role.capitalize()}: {content}\n"
    prompt += "Assistant:"

    output = generator(
        prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True
    )

    generated_text = output[0]['generated_text']
    assistant_reply = generated_text.replace(prompt, "").strip()

    return jsonify({
        "choices": [{
            "message": {
                "role": "assistant",
                "content": assistant_reply
            },
            "finish_reason": "stop",
            "index": 0
        }]
    })

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8080)