File size: 3,665 Bytes
d331351
c0fe323
efdd63d
3077668
 
fbbc9c1
3077668
 
f2bc959
c08a965
d331351
fbbc9c1
c08a965
a5e8a2b
c08a965
 
d331351
beb9a26
fce7f32
a5e8a2b
 
 
7b7ead5
 
 
bf009e0
a5e8a2b
 
7b7ead5
 
18e3582
c0fe323
7b7ead5
 
d331351
c325ffc
c08a965
bf009e0
c0fe323
 
c08a965
 
8a9f79f
 
 
 
c08a965
 
 
 
 
 
 
 
 
 
3077668
c08a965
6db605f
c08a965
 
 
 
 
6db605f
3077668
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beb9a26
 
c08a965
 
e2f4417
 
d331351
beb9a26
c08a965
6db605f
c08a965
a5e8a2b
 
 
 
 
 
 
 
c08a965
bf009e0
a5e8a2b
3077668
 
 
 
 
c08a965
3077668
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
import json
import time
from huggingface_hub import login
from flask import Flask, request, jsonify, Response
import gradio as gr

# Login to Hugging Face using secret token stored in Space secrets
login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))

# Token authentication for requests
API_TOKEN = os.getenv("HF_API_TOKEN")

# Set up model loading and pipeline
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
os.environ['HF_HOME'] = '/tmp/cache'
model_name = "cerebras/btlm-3b-8k-chat"
revision = "main"  # Pin to stable revision

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch_dtype,
    device_map="auto",
    trust_remote_code=True,
    revision=revision
)

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    torch_dtype=torch_dtype,
    pad_token_id=tokenizer.eos_token_id,
    trust_remote_code=True
)

app = Flask(__name__)

@app.route("/")
def home():
    return "API is running"

@app.route("/v1/chat/completions", methods=["POST"])
def chat():
    auth_header = request.headers.get("Authorization", "")
    if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
        return jsonify({"error": "Unauthorized"}), 401

    data = request.json
    messages = data.get("messages", [])
    max_tokens = data.get("max_tokens", 256)
    temperature = data.get("temperature", 0.7)
    stream = data.get("stream", False)

    prompt = ""
    for msg in messages:
        role = msg.get("role", "user").capitalize()
        content = msg.get("content", "")
        prompt += f"{role}: {content}\n"
    prompt += "Assistant:"

    if stream:
        def generate_stream():
            output = generator(
                prompt,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=0.9,
                repetition_penalty=1.1,
                do_sample=True
            )
            reply = output[0]["generated_text"].replace(prompt, "").strip()
            for word in reply.split():
                chunk = {
                    "choices": [{
                        "delta": {"content": word + " "},
                        "index": 0,
                        "finish_reason": None
                    }]
                }
                yield f"data: {json.dumps(chunk)}\n\n"
                time.sleep(0.01)
            yield "data: " + json.dumps({
                "choices": [{
                    "delta": {},
                    "index": 0,
                    "finish_reason": "stop"
                }]
            }) + "\n\n"
            yield "data: [DONE]\n\n"

        return Response(generate_stream(), content_type="text/event-stream")

    output = generator(
        prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True
    )
    reply = output[0]["generated_text"].replace(prompt, "").strip()

    return jsonify({
        "choices": [{
            "message": {
                "role": "assistant",
                "content": reply
            },
            "finish_reason": "stop",
            "index": 0
        }]
    })

# Optional Gradio frontend to keep Space alive
with gr.Blocks() as demo:
    gr.Markdown("### LLM backend is running and ready for API calls.")

demo.launch()

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8080)