import streamlit as st import os import time from datetime import datetime from underthesea import word_tokenize from transformers import EncoderDecoderModel, AutoModelForSeq2SeqLM, AutoTokenizer import torch import logging import transformers import google.generativeai as genai from utils.preprocessing import clean_text, segment_text import asyncio try: asyncio.get_running_loop() except RuntimeError: asyncio.set_event_loop(asyncio.new_event_loop()) # Giảm bớt cảnh báo logging.getLogger('streamlit.runtime.scriptrunner.script_run_context').setLevel(logging.ERROR) transformers.logging.set_verbosity_error() # Cấu hình Streamlit st.set_page_config(page_title="Trình sinh tiêu đề", layout="centered") # Cấu hình Gemini API (thay YOUR_GEMINI_API_KEY bằng API key thực tế) GEMINI_API_KEY = "AIzaSyCEDRquPDC9N09hTHGD9FfvsPP83AZT78Q" # Thay bằng API key thực tế của bạn genai.configure(api_key=GEMINI_API_KEY) # Các mô hình TITLE_MODELS = { "PhoBERT Encoder-Decoder": { "model_path": "PuppetLover/Title_generator", "tokenizer_path": "vinai/phobert-base-v2", "token": True, "model_type": "encoder-decoder" }, "ViT5 Title Generator": { "model_path": "HTThuanHcmus/vit5-base-vietnews-summarization-finetune", "tokenizer_path": "HTThuanHcmus/vit5-base-vietnews-summarization-finetune", "token": False, "model_type": "seq2seq" }, "BARTpho Title Generator": { "model_path": "HTThuanHcmus/bartpho-finetune", "tokenizer_path": "HTThuanHcmus/bartpho-finetune", "token": False, "model_type": "seq2seq" }, "Gemini Title Generator": { # "model_path": "gemini-1.5-pro", "model_path" : "gemini-1.5-flash", "tokenizer_path": None, "token": False, "model_type": "gemini" } } SUMMARIZATION_MODELS = { "ViT5 Summarization": { "model_path": "HTThuanHcmus/vit5-summarization-news-finetune", "tokenizer_path": "HTThuanHcmus/vit5-summarization-news-finetune", "token": False, "model_type": "seq2seq" }, "BARTpho Summarization": { "model_path": "HTThuanHcmus/bartpho-summarization-news-finetune", "tokenizer_path": "HTThuanHcmus/bartpho-summarization-news-finetune", "token": False, "model_type": "seq2seq" }, "Gemini Summarization": { "model_path": "gemini-1.5-pro", "tokenizer_path": None, "token": False, "model_type": "gemini" } } # Cache load model/tokenizer @st.cache_resource def load_model_and_tokenizer(model_path, tokenizer_path, model_type, token=False): if model_type == "gemini": model = genai.GenerativeModel(model_path) return model, None token_arg = None tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False) if tokenizer_path else None if model_type == "encoder-decoder": model = EncoderDecoderModel.from_pretrained(model_path, token=token_arg) elif model_type == "seq2seq": model = AutoModelForSeq2SeqLM.from_pretrained(model_path, token=token_arg) else: raise ValueError(f"Unsupported model type: {model_type}") model.to("cuda" if torch.cuda.is_available() else "cpu") return model, tokenizer # Hàm xử lý Gemini def generate_with_gemini(model, text, task): prompt = ( f"Với tư cách một chuyên gia hãy tạo tiêu đề ngắn gọn cho văn bản sau: {text}" if task == "Sinh tiêu đề" else f"Vơi tư cách một chuyên gia hãy tạo tóm tắt cho văn bản: {text}" ) response = model.generate_content(prompt) return response.text.strip() # Init session state if "history" not in st.session_state: st.session_state.history = [] if "show_sidebar" not in st.session_state: st.session_state.show_sidebar = False if "selected_history_index" not in st.session_state: st.session_state.selected_history_index = None if "current_generated" not in st.session_state: st.session_state.current_generated = None if "current_task" not in st.session_state: st.session_state.current_task = None # Sidebar with st.sidebar: if st.button("🧾 Hiện/Ẩn lịch sử"): st.session_state.show_sidebar = not st.session_state.show_sidebar if st.session_state.show_sidebar: with st.sidebar: st.markdown("### 🕓 Lịch sử") if not st.session_state.history: st.write("Chưa có lịch sử nào.") else: if st.button("🗑️ Xóa tất cả lịch sử"): st.session_state.history = [] st.session_state.selected_history_index = None st.rerun() for idx, history_item in enumerate(st.session_state.history): col1, col2 = st.columns([4, 1]) with col1: # Rút gọn câu đầu để hiển thị short_preview = history_item['title'].split('.')[0][:60] if len(history_item['title']) > 60: short_preview += "..." if st.button(f"- {short_preview}", key=f"history_{idx}"): st.session_state.selected_history_index = idx st.session_state.current_generated = None with col2: if st.button("🗑️", key=f"delete_{idx}"): st.session_state.history.pop(idx) if st.session_state.selected_history_index == idx: st.session_state.selected_history_index = None st.rerun() # Một chút CSS st.markdown(""" """, unsafe_allow_html=True) # Main App st.markdown("""
" f"{st.session_state.current_generated}
", unsafe_allow_html=True) # Hiển thị lịch sử if st.session_state.selected_history_index is not None and st.session_state.selected_history_index < len(st.session_state.history): selected_history = st.session_state.history[st.session_state.selected_history_index] st.markdown("---") st.markdown(f"