Spaces:
Build error
Build error
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments | |
| class TextDataset(Dataset): | |
| def __init__(self, text, tokenizer, max_length): | |
| self.tokenizer = tokenizer | |
| self.input_ids = self.tokenizer(text, return_tensors='pt', max_length=max_length, truncation=True, padding="max_length").input_ids | |
| def __len__(self): | |
| return self.input_ids.size(1) | |
| def __getitem__(self, idx): | |
| return self.input_ids[:, idx] | |
| def main(): | |
| # Hyperparameters | |
| max_length = 512 | |
| batch_size = 32 | |
| epochs = 3 | |
| learning_rate = 5e-5 | |
| # File path | |
| text_file_path = 'path/to/your/text/file.txt' # Modifica questo percorso | |
| # Load text data | |
| with open(text_file_path, 'r', encoding='utf-8') as file: | |
| text = file.read() | |
| # Load tokenizer and model | |
| model_name = "togethercomputer/RedPajama-INCITE-Chat-3B-v1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # Preprocess data | |
| dataset = TextDataset(text, tokenizer, max_length) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| # Setup device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| # Setup optimizer | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | |
| # Training loop | |
| for epoch in range(epochs): | |
| print(f"Epoch {epoch + 1}/{epochs}") | |
| model.train() | |
| for batch in dataloader: | |
| inputs = batch.to(device) | |
| outputs = model(inputs, labels=inputs) | |
| loss = outputs.loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Loss: {loss.item()}") | |
| # Save the model | |
| model_save_path = 'model' # Modifica questo percorso | |
| model.save_pretrained(model_save_path) | |
| tokenizer.save_pretrained(model_save_path) | |
| if __name__ == '__main__': | |
| main() | |