MFBDA commited on
Commit
726808d
·
verified ·
1 Parent(s): 8d08465

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -2
app.py CHANGED
@@ -1,7 +1,12 @@
 
1
  import pandas as pd
2
  from sklearn.model_selection import train_test_split
3
- from transformers import AutoTokenizer
 
 
 
4
 
 
5
  # Carregar os dados
6
  df = pd.read_csv("files/dados.csv")
7
 
@@ -25,4 +30,90 @@ tokenizer = AutoTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased
25
  train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128)
26
  test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=128)
27
 
28
- # Restante do código para criar o dataset e fine-tuning...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importações necessárias
2
  import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
5
+ from torch.utils.data import Dataset
6
+ import torch
7
+ import gradio as gr
8
 
9
+ # === PASSO 1: CARREGAR E PRÉ-PROCESSAR OS DADOS ===
10
  # Carregar os dados
11
  df = pd.read_csv("files/dados.csv")
12
 
 
30
  train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128)
31
  test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=128)
32
 
33
+ # Criar um dataset personalizado
34
+ class CustomDataset(Dataset):
35
+ def __init__(self, encodings, labels):
36
+ self.encodings = encodings
37
+ self.labels = labels
38
+
39
+ def __len__(self):
40
+ return len(self.labels)
41
+
42
+ def __getitem__(self, idx):
43
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
44
+ item["labels"] = torch.tensor(self.labels[idx])
45
+ return item
46
+
47
+ # Criar datasets
48
+ train_dataset = CustomDataset(train_encodings, train_labels)
49
+ test_dataset = CustomDataset(test_encodings, test_labels)
50
+
51
+ # === PASSO 2: FINE-TUNING DO MODELO ===
52
+ # Carregar o modelo pré-treinado para classificação
53
+ model = AutoModelForSequenceClassification.from_pretrained(
54
+ "neuralmind/bert-base-portuguese-cased",
55
+ num_labels=3 # Número de classes (Baixa, Média, Alta)
56
+ )
57
+
58
+ # Configurar os argumentos de treinamento
59
+ training_args = TrainingArguments(
60
+ output_dir="./results",
61
+ evaluation_strategy="epoch",
62
+ learning_rate=2e-5,
63
+ per_device_train_batch_size=8,
64
+ per_device_eval_batch_size=8,
65
+ num_train_epochs=3,
66
+ weight_decay=0.01,
67
+ logging_dir="./logs",
68
+ logging_steps=10,
69
+ save_strategy="epoch"
70
+ )
71
+
72
+ # Criar o Trainer
73
+ trainer = Trainer(
74
+ model=model,
75
+ args=training_args,
76
+ train_dataset=train_dataset,
77
+ eval_dataset=test_dataset
78
+ )
79
+
80
+ # Fine-tune o modelo
81
+ print("Iniciando o fine-tuning do modelo...")
82
+ trainer.train()
83
+ print("Fine-tuning concluído!")
84
+
85
+ # Salvar o modelo ajustado
86
+ model.save_pretrained("./modelo-ajustado")
87
+ tokenizer.save_pretrained("./modelo-ajustado")
88
+
89
+ # === PASSO 3: INTEGRAR COM GRADIO ===
90
+ # Carregar o modelo ajustado
91
+ classifier = pipeline("text-classification", model="./modelo-ajustado")
92
+
93
+ # Função para classificar a criticidade
94
+ def classificar_criticidade(descricao):
95
+ resultado = classifier(descricao)[0]
96
+ label = resultado['label']
97
+ score = resultado['score']
98
+
99
+ # Mapear os rótulos ajustados
100
+ if label == "LABEL_0":
101
+ return f"Criticidade: Baixa (Confiança: {score:.2f})"
102
+ elif label == "LABEL_1":
103
+ return f"Criticidade: Média (Confiança: {score:.2f})"
104
+ elif label == "LABEL_2":
105
+ return f"Criticidade: Alta (Confiança: {score:.2f})"
106
+ else:
107
+ return "Não foi possível determinar a criticidade."
108
+
109
+ # Interface Gradio
110
+ interface = gr.Interface(
111
+ fn=classificar_criticidade,
112
+ inputs=gr.Textbox(lines=2, placeholder="Descreva a compra..."),
113
+ outputs="text",
114
+ title="Classificador de Criticidade de Compra",
115
+ description="Insira a descrição da compra para receber uma classificação de criticidade."
116
+ )
117
+
118
+ # Iniciar a interface
119
+ interface.launch()