HTThuanHcmus commited on
Commit
4ca0883
·
verified ·
1 Parent(s): ab4e2bf

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +378 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,380 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import time
4
+ from datetime import datetime
5
+ from underthesea import word_tokenize
6
+ from transformers import EncoderDecoderModel, AutoModelForSeq2SeqLM, AutoTokenizer
7
+ import torch
8
+ import logging
9
+ import transformers
10
+ import google.generativeai as genai
11
+ from utils.preprocessing import clean_text, segment_text
12
+ import asyncio
13
 
14
+ try:
15
+ asyncio.get_running_loop()
16
+ except RuntimeError:
17
+ asyncio.set_event_loop(asyncio.new_event_loop())
18
+
19
+ # Giảm bớt cảnh báo
20
+ logging.getLogger('streamlit.runtime.scriptrunner.script_run_context').setLevel(logging.ERROR)
21
+ transformers.logging.set_verbosity_error()
22
+
23
+ # Cấu hình Streamlit
24
+ st.set_page_config(page_title="Trình sinh tiêu đề", layout="centered")
25
+
26
+ # Cấu hình Gemini API (thay YOUR_GEMINI_API_KEY bằng API key thực tế)
27
+ GEMINI_API_KEY = "AIzaSyCEDRquPDC9N09hTHGD9FfvsPP83AZT78Q" # Thay bằng API key thực tế của bạn
28
+ genai.configure(api_key=GEMINI_API_KEY)
29
+
30
+ # Các mô hình
31
+ TITLE_MODELS = {
32
+ "PhoBERT Encoder-Decoder": {
33
+ "model_path": "PuppetLover/Title_generator",
34
+ "tokenizer_path": "vinai/phobert-base-v2",
35
+ "token": True,
36
+ "model_type": "encoder-decoder"
37
+ },
38
+ "ViT5 Title Generator": {
39
+ "model_path": "HTThuanHcmus/vit5-base-vietnews-summarization-finetune",
40
+ "tokenizer_path": "HTThuanHcmus/vit5-base-vietnews-summarization-finetune",
41
+ "token": False,
42
+ "model_type": "seq2seq"
43
+ },
44
+ "BARTpho Title Generator": {
45
+ "model_path": "HTThuanHcmus/bartpho-finetune",
46
+ "tokenizer_path": "HTThuanHcmus/bartpho-finetune",
47
+ "token": False,
48
+ "model_type": "seq2seq"
49
+ },
50
+ "Gemini Title Generator": {
51
+ # "model_path": "gemini-1.5-pro",
52
+ "model_path" : "gemini-1.5-flash",
53
+ "tokenizer_path": None,
54
+ "token": False,
55
+ "model_type": "gemini"
56
+ }
57
+ }
58
+
59
+ SUMMARIZATION_MODELS = {
60
+ "ViT5 Summarization": {
61
+ "model_path": "HTThuanHcmus/vit5-summarization-news-finetune",
62
+ "tokenizer_path": "HTThuanHcmus/vit5-summarization-news-finetune",
63
+ "token": False,
64
+ "model_type": "seq2seq"
65
+ },
66
+ "BARTpho Summarization": {
67
+ "model_path": "HTThuanHcmus/bartpho-summarization-news-finetune",
68
+ "tokenizer_path": "HTThuanHcmus/bartpho-summarization-news-finetune",
69
+ "token": False,
70
+ "model_type": "seq2seq"
71
+ },
72
+ "Gemini Summarization": {
73
+ "model_path": "gemini-1.5-pro",
74
+ "tokenizer_path": None,
75
+ "token": False,
76
+ "model_type": "gemini"
77
+ }
78
+ }
79
+
80
+ # Cache load model/tokenizer
81
+ @st.cache_resource
82
+ def load_model_and_tokenizer(model_path, tokenizer_path, model_type, token=False):
83
+ if model_type == "gemini":
84
+ model = genai.GenerativeModel(model_path)
85
+ return model, None
86
+ token_arg = None
87
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False) if tokenizer_path else None
88
+ if model_type == "encoder-decoder":
89
+ model = EncoderDecoderModel.from_pretrained(model_path, token=token_arg)
90
+ elif model_type == "seq2seq":
91
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, token=token_arg)
92
+ else:
93
+ raise ValueError(f"Unsupported model type: {model_type}")
94
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
95
+ return model, tokenizer
96
+
97
+ # Hàm xử lý Gemini
98
+ def generate_with_gemini(model, text, task):
99
+ prompt = (
100
+ f"Với tư cách một chuyên gia hãy tạo tiêu đề ngắn gọn cho văn bản sau: {text}" if task == "Sinh tiêu đề"
101
+ else f"Vơi tư cách một chuyên gia hãy tạo tóm tắt cho văn bản: {text}"
102
+ )
103
+ response = model.generate_content(prompt)
104
+ return response.text.strip()
105
+
106
+ # Init session state
107
+ if "history" not in st.session_state:
108
+ st.session_state.history = []
109
+ if "show_sidebar" not in st.session_state:
110
+ st.session_state.show_sidebar = False
111
+ if "selected_history_index" not in st.session_state:
112
+ st.session_state.selected_history_index = None
113
+ if "current_generated" not in st.session_state:
114
+ st.session_state.current_generated = None
115
+ if "current_task" not in st.session_state:
116
+ st.session_state.current_task = None
117
+
118
+ # Sidebar
119
+ with st.sidebar:
120
+ if st.button("🧾 Hiện/Ẩn lịch sử"):
121
+ st.session_state.show_sidebar = not st.session_state.show_sidebar
122
+
123
+ if st.session_state.show_sidebar:
124
+ with st.sidebar:
125
+ st.markdown("### 🕓 Lịch sử")
126
+ if not st.session_state.history:
127
+ st.write("Chưa có lịch sử nào.")
128
+ else:
129
+ if st.button("🗑️ Xóa tất cả lịch sử"):
130
+ st.session_state.history = []
131
+ st.session_state.selected_history_index = None
132
+ st.rerun()
133
+
134
+ for idx, history_item in enumerate(st.session_state.history):
135
+ col1, col2 = st.columns([4, 1])
136
+ with col1:
137
+ # Rút gọn câu đầu để hiển thị
138
+ short_preview = history_item['title'].split('.')[0][:60]
139
+ if len(history_item['title']) > 60:
140
+ short_preview += "..."
141
+ if st.button(f"- {short_preview}", key=f"history_{idx}"):
142
+ st.session_state.selected_history_index = idx
143
+ st.session_state.current_generated = None
144
+ with col2:
145
+ if st.button("🗑️", key=f"delete_{idx}"):
146
+ st.session_state.history.pop(idx)
147
+ if st.session_state.selected_history_index == idx:
148
+ st.session_state.selected_history_index = None
149
+ st.rerun()
150
+
151
+
152
+ # Một chút CSS
153
+ st.markdown("""
154
+ <style>
155
+ body {
156
+ background-color: #0e1117;
157
+ color: #ffffff;
158
+ }
159
+ textarea {
160
+ background-color: #1e1e1e !important;
161
+ color: #ffffff !important;
162
+ font-family: 'Courier New', monospace;
163
+ border: 1px solid #ffffff30 !important;
164
+ border-radius: 10px !important;
165
+ }
166
+ .stButton > button {
167
+ background: linear-gradient(90deg, #4b6cb7 0%, #182848 100%);
168
+ color: white;
169
+ border: none;
170
+ border-radius: 8px;
171
+ padding: 10px 20px;
172
+ margin-top: 10px;
173
+ font-weight: bold;
174
+ transition: all 0.3s ease;
175
+ }
176
+ .stButton > button:hover {
177
+ background: linear-gradient(90deg, #1e3c72 0%, #2a5298 100%);
178
+ transform: scale(1.02);
179
+ }
180
+ div[role="radiogroup"] label {
181
+ margin-right: 15px;
182
+ background-color: #2c2f36;
183
+ padding: 8px 15px;
184
+ border-radius: 5px;
185
+ cursor: pointer;
186
+ }
187
+ div[role="radiogroup"] input:checked + label {
188
+ background-color: #0078FF;
189
+ color: white;
190
+ }
191
+ .block-container {
192
+ padding-top: 1rem;
193
+ padding-bottom: 1rem;
194
+ padding-left: 2rem;
195
+ padding-right: 2rem;
196
+ }
197
+ .card {
198
+ background-color: #1e1e1e;
199
+ padding: 20px;
200
+ border-radius: 12px;
201
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3);
202
+ margin-bottom: 20px;
203
+ }
204
+ </style>
205
+ """, unsafe_allow_html=True)
206
+
207
+ # Main App
208
+ st.markdown("""
209
+ <h1 style='text-align: center; color: white; font-family: "Segoe UI", sans-serif;'>
210
+ Trình Sinh Tiêu Đề & Tóm Tắt
211
+ </h1>
212
+ """, unsafe_allow_html=True)
213
+
214
+ task_option = st.radio(
215
+ "Chọn chức năng bạn muốn:",
216
+ ('Sinh tiêu đề', 'Tóm tắt nội dung'),
217
+ horizontal=True,
218
+ key="task_selection"
219
+ )
220
+
221
+ selected_model_key = None
222
+ model_config = None
223
+
224
+ if task_option == 'Sinh tiêu đề':
225
+ selected_model_key = st.selectbox(
226
+ "Chọn mô hình sinh tiêu đề:",
227
+ list(TITLE_MODELS.keys()),
228
+ key="title_model_selector"
229
+ )
230
+ model_config = TITLE_MODELS[selected_model_key]
231
+
232
+ elif task_option == 'Tóm tắt nội dung':
233
+ selected_model_key = st.selectbox(
234
+ "Chọn mô hình tóm tắt:",
235
+ list(SUMMARIZATION_MODELS.keys()),
236
+ key="summary_model_selector"
237
+ )
238
+ model_config = SUMMARIZATION_MODELS[selected_model_key]
239
+
240
+ # Upload file
241
+ uploaded_file = st.file_uploader("Hoặc tải lên file (.txt, .docx):", type=["txt", "docx"])
242
+
243
+ if uploaded_file:
244
+ file_name = uploaded_file.name
245
+ if file_name.endswith(".txt"):
246
+ text_input = uploaded_file.read().decode("utf-8")
247
+ elif file_name.endswith(".docx"):
248
+ from docx import Document
249
+ doc = Document(uploaded_file)
250
+ text_input = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
251
+ st.text_area("Nội dung file đã tải lên:", value=text_input, height=200, key="text_input_area", disabled=True)
252
+ else:
253
+ text_input = st.text_area("Nhập đoạn văn của bạn:", height=200, key="text_input_area")
254
+
255
+ # Nút bấm sau phần nhập văn bản
256
+ button_label = f"{task_option}"
257
+ if st.button(button_label, key="generate_button"):
258
+ if not model_config:
259
+ st.warning("Vui lòng chọn mô hình.")
260
+ elif not text_input.strip():
261
+ st.warning("Vui lòng nhập văn bản hoặc tải file lên.")
262
+ else:
263
+ model, tokenizer = load_model_and_tokenizer(
264
+ model_config["model_path"],
265
+ model_config["tokenizer_path"],
266
+ model_config["model_type"],
267
+ model_config.get("token", False)
268
+ )
269
+
270
+ if model:
271
+ if model_config["model_type"] == "gemini":
272
+ processed_text = clean_text(text_input)
273
+ try:
274
+ with st.spinner(f"⏳ Đang {task_option.lower()} với mô hình '{selected_model_key}'..."):
275
+ result = generate_with_gemini(model, processed_text, task_option)
276
+
277
+ st.session_state.current_generated = result
278
+ st.session_state.current_task = task_option
279
+
280
+ st.session_state.history.append({
281
+ "title": result,
282
+ "input_text": text_input,
283
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
284
+ "model_name": selected_model_key
285
+ })
286
+ st.session_state.selected_history_index = None
287
+ st.rerun()
288
+ except Exception as e:
289
+ st.error(f"Đã xảy ra lỗi với Gemini: {e}")
290
+ print(f"Error during Gemini processing: {e}")
291
+ else:
292
+ if model_config["model_type"] == "encoder-decoder":
293
+ processed_text = clean_text(text_input)
294
+ processed_text = segment_text(processed_text)
295
+ else:
296
+ processed_text = clean_text(text_input)
297
+
298
+
299
+ try:
300
+ inputs = tokenizer(
301
+ processed_text,
302
+ padding="max_length",
303
+ truncation=True,
304
+ max_length=256,
305
+ return_tensors="pt"
306
+ )
307
+ device = "cuda" if torch.cuda.is_available() else "cpu"
308
+ inputs = {key: value.to(device) for key, value in inputs.items()}
309
+
310
+ with st.spinner(f"⏳ Đang {task_option.lower()} với mô hình '{selected_model_key}'..."):
311
+ with torch.no_grad():
312
+ outputs = model.generate(
313
+ inputs["input_ids"],
314
+ max_length=80 if task_option == 'Sinh tiêu đề' else 200,
315
+ num_beams=5,
316
+ early_stopping=True,
317
+ no_repeat_ngram_size=2
318
+ )
319
+ result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
320
+ result = result.replace("_", " ")
321
+
322
+ st.session_state.current_generated = result
323
+ st.session_state.current_task = task_option
324
+
325
+ st.session_state.history.append({
326
+ "title": result,
327
+ "input_text": text_input,
328
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
329
+ "model_name": selected_model_key
330
+ })
331
+ st.session_state.selected_history_index = None
332
+ st.rerun()
333
+ except Exception as e:
334
+ st.error(f"Đã xảy ra lỗi: {e}")
335
+ print(f"Error during processing: {e}")
336
+
337
+ # Hiển thị kết quả sinh mới
338
+ if st.session_state.current_generated:
339
+ st.markdown("---")
340
+ label_text = "Tiêu đề được tạo:" if st.session_state.current_task == 'Sinh tiêu đề' else "Nội dung tóm tắt:"
341
+ st.markdown(f"<h3 style='color: #cccccc;'>{label_text}</h3>", unsafe_allow_html=True)
342
+ st.markdown(f"<p style='color: white; background-color: #2a2a2a; padding: 10px; border-radius: 5px;'>"
343
+ f"{st.session_state.current_generated}</p>", unsafe_allow_html=True)
344
+
345
+ # Hiển thị lịch sử
346
+ if st.session_state.selected_history_index is not None and st.session_state.selected_history_index < len(st.session_state.history):
347
+ selected_history = st.session_state.history[st.session_state.selected_history_index]
348
+ st.markdown("---")
349
+ st.markdown(f"<h3 style='color: #cccccc;'>Kết quả đã tạo:</h3>", unsafe_allow_html=True)
350
+
351
+ if f"show_full_input_{st.session_state.selected_history_index}" not in st.session_state:
352
+ st.session_state[f"show_full_input_{st.session_state.selected_history_index}"] = False
353
+
354
+ show_full = st.session_state[f"show_full_input_{st.session_state.selected_history_index}"]
355
+
356
+ input_text_to_display = selected_history['input_text'] if show_full else (selected_history['input_text'][:1000] + "..." if len(selected_history['input_text']) > 1000 else selected_history['input_text'])
357
+
358
+ st.markdown(f"""
359
+ <div style='color: white; background-color: #2a2a2a; padding: 10px; border-radius: 5px;'>
360
+ <b>Model:</b> {selected_history['model_name']}<br>
361
+ <b>Thời gian:</b> {selected_history['timestamp']}<br><br>
362
+ <b>Văn bản gốc:</b><br>
363
+ <div style='background-color: #3a3a3a; padding: 8px; border-radius: 5px; margin-bottom: 10px;'>{input_text_to_display}</div>
364
+ """, unsafe_allow_html=True)
365
+
366
+ if len(selected_history['input_text']) > 1000:
367
+ if not show_full:
368
+ if st.button("📖 Xem đầy đủ văn bản", key=f"show_full_{st.session_state.selected_history_index}"):
369
+ st.session_state[f"show_full_input_{st.session_state.selected_history_index}"] = True
370
+ st.rerun()
371
+ else:
372
+ if st.button("🔽 Thu gọn văn bản", key=f"collapse_full_{st.session_state.selected_history_index}"):
373
+ st.session_state[f"show_full_input_{st.session_state.selected_history_index}"] = False
374
+ st.rerun()
375
+
376
+ st.markdown(f"""
377
+ <b>Kết quả:</b><br>
378
+ <div style='background-color: #3a3a3a; padding: 8px; border-radius: 5px;'>{selected_history['title']}</div>
379
+ </div>
380
+ """, unsafe_allow_html=True)