|
import streamlit as st |
|
import torch |
|
import time |
|
import os |
|
from threading import Thread |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from huggingface_hub import login |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
st.set_page_config( |
|
page_title="GM Fine-tune Assistant π", |
|
page_icon="π", |
|
layout="centered" |
|
) |
|
|
|
st.title("π GM Fine-tune Assistant π") |
|
|
|
|
|
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" |
|
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" |
|
|
|
|
|
login(token=HF_TOKEN) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_finetune", token=HF_TOKEN) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"amiguel/GM_finetune", |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
token=HF_TOKEN |
|
) |
|
return model, tokenizer |
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
def generate_response(prompt, model, tokenizer): |
|
streamer = TextIteratorStreamer( |
|
tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True |
|
) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
generation_kwargs = { |
|
"input_ids": inputs["input_ids"], |
|
"attention_mask": inputs["attention_mask"], |
|
"max_new_tokens": 1024, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"repetition_penalty": 1.1, |
|
"do_sample": True, |
|
"streamer": streamer |
|
} |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
return streamer |
|
|
|
|
|
for message in st.session_state.messages: |
|
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR |
|
with st.chat_message(message["role"], avatar=avatar): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Ask me anything about General Knowledge..."): |
|
|
|
|
|
with st.chat_message("user", avatar=USER_AVATAR): |
|
st.markdown(prompt) |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
if model and tokenizer: |
|
try: |
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
start_time = time.time() |
|
streamer = generate_response(prompt, model, tokenizer) |
|
|
|
response_container = st.empty() |
|
full_response = "" |
|
|
|
for chunk in streamer: |
|
full_response += chunk |
|
response_container.markdown(full_response + "β", unsafe_allow_html=True) |
|
|
|
end_time = time.time() |
|
input_tokens = len(tokenizer(prompt)["input_ids"]) |
|
output_tokens = len(tokenizer(full_response)["input_ids"]) |
|
speed = output_tokens / (end_time - start_time) |
|
|
|
st.caption( |
|
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | " |
|
f"π Speed: {speed:.1f} tokens/sec" |
|
) |
|
|
|
response_container.markdown(full_response) |
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
except Exception as e: |
|
st.error(f"β‘ Generation error: {str(e)}") |
|
else: |
|
st.error("π€ Model not loaded!") |
|
|