LLM_Ariphes / app.py
Euryeth's picture
Update app.py
65e7b56 verified
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
import json
import time
from huggingface_hub import login
from flask import Flask, request, jsonify, Response
import gradio as gr
# Hugging Face Auth
login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))
API_TOKEN = os.getenv("HF_API_TOKEN")
# Model config
model_name = "cerebras/btlm-3b-8k-chat"
revision = "main"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
os.environ['HF_HOME'] = '/tmp/cache'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map="auto",
trust_remote_code=True,
revision=revision
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
torch_dtype=torch_dtype,
pad_token_id=tokenizer.eos_token_id,
trust_remote_code=True
)
# Flask backend
app = Flask(__name__)
@app.route("/")
def home():
return "API is running"
@app.route("/v1/chat/completions", methods=["POST"])
def chat():
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
return jsonify({"error": "Unauthorized"}), 401
data = request.json
messages = data.get("messages", [])
max_tokens = data.get("max_tokens", 256)
temperature = data.get("temperature", 0.7)
stream = data.get("stream", False)
prompt = ""
for msg in messages:
role = msg.get("role", "user").capitalize()
content = msg.get("content", "")
prompt += f"{role}: {content}\n"
prompt += "Assistant:"
if stream:
def generate_stream():
output = generator(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True
)
reply = output[0]["generated_text"].replace(prompt, "").strip()
for word in reply.split():
chunk = {
"choices": [{
"delta": {"content": word + " "},
"index": 0,
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
time.sleep(0.01)
yield "data: " + json.dumps({
"choices": [{
"delta": {},
"index": 0,
"finish_reason": "stop"
}]
}) + "\n\n"
yield "data: [DONE]\n\n"
return Response(generate_stream(), content_type="text/event-stream")
output = generator(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True
)
reply = output[0]["generated_text"].replace(prompt, "").strip()
return jsonify({
"choices": [{
"message": {
"role": "assistant",
"content": reply
},
"finish_reason": "stop",
"index": 0
}]
})
# βœ… Gradio Chat UI
def gradio_chat(user_input, history=[]):
full_prompt = ""
for turn in history:
full_prompt += f"User: {turn[0]}\nAssistant: {turn[1]}\n"
full_prompt += f"User: {user_input}\nAssistant:"
output = generator(
full_prompt,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True
)
reply = output[0]["generated_text"].replace(full_prompt, "").strip()
history.append((user_input, reply))
return history, history
with gr.Blocks() as demo:
gr.Markdown("## πŸ’¬ Chat with Ariphes (LLM-powered)")
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder="Ask me anything...", label="Message")
clear = gr.Button("Clear")
state = gr.State([])
msg.submit(gradio_chat, [msg, state], [chatbot, state])
clear.click(lambda: ([], []), None, [chatbot, state])
# βœ… Enable share=True so Hugging Face can access it
demo.launch(share=True)
# βœ… Still serve API endpoint for OpenAI-compatible connector
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8080)