File size: 5,292 Bytes
d724ba4
 
d38adcf
d724ba4
 
 
d38adcf
 
 
d724ba4
d38adcf
 
 
 
d724ba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d38adcf
 
 
 
 
 
d724ba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset, Dataset, DatasetDict
from config import Config
import torch
from sklearn.model_selection import train_test_split
import pandas as pd

class CyberAttackDetectionModel:
    def __init__(self):
        # Initialize tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME)
        self.model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME)
        self.model.to(Config.DEVICE)
    
    def preprocess_data(self, dataset):
        """
        Preprocess the raw text dataset by cleaning and tokenizing.
        """
        # Clean the dataset (basic text normalization, removing unwanted characters)
        def clean_text(text):
            # Implement custom cleaning function based on dataset's characteristics
            # E.g., removing unwanted characters, special symbols, etc.
            text = text.lower()  # Example of making text lowercase
            text = text.replace("\n", " ")  # Removing newlines
            return text
        
        # Apply cleaning to the dataset
        dataset = dataset.map(lambda x: {'text': clean_text(x['text'])})
        
        # Tokenization
        def tokenize_function(examples):
            return self.tokenizer(examples['text'], truncation=True, padding='max_length', max_length=Config.MAX_LENGTH)
        
        # Tokenize the entire dataset
        tokenized_dataset = dataset.map(tokenize_function, batched=True)
        
        # Set format for PyTorch
        tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
        
        return tokenized_dataset
    
    def fine_tune(self, datasets):
        """
        Fine-tune the model with the preprocessed datasets.
        """
        # Load datasets (after pre-processing)
        dataset_dict = DatasetDict({
            "train": datasets['train'],
            "validation": datasets['validation'],
        })
        
        # Training arguments
        training_args = TrainingArguments(
            output_dir=Config.OUTPUT_DIR,
            evaluation_strategy="epoch",
            learning_rate=Config.LEARNING_RATE,
            per_device_train_batch_size=Config.BATCH_SIZE,
            per_device_eval_batch_size=Config.BATCH_SIZE,
            weight_decay=Config.WEIGHT_DECAY,
            save_total_limit=3,
            num_train_epochs=Config.NUM_EPOCHS,
            logging_dir=Config.LOGGING_DIR,
            load_best_model_at_end=True
        )
        
        # Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=dataset_dict['train'],
            eval_dataset=dataset_dict['validation'],
        )
        
        # Fine-tuning
        trainer.train()
    
    def predict(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=Config.MAX_LENGTH)
        inputs = {key: value.to(Config.DEVICE) for key, value in inputs.items()}
        
        outputs = self.model.generate(**inputs, max_length=Config.MAX_LENGTH)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def load_and_process_datasets(self):
        """
        Loads and preprocesses the datasets for fine-tuning.
        """
        # Load your OSINT and WhiteRabbitNeo datasets
        osint_datasets = [
            'gonferspanish/OSINT', 
            'Inforensics/missing-persons-clue-analysis-osint',
            'jester6136/osint', 
            'originalbox/osint'
        ]
        
        wrn_datasets = [
            'WhiteRabbitNeo/WRN-Chapter-2',
            'WhiteRabbitNeo/WRN-Chapter-1',
            'WhiteRabbitNeo/Code-Functions-Level-Cyber'
        ]
        
        # Combine all datasets into one for training
        combined_datasets = []
        
        # Load and preprocess OSINT datasets
        for dataset_name in osint_datasets:
            dataset = load_dataset(dataset_name)
            processed_data = self.preprocess_data(dataset['train'])  # Assuming the 'train' split exists
            combined_datasets.append(processed_data)
        
        # Load and preprocess WhiteRabbitNeo datasets
        for dataset_name in wrn_datasets:
            dataset = load_dataset(dataset_name)
            processed_data = self.preprocess_data(dataset['train'])  # Assuming the 'train' split exists
            combined_datasets.append(processed_data)
        
        # Combine all preprocessed datasets
        full_dataset = DatasetDict()
        full_dataset['train'] = Dataset.from_dict(pd.concat([d['train'] for d in combined_datasets]))
        full_dataset['validation'] = Dataset.from_dict(pd.concat([d['validation'] for d in combined_datasets]))
        
        return full_dataset

if __name__ == "__main__":
    # Create the model object
    model = CyberAttackDetectionModel()

    # Load and preprocess datasets
    preprocessed_datasets = model.load_and_process_datasets()

    # Fine-tune the model
    model.fine_tune(preprocessed_datasets)

    # Example prediction
    prompt = "A network scan reveals an open port 22 with an outdated SSH service."
    print(model.predict(prompt))