NLSQL / app.py
HusnaManakkot's picture
Update app.py
5cacb61 verified
raw
history blame
1.73 kB
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
# Define your dataset class
class SpiderDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
return {'input_ids': self.encodings[idx], 'labels': self.labels[idx]}
def __len__(self):
return len(self.encodings)
# Load your preprocessed Spider dataset
train_encodings = # Your preprocessed input encodings for training (e.g., a list of input IDs)
train_labels = # Your preprocessed labels for training (e.g., a list of label IDs)
# Create a PyTorch dataset and dataloader
train_dataset = SpiderDataset(train_encodings, train_labels)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# Load the pre-trained T5 model
model = T5ForConditionalGeneration.from_pretrained('t5-base')
tokenizer = T5Tokenizer.from_pretrained('t5-base')
# Move the model to the GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Set up the optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)
# Fine-tune the model
model.train()
for epoch in range(3): # Number of epochs
for batch in tqdm(train_loader):
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
# Save the fine-tuned model
model.save_pretrained('your_model_directory')
tokenizer.save_pretrained('your_model_directory')