File size: 3,689 Bytes
3aafe68
330fc4f
0eb710b
699a812
05e25f7
 
 
eaba60c
05e25f7
 
1836de9
05e25f7
330fc4f
05e25f7
330fc4f
81e998f
330fc4f
b5b8672
05e25f7
b5b8672
05e25f7
a454488
 
 
05e25f7
 
 
 
 
 
 
 
 
 
 
 
330fc4f
05e25f7
 
 
b5b8672
1836de9
330fc4f
 
b5b8672
05e25f7
 
ac19c17
05e25f7
 
ac19c17
 
05e25f7
 
 
ac19c17
e65b516
 
ac19c17
 
 
 
 
 
 
05e25f7
 
 
ac19c17
 
1836de9
ac19c17
1836de9
 
 
 
05e25f7
 
1836de9
05e25f7
a454488
0373f3c
330fc4f
27b07a6
05e25f7
330fc4f
27b07a6
a454488
0eb710b
05e25f7
1836de9
0373f3c
 
1836de9
0373f3c
05e25f7
0373f3c
1836de9
0eb710b
a454488
 
 
1836de9
a454488
 
05e25f7
a454488
1836de9
0373f3c
 
1836de9
a941d96
 
330fc4f
1836de9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import torch
import time
import os
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login

# Hardcoded Hugging Face Token
HF_TOKEN = os.environ.get("HF_TOKEN")  # or directly "hf_xxxxxx"

# App config
st.set_page_config(
    page_title="GM Fine-tune Assistant πŸš€",
    page_icon="πŸš€",
    layout="centered"
)

st.title("πŸš€ GM Fine-tune Assistant πŸš€")

# Avatars
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"

# Login to Huggingface
login(token=HF_TOKEN)

# Load Model
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_finetune", token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(
        "amiguel/GM_finetune",
        device_map="auto",
        torch_dtype=torch.bfloat16,
        token=HF_TOKEN
    )
    return model, tokenizer

model, tokenizer = load_model()

# Session state
if "messages" not in st.session_state:
    st.session_state.messages = []

# Streamer
def generate_response(prompt, model, tokenizer):
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": 1024,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "do_sample": True,
        "streamer": streamer
    }

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    return streamer

# Display chat history
for message in st.session_state.messages:
    avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
    with st.chat_message(message["role"], avatar=avatar):
        st.markdown(message["content"])

# Chat input
if prompt := st.chat_input("Ask me anything about General Knowledge..."):

    # Display user message
    with st.chat_message("user", avatar=USER_AVATAR):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Bot generating response
    if model and tokenizer:
        try:
            with st.chat_message("assistant", avatar=BOT_AVATAR):
                start_time = time.time()
                streamer = generate_response(prompt, model, tokenizer)

                response_container = st.empty()
                full_response = ""

                for chunk in streamer:
                    full_response += chunk
                    response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)

                end_time = time.time()
                input_tokens = len(tokenizer(prompt)["input_ids"])
                output_tokens = len(tokenizer(full_response)["input_ids"])
                speed = output_tokens / (end_time - start_time)

                st.caption(
                    f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
                    f"πŸ•’ Speed: {speed:.1f} tokens/sec"
                )

                response_container.markdown(full_response)
                st.session_state.messages.append({"role": "assistant", "content": full_response})

        except Exception as e:
            st.error(f"⚑ Generation error: {str(e)}")
    else:
        st.error("πŸ€– Model not loaded!")