File size: 4,442 Bytes
d331351
c0fe323
efdd63d
3077668
 
fbbc9c1
3077668
 
f2bc959
65e7b56
d331351
a5e8a2b
c08a965
65e7b56
d4203b5
 
d331351
beb9a26
a5e8a2b
 
7b7ead5
 
 
bf009e0
a5e8a2b
 
7b7ead5
 
18e3582
c0fe323
7b7ead5
 
d331351
c325ffc
c08a965
bf009e0
c0fe323
 
65e7b56
c08a965
 
8a9f79f
 
 
 
c08a965
 
 
 
 
 
 
 
 
 
3077668
c08a965
6db605f
c08a965
 
 
 
 
6db605f
3077668
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beb9a26
 
c08a965
 
e2f4417
 
d331351
beb9a26
c08a965
6db605f
c08a965
a5e8a2b
 
 
 
 
 
 
 
c08a965
bf009e0
65e7b56
 
 
 
 
 
d4203b5
 
65e7b56
d4203b5
 
 
 
 
 
65e7b56
 
d4203b5
 
3077668
65e7b56
d4203b5
65e7b56
d4203b5
 
65e7b56
 
 
 
3077668
65e7b56
 
3077668
65e7b56
c08a965
d4203b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
import json
import time
from huggingface_hub import login
from flask import Flask, request, jsonify, Response
import gradio as gr

# Hugging Face Auth
login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))
API_TOKEN = os.getenv("HF_API_TOKEN")

# Model config
model_name = "cerebras/btlm-3b-8k-chat"
revision = "main"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
os.environ['HF_HOME'] = '/tmp/cache'

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch_dtype,
    device_map="auto",
    trust_remote_code=True,
    revision=revision
)

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    torch_dtype=torch_dtype,
    pad_token_id=tokenizer.eos_token_id,
    trust_remote_code=True
)

# Flask backend
app = Flask(__name__)

@app.route("/")
def home():
    return "API is running"

@app.route("/v1/chat/completions", methods=["POST"])
def chat():
    auth_header = request.headers.get("Authorization", "")
    if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
        return jsonify({"error": "Unauthorized"}), 401

    data = request.json
    messages = data.get("messages", [])
    max_tokens = data.get("max_tokens", 256)
    temperature = data.get("temperature", 0.7)
    stream = data.get("stream", False)

    prompt = ""
    for msg in messages:
        role = msg.get("role", "user").capitalize()
        content = msg.get("content", "")
        prompt += f"{role}: {content}\n"
    prompt += "Assistant:"

    if stream:
        def generate_stream():
            output = generator(
                prompt,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=0.9,
                repetition_penalty=1.1,
                do_sample=True
            )
            reply = output[0]["generated_text"].replace(prompt, "").strip()
            for word in reply.split():
                chunk = {
                    "choices": [{
                        "delta": {"content": word + " "},
                        "index": 0,
                        "finish_reason": None
                    }]
                }
                yield f"data: {json.dumps(chunk)}\n\n"
                time.sleep(0.01)
            yield "data: " + json.dumps({
                "choices": [{
                    "delta": {},
                    "index": 0,
                    "finish_reason": "stop"
                }]
            }) + "\n\n"
            yield "data: [DONE]\n\n"

        return Response(generate_stream(), content_type="text/event-stream")

    output = generator(
        prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True
    )
    reply = output[0]["generated_text"].replace(prompt, "").strip()

    return jsonify({
        "choices": [{
            "message": {
                "role": "assistant",
                "content": reply
            },
            "finish_reason": "stop",
            "index": 0
        }]
    })

# βœ… Gradio Chat UI
def gradio_chat(user_input, history=[]):
    full_prompt = ""
    for turn in history:
        full_prompt += f"User: {turn[0]}\nAssistant: {turn[1]}\n"
    full_prompt += f"User: {user_input}\nAssistant:"

    output = generator(
        full_prompt,
        max_new_tokens=256,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True
    )
    reply = output[0]["generated_text"].replace(full_prompt, "").strip()
    history.append((user_input, reply))
    return history, history

with gr.Blocks() as demo:
    gr.Markdown("## πŸ’¬ Chat with Ariphes (LLM-powered)")
    chatbot = gr.Chatbot()
    msg = gr.Textbox(placeholder="Ask me anything...", label="Message")
    clear = gr.Button("Clear")

    state = gr.State([])

    msg.submit(gradio_chat, [msg, state], [chatbot, state])
    clear.click(lambda: ([], []), None, [chatbot, state])

# βœ… Enable share=True so Hugging Face can access it
demo.launch(share=True)

# βœ… Still serve API endpoint for OpenAI-compatible connector
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8080)