mrstarkng's picture
update cpu app.py
e466ef8 verified
# app.py
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import os
# --- 1. CẤU HÌNH MODEL VÀ THIẾT BỊ ---
BASE_MODEL_ID = "VietAI/vit5-base"
ADAPTER_ID = "mrstarkng/financial-summarization-vit5-lora"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- 2. TẢI MODEL (Chỉ một lần khi app khởi động) ---
print(f"Loading model on {DEVICE}...")
# Tải tokenizer từ model nền
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
# Tải model nền
base_model = AutoModelForSeq2SeqLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32, # Dùng float32 cho CPU
).to(DEVICE)
# Tải và áp dụng adapter LoRA từ Hub của bạn
model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
model.eval() # Chuyển sang chế độ đánh giá
print("✅ Model is ready!")
# --- 3. HÀM XỬ LÝ LOGIC TÓM TẮT ---
def summarize(text_to_summarize, max_length=256, num_beams=5):
"""Hàm nhận văn bản đầu vào và trả về bản tóm tắt."""
if not text_to_summarize or not text_to_summarize.strip():
return "Lỗi: Vui lòng nhập văn bản để tóm tắt."
try:
# Model ViT5 cần tiền tố "summarize: "
input_text = "summarize: " + text_to_summarize
inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=int(max_length),
num_beams=int(num_beams),
early_stopping=True,
repetition_penalty=2.5, # Thêm để tránh lặp từ
length_penalty=1.0
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
except Exception as e:
print(f"Error during summarization: {e}")
return f"Đã có lỗi xảy ra: {e}"
# --- 4. TẠO GIAO DIỆN WEB VỚI GRADIO ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown(
"""
# 🤖 Tóm tắt Tin tức Tài chính - ViT5 LoRA
Ứng dụng demo cho model ViT5 được fine-tune bằng LoRA để tóm tắt tin tức tài chính-kinh tế tiếng Việt.
Được phát triển bởi Tony Nguyen.
"""
)
with gr.Row():
with gr.Column(scale=3):
text_input = gr.Textbox(lines=20, label="Văn bản gốc", placeholder="Dán một bài báo tài chính vào đây...")
with gr.Column(scale=2):
text_output = gr.Textbox(lines=20, label="Bản tóm tắt", interactive=False)
with gr.Row():
submit_button = gr.Button("Tóm tắt", variant="primary")
with gr.Accordion("Cài đặt nâng cao", open=False):
max_len_slider = gr.Slider(minimum=32, maximum=512, value=256, step=16, label="Độ dài tối đa (Max Length)")
num_beams_slider = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Số beam (Num Beams)")
submit_button.click(
fn=summarize,
inputs=[text_input, max_len_slider, num_beams_slider],
outputs=text_output
)
gr.Examples(
examples=[
"DIC: Tổng công ty Cổ phần Đầu tư Phát triển Xây dựng (DIC Corp, MCK: DIG, sàn HoSE) vừa công bố các văn bản báo cáo về kết quả giao dịch cổ phiếu của người nội bộ và người có liên quan của người nội bộ. Theo đó, trong phiên giao dịch ngày 22/4/2025 và 23/4/2025, hơn 4,7 triệu cổ phiếu DIG của ông Nguyễn Hùng Cường- Chủ tịch HĐQT DIC Corp đã bị công ty chứng khoán bán giải chấp...",
"Hôm 22/4, Thủ tướng Phạm Minh Chính ký Công điện 47 về giải pháp trọng tâm thúc đẩy tăng trưởng kinh tế năm 2025. Trong đó, lãnh đạo Chính phủ yêu cầu Bộ Công Thương, địa phương có giải pháp tăng kết nối, kích cầu tiêu dùng nội địa. Công điện được ban hành trong bối cảnh trụ cột xuất khẩu chịu áp lực từ các biến động thuế quan..."
],
inputs=text_input
)
# --- 5. CHẠY APP ---
# Dòng này để chạy trên local, khi deploy lên Spaces nó sẽ tự động chạy
# demo.launch()