amiguel commited on
Commit
6a8c296
Β·
verified Β·
1 Parent(s): d60b079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -91
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import streamlit as st
2
  from transformers import (
3
  AutoTokenizer,
4
- AutoModelForCausalLM,
5
- TextIteratorStreamer,
6
- AutoConfig
7
  )
8
  from huggingface_hub import login
9
  from threading import Thread
@@ -13,50 +12,27 @@ import torch
13
  import time
14
  import os
15
 
16
- # Check if 'peft' is installed
17
- try:
18
- from peft import PeftModel, PeftConfig
19
- except ImportError:
20
- raise ImportError(
21
- "The 'peft' library is required but not installed. "
22
- "Please install it using: `pip install peft`"
23
- )
24
-
25
  # πŸ” Hugging Face Token via Environment Variable
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
  if not HF_TOKEN:
28
  raise ValueError("Missing Hugging Face Token. Please set the HF_TOKEN environment variable.")
29
 
30
- # πŸŽ› Model base and adapters
31
- BASE_MODEL_NAME = "unicamp-dl/ptt5-base-portuguese-vocab" #"neuralmind/bert-base-portuguese-cased" #"pierreguillou/gpt2-small-portuguese" # #"mistralai/Mistral-7B-Instruct-v0.2"
32
- MODEL_OPTIONS = {
33
- "Full Fine-Tuned": "amiguel/mistral-angolan-laborlaw-ptt5", #"amiguel/mistral-angolan-laborlaw-bert-base-pt", #"amiguel/mistral-angolan-laborlaw-gpt2",#, #"amiguel/mistral-angolan-laborlaw",
34
- "LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
35
- "QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora"
36
- }
37
-
38
 
39
- # πŸ–Ό UI Setup
40
  st.set_page_config(page_title="Assistente LGT | Angola", page_icon="πŸš€", layout="centered")
41
  st.title("πŸš€ Assistente LGT | Angola πŸš€")
42
 
43
  USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
44
  BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
45
 
46
- # Sidebar
47
  with st.sidebar:
48
- st.header("Model Selection πŸ€–")
49
- model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
50
- selected_model = MODEL_OPTIONS[model_type]
51
-
52
- st.header("Upload Documents πŸ“‚")
53
- uploaded_file = st.file_uploader("Choose a PDF or XLSX file", type=["pdf", "xlsx"], label_visibility="collapsed")
54
-
55
- # Chat memory
56
- if "messages" not in st.session_state:
57
- st.session_state.messages = []
58
 
59
- # πŸ” File processing
60
  @st.cache_data
61
  def process_file(uploaded_file):
62
  if uploaded_file is None:
@@ -69,57 +45,32 @@ def process_file(uploaded_file):
69
  df = pd.read_excel(uploaded_file)
70
  return df.to_markdown()
71
  except Exception as e:
72
- st.error(f"πŸ“„ Error processing file: {str(e)}")
73
  return ""
74
 
75
- # 🧠 Load model and tokenizer
76
  @st.cache_resource
77
- def load_model(model_type, selected_model):
78
  try:
79
  login(token=HF_TOKEN)
80
- device = "cuda" if torch.cuda.is_available() else "cpu"
81
- dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
82
-
83
- tokenizer = AutoTokenizer.from_pretrained(selected_model, token=HF_TOKEN)
84
-
85
- if model_type == "Full Fine-Tuned":
86
- model = AutoModelForCausalLM.from_pretrained(
87
- selected_model,
88
- device_map="auto",
89
- torch_dtype=dtype,
90
- token=HF_TOKEN
91
- )
92
- else:
93
- base_model = AutoModelForCausalLM.from_pretrained(
94
- BASE_MODEL_NAME,
95
- device_map="auto",
96
- torch_dtype=dtype,
97
- token=HF_TOKEN
98
- )
99
- model = PeftModel.from_pretrained(
100
- base_model,
101
- selected_model,
102
- is_trainable=False,
103
- torch_dtype=dtype,
104
- token=HF_TOKEN
105
- )
106
  return model, tokenizer
107
  except Exception as e:
108
- st.error(f"πŸ€– Model loading failed: {str(e)}")
109
  return None, None
110
 
111
- # πŸš€ Generate response
112
- def generate_with_streaming(prompt, file_context, model, tokenizer):
113
- full_prompt = f"Analisa este contexto:\n{file_context}\n\nPergunta: {prompt}\nResposta:"
114
-
115
- inputs = tokenizer(full_prompt, return_tensors="pt")
116
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
117
 
 
118
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
119
- gen_kwargs = {
 
120
  "input_ids": inputs["input_ids"],
121
  "attention_mask": inputs["attention_mask"],
122
- "max_new_tokens": 1024,
123
  "temperature": 0.7,
124
  "top_p": 0.9,
125
  "repetition_penalty": 1.1,
@@ -128,37 +79,38 @@ def generate_with_streaming(prompt, file_context, model, tokenizer):
128
  "streamer": streamer
129
  }
130
 
131
- Thread(target=model.generate, kwargs=gen_kwargs).start()
132
  return streamer
133
 
134
- # 🧾 Display chat history
135
- for msg in st.session_state.messages:
136
- avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR
137
- with st.chat_message(msg["role"], avatar=avatar):
138
- st.markdown(msg["content"])
 
 
 
 
139
 
140
- # πŸ”Ž Main interaction loop
141
- if prompt := st.chat_input("Pergunta sobre a LGT?"):
142
- # Display user message
143
  with st.chat_message("user", avatar=USER_AVATAR):
144
  st.markdown(prompt)
145
  st.session_state.messages.append({"role": "user", "content": prompt})
146
 
147
- # Load model if needed
148
- if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
149
- with st.spinner("πŸ”„ A carregar modelo..."):
150
- model, tokenizer = load_model(model_type, selected_model)
151
  if not model:
152
  st.stop()
153
  st.session_state.model = model
154
  st.session_state.tokenizer = tokenizer
155
- st.session_state.model_type = model_type
156
  else:
157
  model = st.session_state.model
158
  tokenizer = st.session_state.tokenizer
159
 
160
- # Prepare context
161
- file_context = process_file(uploaded_file) or "Sem contexto adicional disponΓ­vel."
162
 
163
  # Generate assistant response
164
  with st.chat_message("assistant", avatar=BOT_AVATAR):
@@ -166,13 +118,12 @@ if prompt := st.chat_input("Pergunta sobre a LGT?"):
166
  full_response = ""
167
  try:
168
  start_time = time.time()
169
- streamer = generate_with_streaming(prompt, file_context, model, tokenizer)
170
 
171
  for chunk in streamer:
172
  full_response += chunk.strip() + " "
173
  response_box.markdown(full_response + "β–Œ", unsafe_allow_html=True)
174
 
175
- # Token and speed metrics
176
  end_time = time.time()
177
  input_tokens = len(tokenizer(prompt)["input_ids"])
178
  output_tokens = len(tokenizer(full_response)["input_ids"])
@@ -181,8 +132,8 @@ if prompt := st.chat_input("Pergunta sobre a LGT?"):
181
  cost_aoa = cost_usd * 1160
182
 
183
  st.caption(
184
- f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
185
- f"πŸ•’ Speed: {speed:.1f}t/s | πŸ’° USD: ${cost_usd:.4f} | πŸ‡¦πŸ‡΄ AOA: {cost_aoa:.2f}"
186
  )
187
 
188
  response_box.markdown(full_response.strip())
 
1
  import streamlit as st
2
  from transformers import (
3
  AutoTokenizer,
4
+ AutoModelForSeq2SeqLM,
5
+ TextIteratorStreamer
 
6
  )
7
  from huggingface_hub import login
8
  from threading import Thread
 
12
  import time
13
  import os
14
 
 
 
 
 
 
 
 
 
 
15
  # πŸ” Hugging Face Token via Environment Variable
16
  HF_TOKEN = os.environ.get("HF_TOKEN")
17
  if not HF_TOKEN:
18
  raise ValueError("Missing Hugging Face Token. Please set the HF_TOKEN environment variable.")
19
 
20
+ # βœ… Only PT-T5 Model
21
+ MODEL_NAME = "amiguel/mistral-angolan-laborlaw-ptt5"
 
 
 
 
 
 
22
 
23
+ # UI Setup
24
  st.set_page_config(page_title="Assistente LGT | Angola", page_icon="πŸš€", layout="centered")
25
  st.title("πŸš€ Assistente LGT | Angola πŸš€")
26
 
27
  USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
28
  BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
29
 
30
+ # Upload sidebar
31
  with st.sidebar:
32
+ st.header("Upload Documentos πŸ“‚")
33
+ uploaded_file = st.file_uploader("Escolhe um ficheiro PDF ou XLSX", type=["pdf", "xlsx"], label_visibility="collapsed")
 
 
 
 
 
 
 
 
34
 
35
+ # Cache file processing
36
  @st.cache_data
37
  def process_file(uploaded_file):
38
  if uploaded_file is None:
 
45
  df = pd.read_excel(uploaded_file)
46
  return df.to_markdown()
47
  except Exception as e:
48
+ st.error(f"πŸ“„ Erro ao processar o ficheiro: {str(e)}")
49
  return ""
50
 
51
+ # Cache model loading
52
  @st.cache_resource
53
+ def load_model():
54
  try:
55
  login(token=HF_TOKEN)
56
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN, use_fast=False)
57
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32).to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  return model, tokenizer
59
  except Exception as e:
60
+ st.error(f"πŸ€– Erro ao carregar o modelo: {str(e)}")
61
  return None, None
62
 
63
+ # Streaming response generation
64
+ def generate_response(prompt, context, model, tokenizer):
65
+ full_prompt = f"Contexto:\n{context}\n\nPergunta: {prompt}\nResposta:"
 
 
 
66
 
67
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True).to(model.device)
68
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
69
+
70
+ generation_kwargs = {
71
  "input_ids": inputs["input_ids"],
72
  "attention_mask": inputs["attention_mask"],
73
+ "max_new_tokens": 512,
74
  "temperature": 0.7,
75
  "top_p": 0.9,
76
  "repetition_penalty": 1.1,
 
79
  "streamer": streamer
80
  }
81
 
82
+ Thread(target=model.generate, kwargs=generation_kwargs).start()
83
  return streamer
84
 
85
+ # Store chat history
86
+ if "messages" not in st.session_state:
87
+ st.session_state.messages = []
88
+
89
+ # Show chat history
90
+ for message in st.session_state.messages:
91
+ avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
92
+ with st.chat_message(message["role"], avatar=avatar):
93
+ st.markdown(message["content"])
94
 
95
+ # Chat input
96
+ if prompt := st.chat_input("Faca uma pergunta sobre a LGT..."):
 
97
  with st.chat_message("user", avatar=USER_AVATAR):
98
  st.markdown(prompt)
99
  st.session_state.messages.append({"role": "user", "content": prompt})
100
 
101
+ # Load model if not loaded
102
+ if "model" not in st.session_state:
103
+ with st.spinner("πŸ”„ A carregar o modelo PT-T5..."):
104
+ model, tokenizer = load_model()
105
  if not model:
106
  st.stop()
107
  st.session_state.model = model
108
  st.session_state.tokenizer = tokenizer
 
109
  else:
110
  model = st.session_state.model
111
  tokenizer = st.session_state.tokenizer
112
 
113
+ context = process_file(uploaded_file) or "Sem contexto adicional disponΓ­vel."
 
114
 
115
  # Generate assistant response
116
  with st.chat_message("assistant", avatar=BOT_AVATAR):
 
118
  full_response = ""
119
  try:
120
  start_time = time.time()
121
+ streamer = generate_response(prompt, context, model, tokenizer)
122
 
123
  for chunk in streamer:
124
  full_response += chunk.strip() + " "
125
  response_box.markdown(full_response + "β–Œ", unsafe_allow_html=True)
126
 
 
127
  end_time = time.time()
128
  input_tokens = len(tokenizer(prompt)["input_ids"])
129
  output_tokens = len(tokenizer(full_response)["input_ids"])
 
132
  cost_aoa = cost_usd * 1160
133
 
134
  st.caption(
135
+ f"πŸ”‘ Tokens: {input_tokens} β†’ {output_tokens} | πŸ•’ Velocidade: {speed:.1f}t/s | "
136
+ f"πŸ’° USD: ${cost_usd:.4f} | πŸ‡¦πŸ‡΄ AOA: {cost_aoa:.2f}"
137
  )
138
 
139
  response_box.markdown(full_response.strip())