|
""" |
|
Multimodal Learning System for Cybersecurity |
|
Integration of text, network data, and visual security information |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import json |
|
import cv2 |
|
from typing import Dict, List, Optional, Any, Tuple, Union |
|
from dataclasses import dataclass, asdict |
|
from datetime import datetime |
|
import logging |
|
from abc import ABC, abstractmethod |
|
from PIL import Image |
|
import base64 |
|
import io |
|
|
|
@dataclass |
|
class TextData: |
|
"""Text-based security data""" |
|
content: str |
|
data_type: str |
|
metadata: Dict[str, Any] |
|
timestamp: str |
|
source: str |
|
|
|
@dataclass |
|
class NetworkData: |
|
"""Network traffic data""" |
|
packet_data: bytes |
|
flow_features: Dict[str, float] |
|
protocol: str |
|
source_ip: str |
|
dest_ip: str |
|
timestamp: str |
|
metadata: Dict[str, Any] |
|
|
|
@dataclass |
|
class VisualData: |
|
"""Visual security data""" |
|
image_data: np.ndarray |
|
image_type: str |
|
features: Dict[str, Any] |
|
timestamp: str |
|
metadata: Dict[str, Any] |
|
|
|
@dataclass |
|
class MultimodalSample: |
|
"""Combined multimodal sample""" |
|
sample_id: str |
|
text_data: Optional[TextData] |
|
network_data: Optional[NetworkData] |
|
visual_data: Optional[VisualData] |
|
label: str |
|
confidence: float |
|
timestamp: str |
|
|
|
class ModalityEncoder(nn.Module, ABC): |
|
"""Abstract base class for modality encoders""" |
|
|
|
@abstractmethod |
|
def forward(self, data: Any) -> torch.Tensor: |
|
pass |
|
|
|
@abstractmethod |
|
def get_output_dim(self) -> int: |
|
pass |
|
|
|
class TextEncoder(ModalityEncoder): |
|
"""Encoder for text-based security data""" |
|
|
|
def __init__(self, vocab_size: int = 10000, embed_dim: int = 256, hidden_dim: int = 512): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.embedding = nn.Embedding(vocab_size, embed_dim) |
|
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True) |
|
self.attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=8) |
|
self.output_proj = nn.Linear(hidden_dim * 2, hidden_dim) |
|
|
|
|
|
self.threat_patterns = nn.Conv1d(hidden_dim * 2, 64, kernel_size=3, padding=1) |
|
self.temporal_patterns = nn.Conv1d(hidden_dim * 2, 64, kernel_size=5, padding=2) |
|
|
|
def forward(self, text_tokens: torch.Tensor) -> torch.Tensor: |
|
|
|
embedded = self.embedding(text_tokens) |
|
|
|
|
|
lstm_out, (h_n, c_n) = self.lstm(embedded) |
|
|
|
|
|
attn_out, _ = self.attention( |
|
lstm_out.transpose(0, 1), |
|
lstm_out.transpose(0, 1), |
|
lstm_out.transpose(0, 1) |
|
) |
|
attn_out = attn_out.transpose(0, 1) |
|
|
|
|
|
lstm_transposed = lstm_out.transpose(1, 2) |
|
threat_features = F.relu(self.threat_patterns(lstm_transposed)) |
|
temporal_features = F.relu(self.temporal_patterns(lstm_transposed)) |
|
|
|
|
|
threat_pooled = F.adaptive_avg_pool1d(threat_features, 1).squeeze(-1) |
|
temporal_pooled = F.adaptive_avg_pool1d(temporal_features, 1).squeeze(-1) |
|
|
|
|
|
combined = torch.cat([ |
|
attn_out.mean(dim=1), |
|
threat_pooled, |
|
temporal_pooled |
|
], dim=1) |
|
|
|
output = self.output_proj(combined[:, :self.hidden_dim * 2]) |
|
return F.relu(output) |
|
|
|
def get_output_dim(self) -> int: |
|
return self.hidden_dim |
|
|
|
class NetworkEncoder(ModalityEncoder): |
|
"""Encoder for network traffic data""" |
|
|
|
def __init__(self, flow_feature_dim: int = 50, packet_embed_dim: int = 128, hidden_dim: int = 512): |
|
super().__init__() |
|
self.flow_feature_dim = flow_feature_dim |
|
self.packet_embed_dim = packet_embed_dim |
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.flow_encoder = nn.Sequential( |
|
nn.Linear(flow_feature_dim, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Linear(256, 256), |
|
nn.ReLU() |
|
) |
|
|
|
|
|
self.packet_embedding = nn.Embedding(256, packet_embed_dim) |
|
self.packet_conv1d = nn.Conv1d(packet_embed_dim, 128, kernel_size=3, padding=1) |
|
self.packet_conv2d = nn.Conv1d(128, 64, kernel_size=5, padding=2) |
|
|
|
|
|
self.protocol_embedding = nn.Embedding(10, 32) |
|
|
|
|
|
self.temporal_conv = nn.Conv1d(256 + 64 + 32, 128, kernel_size=3, padding=1) |
|
|
|
|
|
self.output_proj = nn.Linear(128 + 256, hidden_dim) |
|
|
|
def forward(self, network_data: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
|
flow_features = network_data['flow_features'] |
|
packet_bytes = network_data['packet_bytes'] |
|
protocol_ids = network_data['protocol_ids'] |
|
|
|
|
|
flow_encoded = self.flow_encoder(flow_features) |
|
|
|
|
|
packet_embedded = self.packet_embedding(packet_bytes) |
|
packet_transposed = packet_embedded.transpose(1, 2) |
|
|
|
packet_conv1 = F.relu(self.packet_conv1d(packet_transposed)) |
|
packet_conv2 = F.relu(self.packet_conv2d(packet_conv1)) |
|
packet_pooled = F.adaptive_avg_pool1d(packet_conv2, 1).squeeze(-1) |
|
|
|
|
|
protocol_embedded = self.protocol_embedding(protocol_ids) |
|
|
|
|
|
combined_features = torch.cat([ |
|
flow_encoded, packet_pooled, protocol_embedded |
|
], dim=1).unsqueeze(-1) |
|
|
|
temporal_features = F.relu(self.temporal_conv(combined_features)) |
|
temporal_pooled = temporal_features.squeeze(-1) |
|
|
|
|
|
final_features = torch.cat([temporal_pooled, flow_encoded], dim=1) |
|
output = self.output_proj(final_features) |
|
|
|
return F.relu(output) |
|
|
|
def get_output_dim(self) -> int: |
|
return self.hidden_dim |
|
|
|
class VisualEncoder(ModalityEncoder): |
|
"""Encoder for visual security data""" |
|
|
|
def __init__(self, hidden_dim: int = 512): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.conv_layers = nn.Sequential( |
|
|
|
nn.Conv2d(3, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Conv2d(128, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.Conv2d(256, 256, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
nn.Conv2d(256, 512, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.AdaptiveAvgPool2d((7, 7)) |
|
) |
|
|
|
|
|
self.topology_detector = nn.Conv2d(512, 64, kernel_size=1) |
|
self.anomaly_detector = nn.Conv2d(512, 64, kernel_size=1) |
|
self.threat_indicator_detector = nn.Conv2d(512, 64, kernel_size=1) |
|
|
|
|
|
self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
self.output_proj = nn.Linear(512 + 64 * 3, hidden_dim) |
|
|
|
def forward(self, image_data: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
conv_features = self.conv_layers(image_data) |
|
|
|
|
|
topology_features = F.relu(self.topology_detector(conv_features)) |
|
anomaly_features = F.relu(self.anomaly_detector(conv_features)) |
|
threat_features = F.relu(self.threat_indicator_detector(conv_features)) |
|
|
|
|
|
conv_pooled = self.global_pool(conv_features).view(conv_features.size(0), -1) |
|
topology_pooled = self.global_pool(topology_features).view(topology_features.size(0), -1) |
|
anomaly_pooled = self.global_pool(anomaly_features).view(anomaly_features.size(0), -1) |
|
threat_pooled = self.global_pool(threat_features).view(threat_features.size(0), -1) |
|
|
|
|
|
combined_features = torch.cat([ |
|
conv_pooled, topology_pooled, anomaly_pooled, threat_pooled |
|
], dim=1) |
|
|
|
output = self.output_proj(combined_features) |
|
return F.relu(output) |
|
|
|
def get_output_dim(self) -> int: |
|
return self.hidden_dim |
|
|
|
class MultimodalFusionLayer(nn.Module): |
|
"""Fusion layer for combining multimodal features""" |
|
|
|
def __init__(self, text_dim: int, network_dim: int, visual_dim: int, |
|
fusion_dim: int = 512, num_heads: int = 8): |
|
super().__init__() |
|
self.text_dim = text_dim |
|
self.network_dim = network_dim |
|
self.visual_dim = visual_dim |
|
self.fusion_dim = fusion_dim |
|
|
|
|
|
self.text_proj = nn.Linear(text_dim, fusion_dim) if text_dim != fusion_dim else nn.Identity() |
|
self.network_proj = nn.Linear(network_dim, fusion_dim) if network_dim != fusion_dim else nn.Identity() |
|
self.visual_proj = nn.Linear(visual_dim, fusion_dim) if visual_dim != fusion_dim else nn.Identity() |
|
|
|
|
|
self.cross_attention = nn.MultiheadAttention(fusion_dim, num_heads, batch_first=True) |
|
|
|
|
|
self.attention_weights = nn.Parameter(torch.ones(3) / 3) |
|
|
|
|
|
self.text_gate = nn.Sequential( |
|
nn.Linear(fusion_dim, fusion_dim // 4), |
|
nn.ReLU(), |
|
nn.Linear(fusion_dim // 4, 1), |
|
nn.Sigmoid() |
|
) |
|
self.network_gate = nn.Sequential( |
|
nn.Linear(fusion_dim, fusion_dim // 4), |
|
nn.ReLU(), |
|
nn.Linear(fusion_dim // 4, 1), |
|
nn.Sigmoid() |
|
) |
|
self.visual_gate = nn.Sequential( |
|
nn.Linear(fusion_dim, fusion_dim // 4), |
|
nn.ReLU(), |
|
nn.Linear(fusion_dim // 4, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
self.fusion_network = nn.Sequential( |
|
nn.Linear(fusion_dim, fusion_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Linear(fusion_dim, fusion_dim) |
|
) |
|
|
|
def forward(self, text_features: Optional[torch.Tensor] = None, |
|
network_features: Optional[torch.Tensor] = None, |
|
visual_features: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
|
available_modalities = [] |
|
projected_features = [] |
|
|
|
|
|
if text_features is not None: |
|
text_proj = self.text_proj(text_features) |
|
available_modalities.append(('text', text_proj)) |
|
projected_features.append(text_proj) |
|
|
|
if network_features is not None: |
|
network_proj = self.network_proj(network_features) |
|
available_modalities.append(('network', network_proj)) |
|
projected_features.append(network_proj) |
|
|
|
if visual_features is not None: |
|
visual_proj = self.visual_proj(visual_features) |
|
available_modalities.append(('visual', visual_proj)) |
|
projected_features.append(visual_proj) |
|
|
|
if not projected_features: |
|
raise ValueError("At least one modality must be provided") |
|
|
|
if len(projected_features) == 1: |
|
|
|
return self.fusion_network(projected_features[0]) |
|
|
|
|
|
stacked_features = torch.stack(projected_features, dim=1) |
|
|
|
|
|
attended_features, attention_weights = self.cross_attention( |
|
stacked_features, stacked_features, stacked_features |
|
) |
|
|
|
|
|
gated_features = [] |
|
for i, (modality, features) in enumerate(available_modalities): |
|
if modality == 'text' and text_features is not None: |
|
gate = self.text_gate(features) |
|
gated_features.append(attended_features[:, i] * gate) |
|
elif modality == 'network' and network_features is not None: |
|
gate = self.network_gate(features) |
|
gated_features.append(attended_features[:, i] * gate) |
|
elif modality == 'visual' and visual_features is not None: |
|
gate = self.visual_gate(features) |
|
gated_features.append(attended_features[:, i] * gate) |
|
|
|
|
|
if len(gated_features) == 2: |
|
weights = F.softmax(self.attention_weights[:2], dim=0) |
|
fused = weights[0] * gated_features[0] + weights[1] * gated_features[1] |
|
elif len(gated_features) == 3: |
|
weights = F.softmax(self.attention_weights, dim=0) |
|
fused = (weights[0] * gated_features[0] + |
|
weights[1] * gated_features[1] + |
|
weights[2] * gated_features[2]) |
|
else: |
|
fused = torch.stack(gated_features, dim=1).mean(dim=1) |
|
|
|
|
|
output = self.fusion_network(fused) |
|
return output |
|
|
|
class MultimodalSecurityClassifier(nn.Module): |
|
"""Complete multimodal cybersecurity classifier""" |
|
|
|
def __init__(self, num_classes: int, vocab_size: int = 10000, |
|
flow_feature_dim: int = 50, fusion_dim: int = 512): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
|
|
|
|
self.text_encoder = TextEncoder(vocab_size=vocab_size, hidden_dim=fusion_dim) |
|
self.network_encoder = NetworkEncoder(flow_feature_dim=flow_feature_dim, hidden_dim=fusion_dim) |
|
self.visual_encoder = VisualEncoder(hidden_dim=fusion_dim) |
|
|
|
|
|
self.fusion_layer = MultimodalFusionLayer( |
|
text_dim=fusion_dim, |
|
network_dim=fusion_dim, |
|
visual_dim=fusion_dim, |
|
fusion_dim=fusion_dim |
|
) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(fusion_dim, fusion_dim // 2), |
|
nn.ReLU(), |
|
nn.Dropout(0.4), |
|
nn.Linear(fusion_dim // 2, fusion_dim // 4), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Linear(fusion_dim // 4, num_classes) |
|
) |
|
|
|
|
|
self.text_classifier = nn.Linear(fusion_dim, num_classes) |
|
self.network_classifier = nn.Linear(fusion_dim, num_classes) |
|
self.visual_classifier = nn.Linear(fusion_dim, num_classes) |
|
|
|
def forward(self, text_tokens: Optional[torch.Tensor] = None, |
|
network_data: Optional[Dict[str, torch.Tensor]] = None, |
|
visual_data: Optional[torch.Tensor] = None, |
|
return_individual_outputs: bool = False) -> Dict[str, torch.Tensor]: |
|
|
|
outputs = {} |
|
|
|
|
|
text_features = None |
|
network_features = None |
|
visual_features = None |
|
|
|
if text_tokens is not None: |
|
text_features = self.text_encoder(text_tokens) |
|
if return_individual_outputs: |
|
outputs['text_logits'] = self.text_classifier(text_features) |
|
|
|
if network_data is not None: |
|
network_features = self.network_encoder(network_data) |
|
if return_individual_outputs: |
|
outputs['network_logits'] = self.network_classifier(network_features) |
|
|
|
if visual_data is not None: |
|
visual_features = self.visual_encoder(visual_data) |
|
if return_individual_outputs: |
|
outputs['visual_logits'] = self.visual_classifier(visual_features) |
|
|
|
|
|
if text_features is not None or network_features is not None or visual_features is not None: |
|
fused_features = self.fusion_layer(text_features, network_features, visual_features) |
|
outputs['fused_logits'] = self.classifier(fused_features) |
|
|
|
return outputs |
|
|
|
class MultimodalSecuritySystem: |
|
"""Complete multimodal learning system for cybersecurity""" |
|
|
|
def __init__(self, num_classes: int = 10, device: str = "cpu"): |
|
self.num_classes = num_classes |
|
self.device = device |
|
|
|
|
|
self.model = MultimodalSecurityClassifier(num_classes=num_classes) |
|
self.model.to(device) |
|
|
|
|
|
self.text_processor = self._create_text_processor() |
|
self.network_processor = self._create_network_processor() |
|
self.visual_processor = self._create_visual_processor() |
|
|
|
|
|
self.optimizer = None |
|
self.criterion = nn.CrossEntropyLoss() |
|
self.logger = logging.getLogger(__name__) |
|
|
|
def _create_text_processor(self): |
|
"""Create text data processor""" |
|
|
|
def process_text(text_data: TextData) -> torch.Tensor: |
|
|
|
words = text_data.content.lower().split() |
|
|
|
token_ids = [hash(word) % 10000 for word in words[:512]] |
|
|
|
|
|
if len(token_ids) < 512: |
|
token_ids.extend([0] * (512 - len(token_ids))) |
|
else: |
|
token_ids = token_ids[:512] |
|
|
|
return torch.tensor(token_ids, dtype=torch.long) |
|
|
|
return process_text |
|
|
|
def _create_network_processor(self): |
|
"""Create network data processor""" |
|
def process_network(network_data: NetworkData) -> Dict[str, torch.Tensor]: |
|
|
|
flow_features = torch.tensor([ |
|
network_data.flow_features.get('packet_count', 0), |
|
network_data.flow_features.get('byte_count', 0), |
|
network_data.flow_features.get('duration', 0), |
|
network_data.flow_features.get('avg_packet_size', 0), |
|
network_data.flow_features.get('packets_per_second', 0) |
|
] + [0] * 45, dtype=torch.float32)[:50] |
|
|
|
|
|
packet_bytes = list(network_data.packet_data[:1024]) |
|
if len(packet_bytes) < 1024: |
|
packet_bytes.extend([0] * (1024 - len(packet_bytes))) |
|
packet_tensor = torch.tensor(packet_bytes, dtype=torch.long) |
|
|
|
|
|
protocol_map = {'tcp': 0, 'udp': 1, 'icmp': 2, 'http': 3, 'https': 4} |
|
protocol_id = torch.tensor( |
|
protocol_map.get(network_data.protocol.lower(), 5), |
|
dtype=torch.long |
|
) |
|
|
|
return { |
|
'flow_features': flow_features, |
|
'packet_bytes': packet_tensor, |
|
'protocol_ids': protocol_id |
|
} |
|
|
|
return process_network |
|
|
|
def _create_visual_processor(self): |
|
"""Create visual data processor""" |
|
def process_visual(visual_data: VisualData) -> torch.Tensor: |
|
|
|
if visual_data.image_data.shape[-1] == 3: |
|
image_tensor = torch.from_numpy(visual_data.image_data).float() |
|
image_tensor = image_tensor.permute(2, 0, 1) |
|
else: |
|
|
|
image_tensor = torch.from_numpy(visual_data.image_data).float() |
|
if len(image_tensor.shape) == 2: |
|
image_tensor = image_tensor.unsqueeze(0).repeat(3, 1, 1) |
|
|
|
|
|
if image_tensor.shape[1] != 224 or image_tensor.shape[2] != 224: |
|
image_tensor = F.interpolate( |
|
image_tensor.unsqueeze(0), size=(224, 224), mode='bilinear' |
|
).squeeze(0) |
|
|
|
|
|
image_tensor = image_tensor / 255.0 |
|
|
|
return image_tensor |
|
|
|
return process_visual |
|
|
|
def prepare_batch(self, samples: List[MultimodalSample]) -> Dict[str, torch.Tensor]: |
|
"""Prepare a batch of multimodal samples""" |
|
batch = { |
|
'text_tokens': [], |
|
'network_data': {'flow_features': [], 'packet_bytes': [], 'protocol_ids': []}, |
|
'visual_data': [], |
|
'labels': [], |
|
'sample_ids': [] |
|
} |
|
|
|
for sample in samples: |
|
batch['sample_ids'].append(sample.sample_id) |
|
|
|
|
|
if sample.text_data: |
|
text_tokens = self.text_processor(sample.text_data) |
|
batch['text_tokens'].append(text_tokens) |
|
else: |
|
batch['text_tokens'].append(None) |
|
|
|
|
|
if sample.network_data: |
|
network_processed = self.network_processor(sample.network_data) |
|
batch['network_data']['flow_features'].append(network_processed['flow_features']) |
|
batch['network_data']['packet_bytes'].append(network_processed['packet_bytes']) |
|
batch['network_data']['protocol_ids'].append(network_processed['protocol_ids']) |
|
else: |
|
batch['network_data']['flow_features'].append(None) |
|
batch['network_data']['packet_bytes'].append(None) |
|
batch['network_data']['protocol_ids'].append(None) |
|
|
|
|
|
if sample.visual_data: |
|
visual_processed = self.visual_processor(sample.visual_data) |
|
batch['visual_data'].append(visual_processed) |
|
else: |
|
batch['visual_data'].append(None) |
|
|
|
|
|
batch['labels'].append(sample.label) |
|
|
|
|
|
result = {} |
|
|
|
|
|
valid_text = [t for t in batch['text_tokens'] if t is not None] |
|
if valid_text: |
|
result['text_tokens'] = torch.stack(valid_text).to(self.device) |
|
|
|
|
|
valid_flow = [f for f in batch['network_data']['flow_features'] if f is not None] |
|
valid_packets = [p for p in batch['network_data']['packet_bytes'] if p is not None] |
|
valid_protocols = [p for p in batch['network_data']['protocol_ids'] if p is not None] |
|
|
|
if valid_flow: |
|
result['network_data'] = { |
|
'flow_features': torch.stack(valid_flow).to(self.device), |
|
'packet_bytes': torch.stack(valid_packets).to(self.device), |
|
'protocol_ids': torch.stack(valid_protocols).to(self.device) |
|
} |
|
|
|
|
|
valid_visual = [v for v in batch['visual_data'] if v is not None] |
|
if valid_visual: |
|
result['visual_data'] = torch.stack(valid_visual).to(self.device) |
|
|
|
|
|
label_map = { |
|
'benign': 0, 'malware': 1, 'phishing': 2, 'ddos': 3, 'intrusion': 4, |
|
'lateral_movement': 5, 'data_exfiltration': 6, 'ransomware': 7, |
|
'insider_threat': 8, 'unknown': 9 |
|
} |
|
label_indices = [label_map.get(label, 9) for label in batch['labels']] |
|
result['labels'] = torch.tensor(label_indices, dtype=torch.long).to(self.device) |
|
|
|
return result |
|
|
|
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: |
|
"""Single training step""" |
|
if self.optimizer is None: |
|
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4) |
|
|
|
self.model.train() |
|
self.optimizer.zero_grad() |
|
|
|
|
|
outputs = self.model( |
|
text_tokens=batch.get('text_tokens'), |
|
network_data=batch.get('network_data'), |
|
visual_data=batch.get('visual_data'), |
|
return_individual_outputs=True |
|
) |
|
|
|
|
|
losses = {} |
|
total_loss = 0 |
|
|
|
labels = batch['labels'] |
|
|
|
|
|
if 'fused_logits' in outputs: |
|
fusion_loss = self.criterion(outputs['fused_logits'], labels) |
|
losses['fusion_loss'] = fusion_loss.item() |
|
total_loss += fusion_loss |
|
|
|
|
|
aux_weight = 0.3 |
|
if 'text_logits' in outputs: |
|
text_loss = self.criterion(outputs['text_logits'], labels) |
|
losses['text_loss'] = text_loss.item() |
|
total_loss += aux_weight * text_loss |
|
|
|
if 'network_logits' in outputs: |
|
network_loss = self.criterion(outputs['network_logits'], labels) |
|
losses['network_loss'] = network_loss.item() |
|
total_loss += aux_weight * network_loss |
|
|
|
if 'visual_logits' in outputs: |
|
visual_loss = self.criterion(outputs['visual_logits'], labels) |
|
losses['visual_loss'] = visual_loss.item() |
|
total_loss += aux_weight * visual_loss |
|
|
|
|
|
total_loss.backward() |
|
self.optimizer.step() |
|
|
|
losses['total_loss'] = total_loss.item() |
|
return losses |
|
|
|
def predict(self, samples: List[MultimodalSample]) -> List[Dict[str, Any]]: |
|
"""Make predictions on multimodal samples""" |
|
self.model.eval() |
|
|
|
batch = self.prepare_batch(samples) |
|
predictions = [] |
|
|
|
with torch.no_grad(): |
|
outputs = self.model( |
|
text_tokens=batch.get('text_tokens'), |
|
network_data=batch.get('network_data'), |
|
visual_data=batch.get('visual_data'), |
|
return_individual_outputs=True |
|
) |
|
|
|
|
|
if 'fused_logits' in outputs: |
|
probs = F.softmax(outputs['fused_logits'], dim=1) |
|
pred_classes = torch.argmax(probs, dim=1) |
|
confidence_scores = torch.max(probs, dim=1)[0] |
|
|
|
|
|
class_names = [ |
|
'benign', 'malware', 'phishing', 'ddos', 'intrusion', |
|
'lateral_movement', 'data_exfiltration', 'ransomware', |
|
'insider_threat', 'unknown' |
|
] |
|
|
|
for i, sample in enumerate(samples): |
|
predictions.append({ |
|
'sample_id': sample.sample_id, |
|
'predicted_class': class_names[pred_classes[i].item()], |
|
'confidence': confidence_scores[i].item(), |
|
'class_probabilities': { |
|
class_names[j]: probs[i][j].item() |
|
for j in range(len(class_names)) |
|
} |
|
}) |
|
|
|
return predictions |
|
|
|
|
|
if __name__ == "__main__": |
|
print("๐ Multimodal Learning System Testing:") |
|
print("=" * 50) |
|
|
|
|
|
multimodal_system = MultimodalSecuritySystem(num_classes=10, device="cpu") |
|
|
|
|
|
print("\n๐ Creating sample multimodal data...") |
|
|
|
|
|
text_sample = TextData( |
|
content="suspicious network activity detected from ip 192.168.1.100 attempting connection to external server", |
|
data_type="security_log", |
|
metadata={"source": "ids", "severity": "high"}, |
|
timestamp=datetime.now().isoformat(), |
|
source="security_system" |
|
) |
|
|
|
|
|
network_sample = NetworkData( |
|
packet_data=b'\x45\x00\x00\x3c\x1c\x46\x40\x00\x40\x06\x00\x00\xc0\xa8\x01\x64' * 64, |
|
flow_features={ |
|
"packet_count": 150, |
|
"byte_count": 9600, |
|
"duration": 30.5, |
|
"avg_packet_size": 64, |
|
"packets_per_second": 4.9 |
|
}, |
|
protocol="tcp", |
|
source_ip="192.168.1.100", |
|
dest_ip="external_server", |
|
timestamp=datetime.now().isoformat(), |
|
metadata={"port": 443, "flags": ["SYN", "ACK"]} |
|
) |
|
|
|
|
|
visual_sample = VisualData( |
|
image_data=np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8), |
|
image_type="network_topology", |
|
features={"nodes": 15, "edges": 23, "anomalous_connections": 2}, |
|
timestamp=datetime.now().isoformat(), |
|
metadata={"generated": True, "tool": "network_visualizer"} |
|
) |
|
|
|
|
|
samples = [ |
|
MultimodalSample( |
|
sample_id="sample_001", |
|
text_data=text_sample, |
|
network_data=network_sample, |
|
visual_data=visual_sample, |
|
label="intrusion", |
|
confidence=0.85, |
|
timestamp=datetime.now().isoformat() |
|
), |
|
MultimodalSample( |
|
sample_id="sample_002", |
|
text_data=text_sample, |
|
network_data=None, |
|
visual_data=visual_sample, |
|
label="malware", |
|
confidence=0.92, |
|
timestamp=datetime.now().isoformat() |
|
), |
|
MultimodalSample( |
|
sample_id="sample_003", |
|
text_data=None, |
|
network_data=network_sample, |
|
visual_data=None, |
|
label="benign", |
|
confidence=0.78, |
|
timestamp=datetime.now().isoformat() |
|
) |
|
] |
|
|
|
|
|
print("๐ง Testing batch preparation...") |
|
batch = multimodal_system.prepare_batch(samples) |
|
print(f" Batch components: {list(batch.keys())}") |
|
if 'text_tokens' in batch: |
|
print(f" Text tokens shape: {batch['text_tokens'].shape}") |
|
if 'network_data' in batch: |
|
print(f" Network flow features shape: {batch['network_data']['flow_features'].shape}") |
|
if 'visual_data' in batch: |
|
print(f" Visual data shape: {batch['visual_data'].shape}") |
|
|
|
|
|
print("\n๐ฎ Testing multimodal inference...") |
|
predictions = multimodal_system.predict(samples) |
|
|
|
for pred in predictions: |
|
print(f"\n Sample: {pred['sample_id']}") |
|
print(f" Predicted: {pred['predicted_class']}") |
|
print(f" Confidence: {pred['confidence']:.3f}") |
|
print(f" Top 3 probabilities:") |
|
sorted_probs = sorted(pred['class_probabilities'].items(), |
|
key=lambda x: x[1], reverse=True)[:3] |
|
for class_name, prob in sorted_probs: |
|
print(f" {class_name}: {prob:.3f}") |
|
|
|
|
|
print("\n๐ Testing training step...") |
|
losses = multimodal_system.train_step(batch) |
|
print(f" Training losses: {losses}") |
|
|
|
print("\nโ
Multimodal Learning System implemented and tested") |
|
print(f" Model parameters: {sum(p.numel() for p in multimodal_system.model.parameters()):,}") |
|
print(f" Supported modalities: Text, Network, Visual") |
|
print(f" Fusion strategy: Cross-modal attention with learnable gates") |
|
|