Spaces:
Runtime error
Runtime error
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
import json | |
import time | |
from huggingface_hub import login | |
from flask import Flask, request, jsonify, Response | |
import gradio as gr | |
# Hugging Face Auth | |
login(os.getenv("HUGGINGFACEHUB_API_TOKEN")) | |
API_TOKEN = os.getenv("HF_API_TOKEN") | |
# Model config | |
model_name = "cerebras/btlm-3b-8k-chat" | |
revision = "main" | |
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
os.environ['HF_HOME'] = '/tmp/cache' | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision=revision) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch_dtype, | |
device_map="auto", | |
trust_remote_code=True, | |
revision=revision | |
) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto", | |
torch_dtype=torch_dtype, | |
pad_token_id=tokenizer.eos_token_id, | |
trust_remote_code=True | |
) | |
# Flask backend | |
app = Flask(__name__) | |
def home(): | |
return "API is running" | |
def chat(): | |
auth_header = request.headers.get("Authorization", "") | |
if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN: | |
return jsonify({"error": "Unauthorized"}), 401 | |
data = request.json | |
messages = data.get("messages", []) | |
max_tokens = data.get("max_tokens", 256) | |
temperature = data.get("temperature", 0.7) | |
stream = data.get("stream", False) | |
prompt = "" | |
for msg in messages: | |
role = msg.get("role", "user").capitalize() | |
content = msg.get("content", "") | |
prompt += f"{role}: {content}\n" | |
prompt += "Assistant:" | |
if stream: | |
def generate_stream(): | |
output = generator( | |
prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
do_sample=True | |
) | |
reply = output[0]["generated_text"].replace(prompt, "").strip() | |
for word in reply.split(): | |
chunk = { | |
"choices": [{ | |
"delta": {"content": word + " "}, | |
"index": 0, | |
"finish_reason": None | |
}] | |
} | |
yield f"data: {json.dumps(chunk)}\n\n" | |
time.sleep(0.01) | |
yield "data: " + json.dumps({ | |
"choices": [{ | |
"delta": {}, | |
"index": 0, | |
"finish_reason": "stop" | |
}] | |
}) + "\n\n" | |
yield "data: [DONE]\n\n" | |
return Response(generate_stream(), content_type="text/event-stream") | |
output = generator( | |
prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
do_sample=True | |
) | |
reply = output[0]["generated_text"].replace(prompt, "").strip() | |
return jsonify({ | |
"choices": [{ | |
"message": { | |
"role": "assistant", | |
"content": reply | |
}, | |
"finish_reason": "stop", | |
"index": 0 | |
}] | |
}) | |
# β Gradio Chat UI | |
def gradio_chat(user_input, history=[]): | |
full_prompt = "" | |
for turn in history: | |
full_prompt += f"User: {turn[0]}\nAssistant: {turn[1]}\n" | |
full_prompt += f"User: {user_input}\nAssistant:" | |
output = generator( | |
full_prompt, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
do_sample=True | |
) | |
reply = output[0]["generated_text"].replace(full_prompt, "").strip() | |
history.append((user_input, reply)) | |
return history, history | |
with gr.Blocks() as demo: | |
gr.Markdown("## π¬ Chat with Ariphes (LLM-powered)") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(placeholder="Ask me anything...", label="Message") | |
clear = gr.Button("Clear") | |
state = gr.State([]) | |
msg.submit(gradio_chat, [msg, state], [chatbot, state]) | |
clear.click(lambda: ([], []), None, [chatbot, state]) | |
# β Enable share=True so Hugging Face can access it | |
demo.launch(share=True) | |
# β Still serve API endpoint for OpenAI-compatible connector | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=8080) | |