File size: 3,689 Bytes
3aafe68 330fc4f 0eb710b 699a812 05e25f7 eaba60c 05e25f7 1836de9 05e25f7 330fc4f 05e25f7 330fc4f 81e998f 330fc4f b5b8672 05e25f7 b5b8672 05e25f7 a454488 05e25f7 330fc4f 05e25f7 b5b8672 1836de9 330fc4f b5b8672 05e25f7 ac19c17 05e25f7 ac19c17 05e25f7 ac19c17 e65b516 ac19c17 05e25f7 ac19c17 1836de9 ac19c17 1836de9 05e25f7 1836de9 05e25f7 a454488 0373f3c 330fc4f 27b07a6 05e25f7 330fc4f 27b07a6 a454488 0eb710b 05e25f7 1836de9 0373f3c 1836de9 0373f3c 05e25f7 0373f3c 1836de9 0eb710b a454488 1836de9 a454488 05e25f7 a454488 1836de9 0373f3c 1836de9 a941d96 330fc4f 1836de9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import streamlit as st
import torch
import time
import os
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
# Hardcoded Hugging Face Token
HF_TOKEN = os.environ.get("HF_TOKEN") # or directly "hf_xxxxxx"
# App config
st.set_page_config(
page_title="GM Fine-tune Assistant π",
page_icon="π",
layout="centered"
)
st.title("π GM Fine-tune Assistant π")
# Avatars
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# Login to Huggingface
login(token=HF_TOKEN)
# Load Model
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_finetune", token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
"amiguel/GM_finetune",
device_map="auto",
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
return model, tokenizer
model, tokenizer = load_model()
# Session state
if "messages" not in st.session_state:
st.session_state.messages = []
# Streamer
def generate_response(prompt, model, tokenizer):
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": 1024,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
"do_sample": True,
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
# Display chat history
for message in st.session_state.messages:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Ask me anything about General Knowledge..."):
# Display user message
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
# Bot generating response
if model and tokenizer:
try:
with st.chat_message("assistant", avatar=BOT_AVATAR):
start_time = time.time()
streamer = generate_response(prompt, model, tokenizer)
response_container = st.empty()
full_response = ""
for chunk in streamer:
full_response += chunk
response_container.markdown(full_response + "β", unsafe_allow_html=True)
end_time = time.time()
input_tokens = len(tokenizer(prompt)["input_ids"])
output_tokens = len(tokenizer(full_response)["input_ids"])
speed = output_tokens / (end_time - start_time)
st.caption(
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
f"π Speed: {speed:.1f} tokens/sec"
)
response_container.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
except Exception as e:
st.error(f"β‘ Generation error: {str(e)}")
else:
st.error("π€ Model not loaded!")
|