HTThuanHcmus commited on
Commit
f19cca7
·
verified ·
1 Parent(s): 66e8fc3

Upload 4 files

Browse files
src/.streamlit/config.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base="dark"
3
+ primaryColor="#ffffff" # Màu chủ đạo trắng
4
+ backgroundColor="#121212" # Nền đen hiện đại
5
+ secondaryBackgroundColor="#1e1e1e" # Màu nền phụ tối
6
+ textColor="#ffffff" # Chữ trắng
7
+ font="sans serif"
src/streamlit_app.py CHANGED
@@ -1,40 +1,374 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import time
4
+ from datetime import datetime
5
+ from underthesea import word_tokenize
6
+ from transformers import EncoderDecoderModel, AutoModelForSeq2SeqLM, AutoTokenizer
7
+ import torch
8
+ import logging
9
+ import transformers
10
+ import google.generativeai as genai
11
+ from utils.preprocessing import clean_text, segment_text
12
 
13
+
14
+ # Giảm bớt cảnh báo
15
+ logging.getLogger('streamlit.runtime.scriptrunner.script_run_context').setLevel(logging.ERROR)
16
+ transformers.logging.set_verbosity_error()
17
+
18
+ # Cấu hình Streamlit
19
+ st.set_page_config(page_title="Trình sinh tiêu đề", layout="centered")
20
+
21
+ # Cấu hình Gemini API (thay YOUR_GEMINI_API_KEY bằng API key thực tế)
22
+ GEMINI_API_KEY = "AIzaSyCAAhuSX60JYbS8eSRa_0dRZBri0mqUr_M" # Thay bằng API key thực tế của bạn
23
+ genai.configure(api_key=GEMINI_API_KEY)
24
+
25
+ # Các mô hình
26
+ TITLE_MODELS = {
27
+ "PhoBERT Encoder-Decoder": {
28
+ "model_path": "PuppetLover/Title_generator",
29
+ "tokenizer_path": "vinai/phobert-base-v2",
30
+ "token": True,
31
+ "model_type": "encoder-decoder"
32
+ },
33
+ "ViT5 Title Generator": {
34
+ "model_path": "HTThuanHcmus/vit5-base-vietnews-summarization-finetune",
35
+ "tokenizer_path": "HTThuanHcmus/vit5-base-vietnews-summarization-finetune",
36
+ "token": False,
37
+ "model_type": "seq2seq"
38
+ },
39
+ "BARTpho Title Generator": {
40
+ "model_path": "HTThuanHcmus/bartpho-finetune",
41
+ "tokenizer_path": "HTThuanHcmus/bartpho-finetune",
42
+ "token": False,
43
+ "model_type": "seq2seq"
44
+ },
45
+ "Gemini Title Generator": {
46
+ "model_path": "gemini-1.5-pro",
47
+ "tokenizer_path": None,
48
+ "token": False,
49
+ "model_type": "gemini"
50
+ }
51
+ }
52
+
53
+ SUMMARIZATION_MODELS = {
54
+ "ViT5 Summarization": {
55
+ "model_path": "HTThuanHcmus/vit5-summarization-news-finetune",
56
+ "tokenizer_path": "HTThuanHcmus/vit5-summarization-news-finetune",
57
+ "token": False,
58
+ "model_type": "seq2seq"
59
+ },
60
+ "BARTpho Summarization": {
61
+ "model_path": "HTThuanHcmus/bartpho-summarization-news-finetune",
62
+ "tokenizer_path": "HTThuanHcmus/bartpho-summarization-news-finetune",
63
+ "token": False,
64
+ "model_type": "seq2seq"
65
+ },
66
+ "Gemini Summarization": {
67
+ "model_path": "gemini-1.5-pro",
68
+ "tokenizer_path": None,
69
+ "token": False,
70
+ "model_type": "gemini"
71
+ }
72
+ }
73
+
74
+ # Cache load model/tokenizer
75
+ @st.cache_resource
76
+ def load_model_and_tokenizer(model_path, tokenizer_path, model_type, token=False):
77
+ if model_type == "gemini":
78
+ model = genai.GenerativeModel(model_path)
79
+ return model, None
80
+ token_arg = None
81
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False) if tokenizer_path else None
82
+ if model_type == "encoder-decoder":
83
+ model = EncoderDecoderModel.from_pretrained(model_path, token=token_arg)
84
+ elif model_type == "seq2seq":
85
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, token=token_arg)
86
+ else:
87
+ raise ValueError(f"Unsupported model type: {model_type}")
88
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
89
+ return model, tokenizer
90
+
91
+ # Hàm xử lý Gemini
92
+ def generate_with_gemini(model, text, task):
93
+ prompt = (
94
+ f"Với tư cách một chuyên gia hãy tạo tiêu đề ngắn gọn cho văn bản sau: {text}" if task == "Sinh tiêu đề"
95
+ else f"Vơi tư cách một chuyên gia hãy tạo tóm tắt cho văn bản: {text}"
96
+ )
97
+ response = model.generate_content(prompt)
98
+ return response.text.strip()
99
+
100
+ # Init session state
101
+ if "history" not in st.session_state:
102
+ st.session_state.history = []
103
+ if "show_sidebar" not in st.session_state:
104
+ st.session_state.show_sidebar = False
105
+ if "selected_history_index" not in st.session_state:
106
+ st.session_state.selected_history_index = None
107
+ if "current_generated" not in st.session_state:
108
+ st.session_state.current_generated = None
109
+ if "current_task" not in st.session_state:
110
+ st.session_state.current_task = None
111
+
112
+ # Sidebar
113
+ with st.sidebar:
114
+ if st.button("🧾 Hiện/Ẩn lịch sử"):
115
+ st.session_state.show_sidebar = not st.session_state.show_sidebar
116
+
117
+ if st.session_state.show_sidebar:
118
+ with st.sidebar:
119
+ st.markdown("### 🕓 Lịch sử")
120
+ if not st.session_state.history:
121
+ st.write("Chưa có lịch sử nào.")
122
+ else:
123
+ if st.button("🗑️ Xóa tất cả lịch sử"):
124
+ st.session_state.history = []
125
+ st.session_state.selected_history_index = None
126
+ st.rerun()
127
+
128
+ for idx, history_item in enumerate(st.session_state.history):
129
+ col1, col2 = st.columns([4, 1])
130
+ with col1:
131
+ # Rút gọn câu đầu để hiển thị
132
+ short_preview = history_item['title'].split('.')[0][:60]
133
+ if len(history_item['title']) > 60:
134
+ short_preview += "..."
135
+ if st.button(f"- {short_preview}", key=f"history_{idx}"):
136
+ st.session_state.selected_history_index = idx
137
+ st.session_state.current_generated = None
138
+ with col2:
139
+ if st.button("🗑️", key=f"delete_{idx}"):
140
+ st.session_state.history.pop(idx)
141
+ if st.session_state.selected_history_index == idx:
142
+ st.session_state.selected_history_index = None
143
+ st.rerun()
144
+
145
+
146
+ # Một chút CSS
147
+ st.markdown("""
148
+ <style>
149
+ body {
150
+ background-color: #0e1117;
151
+ color: #ffffff;
152
+ }
153
+ textarea {
154
+ background-color: #1e1e1e !important;
155
+ color: #ffffff !important;
156
+ font-family: 'Courier New', monospace;
157
+ border: 1px solid #ffffff30 !important;
158
+ border-radius: 10px !important;
159
+ }
160
+ .stButton > button {
161
+ background: linear-gradient(90deg, #4b6cb7 0%, #182848 100%);
162
+ color: white;
163
+ border: none;
164
+ border-radius: 8px;
165
+ padding: 10px 20px;
166
+ margin-top: 10px;
167
+ font-weight: bold;
168
+ transition: all 0.3s ease;
169
+ }
170
+ .stButton > button:hover {
171
+ background: linear-gradient(90deg, #1e3c72 0%, #2a5298 100%);
172
+ transform: scale(1.02);
173
+ }
174
+ div[role="radiogroup"] label {
175
+ margin-right: 15px;
176
+ background-color: #2c2f36;
177
+ padding: 8px 15px;
178
+ border-radius: 5px;
179
+ cursor: pointer;
180
+ }
181
+ div[role="radiogroup"] input:checked + label {
182
+ background-color: #0078FF;
183
+ color: white;
184
+ }
185
+ .block-container {
186
+ padding-top: 1rem;
187
+ padding-bottom: 1rem;
188
+ padding-left: 2rem;
189
+ padding-right: 2rem;
190
+ }
191
+ .card {
192
+ background-color: #1e1e1e;
193
+ padding: 20px;
194
+ border-radius: 12px;
195
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3);
196
+ margin-bottom: 20px;
197
+ }
198
+ </style>
199
+ """, unsafe_allow_html=True)
200
+
201
+ # Main App
202
+ st.markdown("""
203
+ <h1 style='text-align: center; color: white; font-family: "Segoe UI", sans-serif;'>
204
+ Trình Sinh Tiêu Đề & Tóm Tắt
205
+ </h1>
206
+ """, unsafe_allow_html=True)
207
+
208
+ task_option = st.radio(
209
+ "Chọn chức năng bạn muốn:",
210
+ ('Sinh tiêu đề', 'Tóm tắt nội dung'),
211
+ horizontal=True,
212
+ key="task_selection"
213
+ )
214
+
215
+ selected_model_key = None
216
+ model_config = None
217
+
218
+ if task_option == 'Sinh tiêu đề':
219
+ selected_model_key = st.selectbox(
220
+ "Chọn mô hình sinh tiêu đề:",
221
+ list(TITLE_MODELS.keys()),
222
+ key="title_model_selector"
223
+ )
224
+ model_config = TITLE_MODELS[selected_model_key]
225
+
226
+ elif task_option == 'Tóm tắt nội dung':
227
+ selected_model_key = st.selectbox(
228
+ "Chọn mô hình tóm tắt:",
229
+ list(SUMMARIZATION_MODELS.keys()),
230
+ key="summary_model_selector"
231
+ )
232
+ model_config = SUMMARIZATION_MODELS[selected_model_key]
233
+
234
+ # Upload file
235
+ uploaded_file = st.file_uploader("Hoặc tải lên file (.txt, .docx):", type=["txt", "docx"])
236
+
237
+ if uploaded_file:
238
+ file_name = uploaded_file.name
239
+ if file_name.endswith(".txt"):
240
+ text_input = uploaded_file.read().decode("utf-8")
241
+ elif file_name.endswith(".docx"):
242
+ from docx import Document
243
+ doc = Document(uploaded_file)
244
+ text_input = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
245
+ st.text_area("Nội dung file đã tải lên:", value=text_input, height=200, key="text_input_area", disabled=True)
246
+ else:
247
+ text_input = st.text_area("Nhập đoạn văn của bạn:", height=200, key="text_input_area")
248
+
249
+ # Nút bấm sau phần nhập văn bản
250
+ button_label = f"{task_option}"
251
+ if st.button(button_label, key="generate_button"):
252
+ if not model_config:
253
+ st.warning("Vui lòng chọn mô hình.")
254
+ elif not text_input.strip():
255
+ st.warning("Vui lòng nhập văn bản hoặc tải file lên.")
256
+ else:
257
+ model, tokenizer = load_model_and_tokenizer(
258
+ model_config["model_path"],
259
+ model_config["tokenizer_path"],
260
+ model_config["model_type"],
261
+ model_config.get("token", False)
262
+ )
263
+
264
+ if model:
265
+ if model_config["model_type"] == "gemini":
266
+ processed_text = clean_text(text_input)
267
+ try:
268
+ with st.spinner(f"⏳ Đang {task_option.lower()} với mô hình '{selected_model_key}'..."):
269
+ result = generate_with_gemini(model, processed_text, task_option)
270
+
271
+ st.session_state.current_generated = result
272
+ st.session_state.current_task = task_option
273
+
274
+ st.session_state.history.append({
275
+ "title": result,
276
+ "input_text": text_input,
277
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
278
+ "model_name": selected_model_key
279
+ })
280
+ st.session_state.selected_history_index = None
281
+ st.rerun()
282
+ except Exception as e:
283
+ st.error(f"Đã xảy ra lỗi với Gemini: {e}")
284
+ print(f"Error during Gemini processing: {e}")
285
+ else:
286
+ if model_config["model_type"] == "encoder-decoder":
287
+ processed_text = clean_text(text_input)
288
+ processed_text = segment_text(processed_text)
289
+ else:
290
+ processed_text = clean_text(text_input)
291
+
292
+
293
+ try:
294
+ inputs = tokenizer(
295
+ processed_text,
296
+ padding="max_length",
297
+ truncation=True,
298
+ max_length=256,
299
+ return_tensors="pt"
300
+ )
301
+ device = "cuda" if torch.cuda.is_available() else "cpu"
302
+ inputs = {key: value.to(device) for key, value in inputs.items()}
303
+
304
+ with st.spinner(f"⏳ Đang {task_option.lower()} với mô hình '{selected_model_key}'..."):
305
+ with torch.no_grad():
306
+ outputs = model.generate(
307
+ inputs["input_ids"],
308
+ max_length=80 if task_option == 'Sinh tiêu đề' else 200,
309
+ num_beams=5,
310
+ early_stopping=True,
311
+ no_repeat_ngram_size=2
312
+ )
313
+ result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
314
+ result = result.replace("_", " ")
315
+
316
+ st.session_state.current_generated = result
317
+ st.session_state.current_task = task_option
318
+
319
+ st.session_state.history.append({
320
+ "title": result,
321
+ "input_text": text_input,
322
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
323
+ "model_name": selected_model_key
324
+ })
325
+ st.session_state.selected_history_index = None
326
+ st.rerun()
327
+ except Exception as e:
328
+ st.error(f"Đã xảy ra lỗi: {e}")
329
+ print(f"Error during processing: {e}")
330
+
331
+ # Hiển thị kết quả sinh mới
332
+ if st.session_state.current_generated:
333
+ st.markdown("---")
334
+ label_text = "Tiêu đề được tạo:" if st.session_state.current_task == 'Sinh tiêu đề' else "Nội dung tóm tắt:"
335
+ st.markdown(f"<h3 style='color: #cccccc;'>{label_text}</h3>", unsafe_allow_html=True)
336
+ st.markdown(f"<p style='color: white; background-color: #2a2a2a; padding: 10px; border-radius: 5px;'>"
337
+ f"{st.session_state.current_generated}</p>", unsafe_allow_html=True)
338
+
339
+ # Hiển thị lịch sử
340
+ if st.session_state.selected_history_index is not None and st.session_state.selected_history_index < len(st.session_state.history):
341
+ selected_history = st.session_state.history[st.session_state.selected_history_index]
342
+ st.markdown("---")
343
+ st.markdown(f"<h3 style='color: #cccccc;'>Kết quả đã tạo:</h3>", unsafe_allow_html=True)
344
+
345
+ if f"show_full_input_{st.session_state.selected_history_index}" not in st.session_state:
346
+ st.session_state[f"show_full_input_{st.session_state.selected_history_index}"] = False
347
+
348
+ show_full = st.session_state[f"show_full_input_{st.session_state.selected_history_index}"]
349
+
350
+ input_text_to_display = selected_history['input_text'] if show_full else (selected_history['input_text'][:1000] + "..." if len(selected_history['input_text']) > 1000 else selected_history['input_text'])
351
+
352
+ st.markdown(f"""
353
+ <div style='color: white; background-color: #2a2a2a; padding: 10px; border-radius: 5px;'>
354
+ <b>Model:</b> {selected_history['model_name']}<br>
355
+ <b>Thời gian:</b> {selected_history['timestamp']}<br><br>
356
+ <b>Văn bản gốc:</b><br>
357
+ <div style='background-color: #3a3a3a; padding: 8px; border-radius: 5px; margin-bottom: 10px;'>{input_text_to_display}</div>
358
+ """, unsafe_allow_html=True)
359
+
360
+ if len(selected_history['input_text']) > 1000:
361
+ if not show_full:
362
+ if st.button("📖 Xem đầy đủ văn bản", key=f"show_full_{st.session_state.selected_history_index}"):
363
+ st.session_state[f"show_full_input_{st.session_state.selected_history_index}"] = True
364
+ st.rerun()
365
+ else:
366
+ if st.button("🔽 Thu gọn văn bản", key=f"collapse_full_{st.session_state.selected_history_index}"):
367
+ st.session_state[f"show_full_input_{st.session_state.selected_history_index}"] = False
368
+ st.rerun()
369
+
370
+ st.markdown(f"""
371
+ <b>Kết quả:</b><br>
372
+ <div style='background-color: #3a3a3a; padding: 8px; border-radius: 5px;'>{selected_history['title']}</div>
373
+ </div>
374
+ """, unsafe_allow_html=True)
src/utils/__init__.py ADDED
File without changes
src/utils/preprocessing.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import re
3
+ from underthesea import word_tokenize
4
+
5
+ import unicodedata
6
+
7
+ def clean_text(text):
8
+ text = text.replace('\xa0', ' ') # Thay thế non-breaking space
9
+ text = unicodedata.normalize("NFC", text)
10
+ text = re.sub(r'[^\x20-\x7E\u00A0-\u1EF9\u0100-\u017F]', '', text) # Loại bỏ ký tự không thuộc bảng Unicode mở rộng của tiếng Việt
11
+ return text.strip()
12
+
13
+ def segment_text(text):
14
+ return word_tokenize(text, format="text")