dafeafdf dfae
Browse files- train_llama.py +14 -9
train_llama.py
CHANGED
@@ -43,20 +43,25 @@ model.print_trainable_parameters()
|
|
43 |
dataset = datasets.load_dataset("json", data_files="final_combined_fraud_data.json", field="training_pairs")
|
44 |
print("First example from dataset:", dataset["train"][0])
|
45 |
|
46 |
-
# Tokenization with
|
47 |
def tokenize_data(example):
|
48 |
formatted_text = f"{example['input']} {example['output']}"
|
49 |
inputs = tokenizer(formatted_text, truncation=True, max_length=512, padding="max_length", return_tensors="pt")
|
50 |
-
input_ids = inputs["input_ids"].squeeze(0)
|
51 |
-
attention_mask = inputs["attention_mask"].squeeze(0)
|
52 |
-
labels = input_ids.
|
53 |
input_len = len(tokenizer(example['input'])["input_ids"])
|
54 |
-
labels[:input_len] =
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
return {
|
57 |
-
"input_ids": input_ids,
|
58 |
-
"labels": labels,
|
59 |
-
"attention_mask": attention_mask
|
60 |
}
|
61 |
|
62 |
tokenized_dataset = dataset["train"].map(tokenize_data, batched=False, remove_columns=dataset["train"].column_names)
|
|
|
43 |
dataset = datasets.load_dataset("json", data_files="final_combined_fraud_data.json", field="training_pairs")
|
44 |
print("First example from dataset:", dataset["train"][0])
|
45 |
|
46 |
+
# Tokenization with validation
|
47 |
def tokenize_data(example):
|
48 |
formatted_text = f"{example['input']} {example['output']}"
|
49 |
inputs = tokenizer(formatted_text, truncation=True, max_length=512, padding="max_length", return_tensors="pt")
|
50 |
+
input_ids = inputs["input_ids"].squeeze(0)
|
51 |
+
attention_mask = inputs["attention_mask"].squeeze(0)
|
52 |
+
labels = input_ids.clone()
|
53 |
input_len = len(tokenizer(example['input'])["input_ids"])
|
54 |
+
labels[:input_len] = -100 # Mask input part in labels only
|
55 |
+
# Validate input_ids
|
56 |
+
vocab_size = model.config.vocab_size # Should be 32000 for LLaMA-2
|
57 |
+
if (input_ids < 0).any() or (input_ids >= vocab_size).any():
|
58 |
+
print(f"Invalid input_ids: min={input_ids.min()}, max={input_ids.max()}, vocab_size={vocab_size}")
|
59 |
+
raise ValueError("input_ids contains invalid indices")
|
60 |
+
print(f"Debug: input_ids[:5] = {input_ids[:5].tolist()}, labels[:5] = {labels[:5].tolist()}, attention_mask[:5] = {attention_mask[:5].tolist()}")
|
61 |
return {
|
62 |
+
"input_ids": input_ids.tolist(),
|
63 |
+
"labels": labels.tolist(),
|
64 |
+
"attention_mask": attention_mask.tolist()
|
65 |
}
|
66 |
|
67 |
tokenized_dataset = dataset["train"].map(tokenize_data, batched=False, remove_columns=dataset["train"].column_names)
|