amiguel's picture
Update app.py
05e25f7 verified
raw
history blame
3.69 kB
import streamlit as st
import torch
import time
import os
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
# Hardcoded Hugging Face Token
HF_TOKEN = os.environ.get("HF_TOKEN") # or directly "hf_xxxxxx"
# App config
st.set_page_config(
page_title="GM Fine-tune Assistant πŸš€",
page_icon="πŸš€",
layout="centered"
)
st.title("πŸš€ GM Fine-tune Assistant πŸš€")
# Avatars
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# Login to Huggingface
login(token=HF_TOKEN)
# Load Model
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_finetune", token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
"amiguel/GM_finetune",
device_map="auto",
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
return model, tokenizer
model, tokenizer = load_model()
# Session state
if "messages" not in st.session_state:
st.session_state.messages = []
# Streamer
def generate_response(prompt, model, tokenizer):
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": 1024,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
"do_sample": True,
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
# Display chat history
for message in st.session_state.messages:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Ask me anything about General Knowledge..."):
# Display user message
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
# Bot generating response
if model and tokenizer:
try:
with st.chat_message("assistant", avatar=BOT_AVATAR):
start_time = time.time()
streamer = generate_response(prompt, model, tokenizer)
response_container = st.empty()
full_response = ""
for chunk in streamer:
full_response += chunk
response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)
end_time = time.time()
input_tokens = len(tokenizer(prompt)["input_ids"])
output_tokens = len(tokenizer(full_response)["input_ids"])
speed = output_tokens / (end_time - start_time)
st.caption(
f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
f"πŸ•’ Speed: {speed:.1f} tokens/sec"
)
response_container.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
except Exception as e:
st.error(f"⚑ Generation error: {str(e)}")
else:
st.error("πŸ€– Model not loaded!")