Update app.py
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ st.title("π DigiTs the Twin")
|
|
21 |
with st.sidebar:
|
22 |
st.header("π Upload Knowledge Files")
|
23 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
24 |
-
model_choice = st.selectbox("π§ Choose Model", ["Qwen", "Mistral"])
|
25 |
if uploaded_files:
|
26 |
st.success(f"{len(uploaded_files)} file(s) uploaded")
|
27 |
|
@@ -30,10 +30,8 @@ with st.sidebar:
|
|
30 |
def load_model(selected_model):
|
31 |
if selected_model == "Qwen":
|
32 |
model_id = "amiguel/GM_Qwen1.8B_Finetune"
|
33 |
-
|
34 |
-
|
35 |
-
model_id = "amiguel/Llama3_8B_Instruct_FP16"
|
36 |
-
|
37 |
else:
|
38 |
model_id = "amiguel/GM_Mistral7B_Finetune"
|
39 |
|
@@ -61,7 +59,6 @@ SYSTEM_PROMPT = (
|
|
61 |
# --- Prompt Builder ---
|
62 |
def build_prompt(messages, context="", model_name="Qwen"):
|
63 |
if "Mistral" in model_name:
|
64 |
-
# Alpaca-style prompt
|
65 |
prompt = f"You are DigiTwin, an expert in offshore inspection, maintenance, and asset integrity.\n"
|
66 |
if context:
|
67 |
prompt += f"Here is relevant context:\n{context}\n\n"
|
@@ -71,8 +68,18 @@ def build_prompt(messages, context="", model_name="Qwen"):
|
|
71 |
elif msg["role"] == "assistant":
|
72 |
prompt += f"### Response:\n{msg['content'].strip()}\n"
|
73 |
prompt += "### Response:\n"
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
|
77 |
for msg in messages:
|
78 |
role = msg["role"]
|
@@ -80,7 +87,6 @@ def build_prompt(messages, context="", model_name="Qwen"):
|
|
80 |
prompt += "<|im_start|>assistant\n"
|
81 |
return prompt
|
82 |
|
83 |
-
|
84 |
# --- Embed Uploaded Documents ---
|
85 |
@st.cache_resource
|
86 |
def embed_uploaded_files(files):
|
@@ -125,7 +131,7 @@ BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/99
|
|
125 |
if "messages" not in st.session_state:
|
126 |
st.session_state.messages = []
|
127 |
|
128 |
-
# --- Display
|
129 |
for msg in st.session_state.messages:
|
130 |
with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
|
131 |
st.markdown(msg["content"])
|
@@ -141,7 +147,6 @@ if prompt := st.chat_input("Ask something based on uploaded documents..."):
|
|
141 |
docs = retriever.similarity_search(prompt, k=3)
|
142 |
context = "\n\n".join([doc.page_content for doc in docs])
|
143 |
|
144 |
-
# Limit to last 6 messages for memory
|
145 |
recent_messages = st.session_state.messages[-6:]
|
146 |
full_prompt = build_prompt(recent_messages, context, model_name=model_id)
|
147 |
|
@@ -154,9 +159,10 @@ if prompt := st.chat_input("Ask something based on uploaded documents..."):
|
|
154 |
answer += chunk
|
155 |
cleaned = answer
|
156 |
|
157 |
-
|
158 |
-
if "Mistral" in model_id:
|
159 |
cleaned = cleaned.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()
|
|
|
|
|
160 |
|
161 |
container.markdown(cleaned + "β", unsafe_allow_html=True)
|
162 |
|
|
|
21 |
with st.sidebar:
|
22 |
st.header("π Upload Knowledge Files")
|
23 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
24 |
+
model_choice = st.selectbox("π§ Choose Model", ["Qwen", "Mistral", "Llama3"])
|
25 |
if uploaded_files:
|
26 |
st.success(f"{len(uploaded_files)} file(s) uploaded")
|
27 |
|
|
|
30 |
def load_model(selected_model):
|
31 |
if selected_model == "Qwen":
|
32 |
model_id = "amiguel/GM_Qwen1.8B_Finetune"
|
33 |
+
elif selected_model == "Llama3":
|
34 |
+
model_id = "amiguel/Llama3_8B_Instruct_FP16"
|
|
|
|
|
35 |
else:
|
36 |
model_id = "amiguel/GM_Mistral7B_Finetune"
|
37 |
|
|
|
59 |
# --- Prompt Builder ---
|
60 |
def build_prompt(messages, context="", model_name="Qwen"):
|
61 |
if "Mistral" in model_name:
|
|
|
62 |
prompt = f"You are DigiTwin, an expert in offshore inspection, maintenance, and asset integrity.\n"
|
63 |
if context:
|
64 |
prompt += f"Here is relevant context:\n{context}\n\n"
|
|
|
68 |
elif msg["role"] == "assistant":
|
69 |
prompt += f"### Response:\n{msg['content'].strip()}\n"
|
70 |
prompt += "### Response:\n"
|
71 |
+
|
72 |
+
elif "Llama" in model_name:
|
73 |
+
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
|
74 |
+
prompt += f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n"
|
75 |
+
for msg in messages:
|
76 |
+
if msg["role"] == "user":
|
77 |
+
prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"].strip() + "\n"
|
78 |
+
elif msg["role"] == "assistant":
|
79 |
+
prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"].strip() + "\n"
|
80 |
+
prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
|
81 |
+
|
82 |
+
else: # Qwen
|
83 |
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
|
84 |
for msg in messages:
|
85 |
role = msg["role"]
|
|
|
87 |
prompt += "<|im_start|>assistant\n"
|
88 |
return prompt
|
89 |
|
|
|
90 |
# --- Embed Uploaded Documents ---
|
91 |
@st.cache_resource
|
92 |
def embed_uploaded_files(files):
|
|
|
131 |
if "messages" not in st.session_state:
|
132 |
st.session_state.messages = []
|
133 |
|
134 |
+
# --- Display Chat History ---
|
135 |
for msg in st.session_state.messages:
|
136 |
with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
|
137 |
st.markdown(msg["content"])
|
|
|
147 |
docs = retriever.similarity_search(prompt, k=3)
|
148 |
context = "\n\n".join([doc.page_content for doc in docs])
|
149 |
|
|
|
150 |
recent_messages = st.session_state.messages[-6:]
|
151 |
full_prompt = build_prompt(recent_messages, context, model_name=model_id)
|
152 |
|
|
|
159 |
answer += chunk
|
160 |
cleaned = answer
|
161 |
|
162 |
+
if "Mistral" in model_id or "Llama" in model_id:
|
|
|
163 |
cleaned = cleaned.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()
|
164 |
+
cleaned = cleaned.replace("<|start_header_id|>", "").replace("<|end_header_id|>", "")
|
165 |
+
cleaned = cleaned.replace("<|begin_of_text|>", "").strip()
|
166 |
|
167 |
container.markdown(cleaned + "β", unsafe_allow_html=True)
|
168 |
|