|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments |
|
from datasets import load_dataset, Dataset, DatasetDict |
|
from config import Config |
|
import torch |
|
from sklearn.model_selection import train_test_split |
|
import pandas as pd |
|
|
|
class CyberAttackDetectionModel: |
|
def __init__(self): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME) |
|
self.model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME) |
|
self.model.to(Config.DEVICE) |
|
|
|
def preprocess_data(self, dataset): |
|
""" |
|
Preprocess the raw text dataset by cleaning and tokenizing. |
|
""" |
|
|
|
def clean_text(text): |
|
|
|
|
|
text = text.lower() |
|
text = text.replace("\n", " ") |
|
return text |
|
|
|
|
|
dataset = dataset.map(lambda x: {'text': clean_text(x['text'])}) |
|
|
|
|
|
def tokenize_function(examples): |
|
return self.tokenizer(examples['text'], truncation=True, padding='max_length', max_length=Config.MAX_LENGTH) |
|
|
|
|
|
tokenized_dataset = dataset.map(tokenize_function, batched=True) |
|
|
|
|
|
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) |
|
|
|
return tokenized_dataset |
|
|
|
def fine_tune(self, datasets): |
|
""" |
|
Fine-tune the model with the preprocessed datasets. |
|
""" |
|
|
|
dataset_dict = DatasetDict({ |
|
"train": datasets['train'], |
|
"validation": datasets['validation'], |
|
}) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=Config.OUTPUT_DIR, |
|
evaluation_strategy="epoch", |
|
learning_rate=Config.LEARNING_RATE, |
|
per_device_train_batch_size=Config.BATCH_SIZE, |
|
per_device_eval_batch_size=Config.BATCH_SIZE, |
|
weight_decay=Config.WEIGHT_DECAY, |
|
save_total_limit=3, |
|
num_train_epochs=Config.NUM_EPOCHS, |
|
logging_dir=Config.LOGGING_DIR, |
|
load_best_model_at_end=True |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=dataset_dict['train'], |
|
eval_dataset=dataset_dict['validation'], |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
def predict(self, prompt): |
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=Config.MAX_LENGTH) |
|
inputs = {key: value.to(Config.DEVICE) for key, value in inputs.items()} |
|
|
|
outputs = self.model.generate(**inputs, max_length=Config.MAX_LENGTH) |
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def load_and_process_datasets(self): |
|
""" |
|
Loads and preprocesses the datasets for fine-tuning. |
|
""" |
|
|
|
osint_datasets = [ |
|
'gonferspanish/OSINT', |
|
'Inforensics/missing-persons-clue-analysis-osint', |
|
'jester6136/osint', |
|
'originalbox/osint' |
|
] |
|
|
|
wrn_datasets = [ |
|
'WhiteRabbitNeo/WRN-Chapter-2', |
|
'WhiteRabbitNeo/WRN-Chapter-1', |
|
'WhiteRabbitNeo/Code-Functions-Level-Cyber' |
|
] |
|
|
|
|
|
combined_datasets = [] |
|
|
|
|
|
for dataset_name in osint_datasets: |
|
dataset = load_dataset(dataset_name) |
|
processed_data = self.preprocess_data(dataset['train']) |
|
combined_datasets.append(processed_data) |
|
|
|
|
|
for dataset_name in wrn_datasets: |
|
dataset = load_dataset(dataset_name) |
|
processed_data = self.preprocess_data(dataset['train']) |
|
combined_datasets.append(processed_data) |
|
|
|
|
|
full_dataset = DatasetDict() |
|
full_dataset['train'] = Dataset.from_dict(pd.concat([d['train'] for d in combined_datasets])) |
|
full_dataset['validation'] = Dataset.from_dict(pd.concat([d['validation'] for d in combined_datasets])) |
|
|
|
return full_dataset |
|
|
|
if __name__ == "__main__": |
|
|
|
model = CyberAttackDetectionModel() |
|
|
|
|
|
preprocessed_datasets = model.load_and_process_datasets() |
|
|
|
|
|
model.fine_tune(preprocessed_datasets) |
|
|
|
|
|
prompt = "A network scan reveals an open port 22 with an outdated SSH service." |
|
print(model.predict(prompt)) |
|
|