File size: 2,363 Bytes
91b2dc1
 
 
 
aaea5e8
461845f
85b0ca0
91b2dc1
 
85b0ca0
461845f
85b0ca0
91b2dc1
85b0ca0
91b2dc1
 
 
85b0ca0
91b2dc1
 
 
 
85b0ca0
 
91b2dc1
c60c816
91b2dc1
 
 
 
 
 
aaea5e8
 
91b2dc1
 
aaea5e8
 
 
 
 
 
91b2dc1
 
aaea5e8
 
 
91b2dc1
aaea5e8
 
 
 
91b2dc1
aaea5e8
 
 
 
 
 
 
91b2dc1
aaea5e8
 
 
 
 
 
 
 
c60c816
aaea5e8
 
91b2dc1
aaea5e8
 
 
 
 
 
 
 
 
 
8ce7317
91b2dc1
85b0ca0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
from huggingface_hub import login
from flask import Flask, request, jsonify

# Login with Hugging Face token
login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))

API_TOKEN = os.getenv("HF_API_TOKEN")

# Model and loading config
model_name = "cerebras/btlm-3b-8k-chat"
revision = "main"  # Pin to specific model revision
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
os.environ['HF_HOME'] = '/tmp/cache'

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch_dtype,
    device_map="auto",
    trust_remote_code=True,
    revision=revision
)

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    torch_dtype=torch_dtype,
    pad_token_id=tokenizer.eos_token_id,
    trust_remote_code=True
)

app = Flask(__name__)

@app.route("/")
def home():
    return "API is running"

@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
    auth_header = request.headers.get("Authorization", "")
    if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
        return jsonify({"error": "Unauthorized"}), 401

    data = request.get_json()
    messages = data.get("messages", [])
    max_tokens = data.get("max_tokens", 256)
    temperature = data.get("temperature", 0.7)

    prompt = ""
    for msg in messages:
        role = msg.get("role", "")
        content = msg.get("content", "")
        if role and content:
            prompt += f"{role.capitalize()}: {content}\n"
    prompt += "Assistant:"

    output = generator(
        prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True
    )

    generated_text = output[0]['generated_text']
    assistant_reply = generated_text.replace(prompt, "").strip()

    return jsonify({
        "choices": [{
            "message": {
                "role": "assistant",
                "content": assistant_reply
            },
            "finish_reason": "stop",
            "index": 0
        }]
    })

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8080)