practic / app.py
SimrusDenuvo's picture
Update app.py
e78c1cb verified
raw
history blame
2.81 kB
from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import pandas as pd
import numpy as np
import gradio as gr
# ❗ Загрузка датасета ZhenDOS/alpha_bank_data
dataset = load_dataset("ZhenDOS/alpha_bank_data")
# ✔️ Загрузка базовой модели и токенайзера
tokenizer = BertTokenizerFast.from_pretrained("DeepPavlov/rubert-base-cased")
model = BertForSequenceClassification.from_pretrained("DeepPavlov/rubert-base-cased", num_labels=len(dataset["train"].features["label"].names))
# ➕ Токенизация входных данных
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# 🏃‍♂️ Настройки обучения
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
num_train_epochs=3,
weight_decay=0.01,
)
# 💨 Процесс обучения
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
)
trainer.train()
# 📊 Функционал для демонстрации через Gradio
def classify_question(question):
tokens = tokenizer(question, return_tensors="pt")
outputs = model(**tokens)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_label_idx = torch.argmax(probabilities, dim=1).item()
categories = dataset["train"].features["label"].names
return {
"Вероятности классов": dict(zip(categories, probabilities.detach().numpy()[0])),
"Прогнозируемый класс": categories[pred_label_idx],
}
# 🖥️ Графический интерфейс Gradio
demo = gr.Interface(
fn=classify_question,
inputs="text",
outputs=[
gr.Label(label="Категории"),
gr.Textbox(label="Прогнозируемый класс"),
],
examples=[
["Как перевести деньги между картами?"],
["Что такое кредитная история?"],
["Почему моя карта заблокирована?"],
],
title="Классификация клиентских запросов банка",
description="Приложение помогает определить категорию клиентского запроса и оценить вероятность принадлежности каждого класса.",
)
demo.launch()