Update app.py
Browse files
app.py
CHANGED
@@ -17,27 +17,33 @@ HF_TOKEN = st.secrets["HF_TOKEN"]
|
|
17 |
st.set_page_config(page_title="DigiTwin RAG", page_icon="π", layout="centered")
|
18 |
st.title("π DigiTs the Twin")
|
19 |
|
20 |
-
# ---
|
21 |
with st.sidebar:
|
22 |
st.header("π Upload Knowledge Files")
|
23 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
|
|
24 |
if uploaded_files:
|
25 |
st.success(f"{len(uploaded_files)} file(s) uploaded")
|
26 |
|
27 |
# --- Load Model & Tokenizer ---
|
28 |
@st.cache_resource
|
29 |
-
def load_model():
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
31 |
model = AutoModelForCausalLM.from_pretrained(
|
32 |
-
|
33 |
device_map="auto",
|
34 |
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
|
35 |
trust_remote_code=True,
|
36 |
token=HF_TOKEN
|
37 |
)
|
38 |
-
return model, tokenizer
|
39 |
|
40 |
-
model, tokenizer = load_model()
|
41 |
|
42 |
# --- System Prompt ---
|
43 |
SYSTEM_PROMPT = (
|
@@ -128,14 +134,19 @@ if prompt := st.chat_input("Ask something based on uploaded documents..."):
|
|
128 |
|
129 |
for chunk in generate_response(full_prompt):
|
130 |
answer += chunk
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
end = time.time()
|
135 |
-
st.session_state.messages.append({"role": "assistant", "content":
|
136 |
|
137 |
input_tokens = len(tokenizer(full_prompt)["input_ids"])
|
138 |
-
output_tokens = len(tokenizer(
|
139 |
speed = output_tokens / (end - start)
|
140 |
|
141 |
with st.expander("π Debug Info"):
|
|
|
17 |
st.set_page_config(page_title="DigiTwin RAG", page_icon="π", layout="centered")
|
18 |
st.title("π DigiTs the Twin")
|
19 |
|
20 |
+
# --- Sidebar ---
|
21 |
with st.sidebar:
|
22 |
st.header("π Upload Knowledge Files")
|
23 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
24 |
+
model_choice = st.selectbox("π§ Choose Model", ["Qwen", "Mistral"])
|
25 |
if uploaded_files:
|
26 |
st.success(f"{len(uploaded_files)} file(s) uploaded")
|
27 |
|
28 |
# --- Load Model & Tokenizer ---
|
29 |
@st.cache_resource
|
30 |
+
def load_model(selected_model):
|
31 |
+
if selected_model == "Qwen":
|
32 |
+
model_id = "amiguel/GM_Qwen1.8B_Finetune"
|
33 |
+
else:
|
34 |
+
model_id = "amiguel/GM_Mistral7B_Finetune"
|
35 |
+
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=HF_TOKEN)
|
37 |
model = AutoModelForCausalLM.from_pretrained(
|
38 |
+
model_id,
|
39 |
device_map="auto",
|
40 |
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
|
41 |
trust_remote_code=True,
|
42 |
token=HF_TOKEN
|
43 |
)
|
44 |
+
return model, tokenizer, model_id
|
45 |
|
46 |
+
model, tokenizer, model_id = load_model(model_choice)
|
47 |
|
48 |
# --- System Prompt ---
|
49 |
SYSTEM_PROMPT = (
|
|
|
134 |
|
135 |
for chunk in generate_response(full_prompt):
|
136 |
answer += chunk
|
137 |
+
cleaned = answer
|
138 |
+
|
139 |
+
# π§ Strip <|im_start|>, <|im_end|> if using Mistral (Qwen needs them)
|
140 |
+
if "Mistral" in model_id:
|
141 |
+
cleaned = cleaned.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()
|
142 |
+
|
143 |
+
container.markdown(cleaned + "β", unsafe_allow_html=True)
|
144 |
|
145 |
end = time.time()
|
146 |
+
st.session_state.messages.append({"role": "assistant", "content": cleaned})
|
147 |
|
148 |
input_tokens = len(tokenizer(full_prompt)["input_ids"])
|
149 |
+
output_tokens = len(tokenizer(cleaned)["input_ids"])
|
150 |
speed = output_tokens / (end - start)
|
151 |
|
152 |
with st.expander("π Debug Info"):
|