File size: 5,292 Bytes
d724ba4 d38adcf d724ba4 d38adcf d724ba4 d38adcf d724ba4 d38adcf d724ba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset, Dataset, DatasetDict
from config import Config
import torch
from sklearn.model_selection import train_test_split
import pandas as pd
class CyberAttackDetectionModel:
def __init__(self):
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME)
self.model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME)
self.model.to(Config.DEVICE)
def preprocess_data(self, dataset):
"""
Preprocess the raw text dataset by cleaning and tokenizing.
"""
# Clean the dataset (basic text normalization, removing unwanted characters)
def clean_text(text):
# Implement custom cleaning function based on dataset's characteristics
# E.g., removing unwanted characters, special symbols, etc.
text = text.lower() # Example of making text lowercase
text = text.replace("\n", " ") # Removing newlines
return text
# Apply cleaning to the dataset
dataset = dataset.map(lambda x: {'text': clean_text(x['text'])})
# Tokenization
def tokenize_function(examples):
return self.tokenizer(examples['text'], truncation=True, padding='max_length', max_length=Config.MAX_LENGTH)
# Tokenize the entire dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Set format for PyTorch
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
return tokenized_dataset
def fine_tune(self, datasets):
"""
Fine-tune the model with the preprocessed datasets.
"""
# Load datasets (after pre-processing)
dataset_dict = DatasetDict({
"train": datasets['train'],
"validation": datasets['validation'],
})
# Training arguments
training_args = TrainingArguments(
output_dir=Config.OUTPUT_DIR,
evaluation_strategy="epoch",
learning_rate=Config.LEARNING_RATE,
per_device_train_batch_size=Config.BATCH_SIZE,
per_device_eval_batch_size=Config.BATCH_SIZE,
weight_decay=Config.WEIGHT_DECAY,
save_total_limit=3,
num_train_epochs=Config.NUM_EPOCHS,
logging_dir=Config.LOGGING_DIR,
load_best_model_at_end=True
)
# Trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=dataset_dict['train'],
eval_dataset=dataset_dict['validation'],
)
# Fine-tuning
trainer.train()
def predict(self, prompt):
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=Config.MAX_LENGTH)
inputs = {key: value.to(Config.DEVICE) for key, value in inputs.items()}
outputs = self.model.generate(**inputs, max_length=Config.MAX_LENGTH)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def load_and_process_datasets(self):
"""
Loads and preprocesses the datasets for fine-tuning.
"""
# Load your OSINT and WhiteRabbitNeo datasets
osint_datasets = [
'gonferspanish/OSINT',
'Inforensics/missing-persons-clue-analysis-osint',
'jester6136/osint',
'originalbox/osint'
]
wrn_datasets = [
'WhiteRabbitNeo/WRN-Chapter-2',
'WhiteRabbitNeo/WRN-Chapter-1',
'WhiteRabbitNeo/Code-Functions-Level-Cyber'
]
# Combine all datasets into one for training
combined_datasets = []
# Load and preprocess OSINT datasets
for dataset_name in osint_datasets:
dataset = load_dataset(dataset_name)
processed_data = self.preprocess_data(dataset['train']) # Assuming the 'train' split exists
combined_datasets.append(processed_data)
# Load and preprocess WhiteRabbitNeo datasets
for dataset_name in wrn_datasets:
dataset = load_dataset(dataset_name)
processed_data = self.preprocess_data(dataset['train']) # Assuming the 'train' split exists
combined_datasets.append(processed_data)
# Combine all preprocessed datasets
full_dataset = DatasetDict()
full_dataset['train'] = Dataset.from_dict(pd.concat([d['train'] for d in combined_datasets]))
full_dataset['validation'] = Dataset.from_dict(pd.concat([d['validation'] for d in combined_datasets]))
return full_dataset
if __name__ == "__main__":
# Create the model object
model = CyberAttackDetectionModel()
# Load and preprocess datasets
preprocessed_datasets = model.load_and_process_datasets()
# Fine-tune the model
model.fine_tune(preprocessed_datasets)
# Example prediction
prompt = "A network scan reveals an open port 22 with an outdated SSH service."
print(model.predict(prompt))
|