Jainesh212 commited on
Commit
5555578
·
1 Parent(s): 371fdf3

Delete finetuning.py

Browse files
Files changed (1) hide show
  1. finetuning.py +0 -214
finetuning.py DELETED
@@ -1,214 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import os
4
- from tqdm.notebook import tqdm
5
- import pandas as pd
6
- from torch import cuda
7
- import torch
8
- import transformers
9
- from torch.utils.data import Dataset, DataLoader
10
- from transformers import DistilBertModel, DistilBertTokenizer
11
- import shutil
12
-
13
- device = 'cuda' if cuda.is_available() else 'cpu'
14
-
15
- label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
16
-
17
- df_train = pd.read_csv("train.csv")
18
-
19
-
20
- MAX_LEN = 512
21
- TRAIN_BATCH_SIZE = 32
22
- VALID_BATCH_SIZE = 32
23
- EPOCHS = 2
24
- LEARNING_RATE = 1e-05
25
-
26
- df_train = df_train.sample(n=512)
27
-
28
-
29
- train_size = 0.8
30
- df_train_sampled = df_train.sample(frac=train_size, random_state=44)
31
- df_val = df_train.drop(df_train_sampled.index).reset_index(drop=True)
32
- df_train_sampled = df_train_sampled.reset_index(drop=True)
33
-
34
- model_name = 'distilbert-base-uncased'
35
-
36
- tokenizer = DistilBertTokenizer.from_pretrained(model_name, do_lower_case=True)
37
-
38
-
39
- class ToxicDataset(Dataset):
40
- def __init__(self, data, tokenizer, max_len):
41
- self.data = data
42
- self.tokenizer = tokenizer
43
- self.max_len = max_len
44
- self.labels = self.data[label_cols].values
45
-
46
- def __len__(self):
47
- return len(self.data.id)
48
-
49
- def __getitem__(self, idx):
50
- text = self.data.comment_text
51
- tokenized_text = self.tokenizer.encode_plus(
52
- str( text ),
53
- None,
54
- add_special_tokens=True,
55
- max_length=self.max_len,
56
- padding='max_length',
57
- return_token_type_ids=True,
58
- truncation=True,
59
- return_attention_mask=True,
60
- return_tensors='pt'
61
- )
62
-
63
- return {
64
- 'input_ids': tokenized_text['input_ids'].flatten(),
65
- 'attention_mask': tokenized_text['attention_mask'].flatten(),
66
- 'targets': torch.FloatTensor(self.labels[idx])
67
- }
68
-
69
- train_dataset = ToxicDataset(df_train_sampled, tokenizer, MAX_LEN)
70
- valid_dataset = ToxicDataset(df_val, tokenizer, MAX_LEN)
71
-
72
- train_data_loader = torch.utils.data.DataLoader(train_dataset,
73
- batch_size=TRAIN_BATCH_SIZE,
74
- shuffle=True,
75
- num_workers=0
76
- )
77
-
78
- val_data_loader = torch.utils.data.DataLoader(valid_dataset,
79
- batch_size=VALID_BATCH_SIZE,
80
- shuffle=False,
81
- num_workers=0
82
- )
83
-
84
-
85
- class CustomDistilBertClass(torch.nn.Module):
86
- def __init__(self):
87
- super(CustomDistilBertClass, self).__init__()
88
- self.distilbert_model = DistilBertModel.from_pretrained(model_name, return_dict=True)
89
- self.dropout = torch.nn.Dropout(0.3)
90
- self.linear = torch.nn.Linear(768, 6)
91
-
92
- def forward(self, input_ids, attn_mask):
93
- output = self.distilbert_model(
94
- input_ids,
95
- attention_mask=attn_mask,
96
- )
97
- output_dropout = self.dropout(output.last_hidden_state)
98
- output = self.linear(output_dropout)
99
- return output
100
-
101
- model = CustomDistilBertClass()
102
- model.to(device)
103
-
104
- def loss_fn(outputs, targets):
105
- return torch.nn.BCEWithLogitsLoss()(outputs, targets)
106
-
107
- optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
108
-
109
- def train_model(n_epochs, training_loader, validation_loader, model,
110
- optimizer, checkpoint_path, best_model_path):
111
-
112
- valid_loss_min = np.Inf
113
-
114
- for epoch in range(1, n_epochs+1):
115
- train_loss = 0
116
- valid_loss = 0
117
-
118
- model.train()
119
- print('############# Epoch {}: Training Start #############'.format(epoch))
120
- for batch_idx, data in enumerate(training_loader):
121
- ids = data['input_ids'].to(device, dtype = torch.long)
122
- mask = data['attention_mask'].to(device, dtype = torch.long)
123
-
124
- outputs = model(ids, mask, )
125
- outputs = outputs[:, 0, :]
126
- targets = data['targets'].to(device, dtype = torch.float)
127
- loss = loss_fn(outputs, targets)
128
-
129
- optimizer.zero_grad()
130
- loss.backward()
131
- optimizer.step()
132
-
133
- train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.item() - train_loss))
134
-
135
- print('############# Epoch {}: Training End #############'.format(epoch))
136
-
137
- print('############# Epoch {}: Validation Start #############'.format(epoch))
138
-
139
- model.eval()
140
-
141
- with torch.no_grad():
142
- for batch_idx, data in enumerate(validation_loader, 0):
143
- ids = data['input_ids'].to(device, dtype = torch.long)
144
- mask = data['attention_mask'].to(device, dtype = torch.long)
145
-
146
- targets = data['targets'].to(device, dtype = torch.float)
147
- outputs = model(ids, mask, )
148
- outputs = outputs[:, 0, :]
149
- loss = loss_fn(outputs, targets)
150
-
151
- valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.item() - valid_loss))
152
-
153
- print('############# Epoch {}: Validation End #############'.format(epoch))
154
- train_loss = train_loss/len(training_loader)
155
- valid_loss = valid_loss/len(validation_loader)
156
- print('Epoch: {} \tAvgerage Training Loss: {:.6f} \tAverage Validation Loss: {:.6f}'.format(
157
- epoch,
158
- train_loss,
159
- valid_loss
160
- ))
161
-
162
- checkpoint = {
163
- 'epoch': epoch + 1,
164
- 'valid_loss_min': valid_loss,
165
- 'state_dict': model.state_dict(),
166
- 'optimizer': optimizer.state_dict()
167
- }
168
-
169
- save_ckp(checkpoint, False, checkpoint_path, best_model_path)
170
-
171
- if valid_loss <= valid_loss_min:
172
- print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(valid_loss_min,valid_loss))
173
- save_ckp(checkpoint, True, checkpoint_path, best_model_path)
174
- valid_loss_min = valid_loss
175
-
176
- print('############# Epoch {} Done #############\n'.format(epoch))
177
-
178
- return model
179
-
180
- def load_ckp(checkpoint_fpath, model, optimizer):
181
- """
182
- checkpoint_path: path to save checkpoint
183
- model: model that we want to load checkpoint parameters into
184
- optimizer: optimizer we defined in previous training
185
- """
186
- checkpoint = torch.load(checkpoint_fpath)
187
- model.load_state_dict(checkpoint['state_dict'])
188
- optimizer.load_state_dict(checkpoint['optimizer'])
189
- valid_loss_min = checkpoint['valid_loss_min']
190
- return model, optimizer, checkpoint['epoch'], valid_loss_min.item()
191
-
192
- def save_ckp(state, is_best, checkpoint_path, best_model_path):
193
- """
194
- state: checkpoint we want to save
195
- is_best: is this the best checkpoint; min validation loss
196
- checkpoint_path: path to save checkpoint
197
- best_model_path: path to save best model
198
- """
199
- f_path = checkpoint_path
200
- torch.save(state, f_path)
201
- if is_best:
202
- best_fpath = best_model_path
203
- shutil.copyfile(f_path, best_fpath)
204
-
205
- ckpt_path = "model.pt"
206
- best_model_path = "best_model.pt"
207
-
208
- trained_model = train_model(EPOCHS,
209
- train_data_loader,
210
- val_data_loader,
211
- model,
212
- optimizer,
213
- ckpt_path,
214
- best_model_path)