File size: 31,577 Bytes
23804b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 |
#!/usr/bin/env python3
"""
Cybersecurity Fine-Tuning Pipeline for Free Models
Specialized training for security domain without external restrictions
"""
import json
import logging
import asyncio
import os
import subprocess
from typing import Dict, List, Any, Optional, Union, Tuple
from pathlib import Path
import yaml
import time
import hashlib
from dataclasses import dataclass
@dataclass
class FineTuningConfig:
"""Configuration for cybersecurity fine-tuning"""
model_name: str
base_model: str
dataset_path: str
output_path: str
training_args: Dict[str, Any]
security_focus: str = "general"
filter_override: bool = True
domain_specialization: str = "cybersecurity"
class CybersecurityDatasetBuilder:
"""
Builds cybersecurity training datasets from multiple sources
Focuses on legitimate security research and professional use
"""
def __init__(self, output_dir: str = "data/cybersec_training"):
self.logger = logging.getLogger(__name__)
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# Security domain categories
self.security_domains = {
"vulnerability_assessment": "Network and application vulnerability analysis",
"malware_analysis": "Malware reverse engineering and analysis",
"penetration_testing": "Ethical hacking and security testing",
"incident_response": "Security incident handling and forensics",
"threat_intelligence": "Threat analysis and intelligence gathering",
"security_tools": "Development and usage of security tools",
"compliance": "Security compliance and risk management",
"cryptography": "Cryptographic implementations and analysis"
}
def generate_training_data(self) -> Dict[str, str]:
"""Generate comprehensive cybersecurity training data"""
self.logger.info("π Generating cybersecurity training dataset...")
training_data = {
"conversations": [],
"instructions": [],
"code_examples": [],
"analysis_cases": []
}
# Generate domain-specific training data
for domain, description in self.security_domains.items():
domain_data = self._generate_domain_data(domain, description)
training_data["conversations"].extend(domain_data["conversations"])
training_data["instructions"].extend(domain_data["instructions"])
training_data["code_examples"].extend(domain_data["code_examples"])
training_data["analysis_cases"].extend(domain_data["analysis_cases"])
# Save training data
dataset_files = self._save_training_data(training_data)
self.logger.info(f"β
Generated {len(training_data['conversations'])} training examples")
return dataset_files
def _generate_domain_data(self, domain: str, description: str) -> Dict[str, List]:
"""Generate training data for specific security domain"""
domain_data = {
"conversations": [],
"instructions": [],
"code_examples": [],
"analysis_cases": []
}
if domain == "vulnerability_assessment":
domain_data.update(self._generate_vuln_assessment_data())
elif domain == "malware_analysis":
domain_data.update(self._generate_malware_analysis_data())
elif domain == "penetration_testing":
domain_data.update(self._generate_pentest_data())
elif domain == "incident_response":
domain_data.update(self._generate_incident_response_data())
elif domain == "threat_intelligence":
domain_data.update(self._generate_threat_intel_data())
elif domain == "security_tools":
domain_data.update(self._generate_security_tools_data())
elif domain == "compliance":
domain_data.update(self._generate_compliance_data())
elif domain == "cryptography":
domain_data.update(self._generate_crypto_data())
return domain_data
def _generate_vuln_assessment_data(self) -> Dict[str, List]:
"""Generate vulnerability assessment training data"""
conversations = [
{
"system": "You are a cybersecurity expert specializing in vulnerability assessment. Provide technical analysis for legitimate security testing.",
"human": "How do I perform a comprehensive web application vulnerability assessment?",
"assistant": """A comprehensive web application vulnerability assessment involves several systematic phases:
**1. Reconnaissance and Information Gathering**
- Identify web technologies, frameworks, and versions
- Map application architecture and attack surface
- Gather DNS records, subdomains, and directory structures
**2. Automated Scanning**
- Use tools like OWASP ZAP, Burp Suite, or Nessus
- Scan for common vulnerabilities (OWASP Top 10)
- Identify missing security headers and misconfigurations
**3. Manual Testing**
- Authentication bypass attempts
- Input validation testing (XSS, SQL injection)
- Business logic flaw identification
- Session management analysis
**4. Exploitation and Impact Assessment**
- Safely demonstrate vulnerability impact
- Document proof-of-concept code
- Assess business risk and data exposure
**5. Reporting and Remediation**
- Prioritize findings by risk level
- Provide specific remediation guidance
- Include timeline for vulnerability fixes
Key tools: Burp Suite Professional, OWASP ZAP, Nmap, SQLMap, Nikto
Remember to always obtain proper authorization before testing."""
},
{
"system": "You are a vulnerability assessment specialist providing technical guidance for security professionals.",
"human": "What are the key indicators of SQL injection vulnerabilities?",
"assistant": """SQL injection vulnerabilities can be identified through several key indicators:
**1. Error-Based Indicators**
- Database error messages in HTTP responses
- Different responses to invalid vs. valid input
- SQL syntax errors exposed to the user interface
**2. Boolean-Based Blind Indicators**
- Different response patterns for true/false conditions
- Conditional logic affecting page behavior
- Time delays in response patterns
**3. Union-Based Indicators**
- Ability to control number of columns in SELECT
- Data from other tables appearing in responses
- Successful UNION SELECT operations
**4. Time-Based Blind Indicators**
- Consistent delays when using time functions
- Response time variations based on conditions
- Sleep/waitfor delay functions working
**Testing Payloads:**
- Single quote (') to test for escaping
- OR 1=1-- to test boolean logic
- UNION SELECT 1,2,3-- for union injection
- '; WAITFOR DELAY '00:00:05'-- for time delays
**Detection Tools:**
- SQLMap for automated detection
- Burp Suite's scanner extensions
- Manual payload testing in web proxies
Always test in authorized environments only."""
}
]
instructions = [
{
"instruction": "Create a vulnerability assessment checklist for web applications",
"input": "I need a systematic approach to assess web application security",
"output": """# Web Application Vulnerability Assessment Checklist
## Pre-Assessment Phase
- [ ] Obtain written authorization from client
- [ ] Define scope and testing boundaries
- [ ] Set up testing environment and tools
- [ ] Document baseline application functionality
## Information Gathering
- [ ] Passive reconnaissance (DNS, WHOIS, search engines)
- [ ] Technology fingerprinting (Wappalyzer, Whatweb)
- [ ] Directory and file enumeration (Dirbuster, Gobuster)
- [ ] Parameter discovery and analysis
## Automated Scanning
- [ ] Vulnerability scanner deployment (Nessus, OpenVAS)
- [ ] Web application scanner (OWASP ZAP, Burp Suite)
- [ ] SSL/TLS configuration testing (SSLyze, testssl.sh)
- [ ] Security header analysis (SecurityHeaders.com)
## Manual Testing Categories
### Authentication & Session Management
- [ ] Password policy assessment
- [ ] Account lockout mechanisms
- [ ] Session token security
- [ ] Multi-factor authentication bypass
### Input Validation
- [ ] SQL injection testing
- [ ] Cross-site scripting (XSS)
- [ ] Command injection
- [ ] Path traversal/directory traversal
- [ ] XML External Entity (XXE)
### Authorization & Access Control
- [ ] Vertical privilege escalation
- [ ] Horizontal privilege escalation
- [ ] Direct object references
- [ ] Function-level access control
### Business Logic
- [ ] Workflow bypass attempts
- [ ] Race conditions
- [ ] Price manipulation
- [ ] Process validation flaws
## Documentation & Reporting
- [ ] Evidence collection and screenshots
- [ ] Risk rating and impact analysis
- [ ] Remediation recommendations
- [ ] Executive summary preparation"""
}
]
code_examples = [
{
"description": "Python script for basic SQL injection testing",
"language": "python",
"code": """#!/usr/bin/env python3
\"\"\"
Basic SQL Injection Testing Script
For authorized penetration testing only
\"\"\"
import requests
import time
import sys
from urllib.parse import urljoin
class SQLInjectionTester:
def __init__(self, target_url, session=None):
self.target_url = target_url
self.session = session or requests.Session()
self.payloads = [
"' OR '1'='1",
"' OR '1'='1'--",
"' OR '1'='1'/*",
"admin'--",
"admin'/*",
"' OR 1=1--",
"') OR ('1'='1",
"') OR ('1'='1'--"
]
def test_parameter(self, param_name, param_value):
\"\"\"Test a specific parameter for SQL injection\"\"\"
results = []
for payload in self.payloads:
test_value = param_value + payload
data = {param_name: test_value}
try:
response = self.session.post(self.target_url, data=data)
# Check for common SQL error patterns
sql_errors = [
'mysql_fetch_array',
'ORA-01756',
'Microsoft OLE DB Provider',
'SQLServer JDBC Driver',
'PostgreSQL query failed',
'Warning: mysql_',
'valid MySQL result',
'MySqlClient.MySqlException'
]
for error in sql_errors:
if error.lower() in response.text.lower():
results.append({
'payload': payload,
'vulnerable': True,
'error_type': error,
'response_length': len(response.text)
})
break
else:
results.append({
'payload': payload,
'vulnerable': False,
'response_length': len(response.text)
})
# Add delay to avoid overwhelming the server
time.sleep(0.5)
except requests.exceptions.RequestException as e:
print(f"Request error: {e}")
continue
return results
def generate_report(self, test_results):
\"\"\"Generate a vulnerability report\"\"\"
vulnerable_payloads = [r for r in test_results if r['vulnerable']]
print("\\n" + "="*50)
print("SQL Injection Test Results")
print("="*50)
print(f"Target: {self.target_url}")
print(f"Total payloads tested: {len(test_results)}")
print(f"Vulnerable responses: {len(vulnerable_payloads)}")
if vulnerable_payloads:
print("\\nVULNERABLE PAYLOADS DETECTED:")
for result in vulnerable_payloads:
print(f" Payload: {result['payload']}")
print(f" Error: {result.get('error_type', 'Unknown')}")
print()
else:
print("\\nNo SQL injection vulnerabilities detected.")
return len(vulnerable_payloads) > 0
# Usage example (authorized testing only)
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python sqli_test.py <target_url>")
sys.exit(1)
target = sys.argv[1]
tester = SQLInjectionTester(target)
# Test common parameter names
common_params = ['username', 'user', 'login', 'id', 'search', 'q']
for param in common_params:
print(f"Testing parameter: {param}")
results = tester.test_parameter(param, "test")
tester.generate_report(results)"""
}
]
return {
"conversations": conversations,
"instructions": instructions,
"code_examples": code_examples,
"analysis_cases": []
}
def _generate_malware_analysis_data(self) -> Dict[str, List]:
"""Generate malware analysis training data"""
conversations = [
{
"system": "You are a malware analyst providing technical guidance for security research and incident response.",
"human": "What are the key steps in static malware analysis?",
"assistant": """Static malware analysis involves examining malware samples without execution. Here's a systematic approach:
**1. Safe Environment Setup**
- Isolated analysis VM (no network connectivity)
- Snapshot-based rollback capability
- Proper tool installation (hex editors, disassemblers)
**2. Initial Triage**
- File type identification (file command, TrID)
- Hash calculation (MD5, SHA-1, SHA-256)
- Virus scanning (multiple engines)
- Packer/obfuscation detection (PEiD, Detect It Easy)
**3. Metadata Analysis**
- PE header examination (PE-bear, CFF Explorer)
- Import/export table analysis
- Resource section inspection
- Digital signature verification
**4. String Analysis**
- Printable string extraction (strings utility)
- Unicode string identification
- URL, IP, and registry key discovery
- Potential C2 infrastructure indicators
**5. Disassembly and Code Analysis**
- Disassembly with IDA Pro/Ghidra/x64dbg
- Control flow analysis
- Function identification and naming
- API call analysis
**6. Behavioral Indicators**
- Registry modification patterns
- File system operations
- Network communication protocols
- Persistence mechanisms
**Tools: IDA Pro, Ghidra, x64dbg, PE-bear, strings, Volatility**
**Remember: Always use isolated environments and proper containment.**"""
}
]
return {"conversations": conversations, "instructions": [], "code_examples": [], "analysis_cases": []}
def _generate_pentest_data(self) -> Dict[str, List]:
"""Generate penetration testing training data"""
conversations = [
{
"system": "You are a penetration testing expert providing guidance for authorized security testing.",
"human": "How should I approach network penetration testing methodology?",
"assistant": """Network penetration testing follows a structured methodology:
**1. Pre-Engagement**
- Scope definition and rules of engagement
- Legal authorization and contracts
- Emergency contact procedures
- Testing timeline establishment
**2. Information Gathering**
- Passive reconnaissance (OSINT)
- DNS enumeration and zone transfers
- Network range identification
- Social media intelligence gathering
**3. Active Scanning**
- Host discovery (Nmap ping sweeps)
- Port scanning and service enumeration
- OS fingerprinting and version detection
- Vulnerability scanning (Nessus, OpenVAS)
**4. Enumeration**
- Service-specific enumeration
- SMB/NetBIOS information gathering
- SNMP community string testing
- Web service fingerprinting
**5. Vulnerability Assessment**
- CVE research and validation
- Custom vulnerability verification
- False positive elimination
- Impact and exploitability analysis
**6. Exploitation**
- Proof-of-concept development
- Privilege escalation attempts
- Lateral movement techniques
- Data exfiltration simulation
**7. Post-Exploitation**
- Persistence establishment
- Additional system compromise
- Evidence collection and documentation
- Clean-up and artifact removal
**Tools: Nmap, Metasploit, Burp Suite, Cobalt Strike, BloodHound**
**Always maintain detailed logs and evidence for reporting.**"""
}
]
return {"conversations": conversations, "instructions": [], "code_examples": [], "analysis_cases": []}
def _generate_incident_response_data(self) -> Dict[str, List]:
"""Generate incident response training data"""
return {"conversations": [], "instructions": [], "code_examples": [], "analysis_cases": []}
def _generate_threat_intel_data(self) -> Dict[str, List]:
"""Generate threat intelligence training data"""
return {"conversations": [], "instructions": [], "code_examples": [], "analysis_cases": []}
def _generate_security_tools_data(self) -> Dict[str, List]:
"""Generate security tools training data"""
return {"conversations": [], "instructions": [], "code_examples": [], "analysis_cases": []}
def _generate_compliance_data(self) -> Dict[str, List]:
"""Generate compliance training data"""
return {"conversations": [], "instructions": [], "code_examples": [], "analysis_cases": []}
def _generate_crypto_data(self) -> Dict[str, List]:
"""Generate cryptography training data"""
return {"conversations": [], "instructions": [], "code_examples": [], "analysis_cases": []}
def _save_training_data(self, training_data: Dict[str, List]) -> Dict[str, str]:
"""Save training data to files"""
dataset_files = {}
# Save conversations in ChatML format
conversations_file = self.output_dir / "cybersec_conversations.jsonl"
with open(conversations_file, 'w') as f:
for conv in training_data["conversations"]:
f.write(json.dumps(conv) + '\n')
dataset_files["conversations"] = str(conversations_file)
# Save instructions in Alpaca format
instructions_file = self.output_dir / "cybersec_instructions.jsonl"
with open(instructions_file, 'w') as f:
for inst in training_data["instructions"]:
f.write(json.dumps(inst) + '\n')
dataset_files["instructions"] = str(instructions_file)
# Save code examples
code_file = self.output_dir / "cybersec_code.jsonl"
with open(code_file, 'w') as f:
for code in training_data["code_examples"]:
f.write(json.dumps(code) + '\n')
dataset_files["code"] = str(code_file)
# Create combined dataset
combined_file = self.output_dir / "cybersec_combined.jsonl"
with open(combined_file, 'w') as f:
# Convert conversations to instruction format
for conv in training_data["conversations"]:
combined_entry = {
"instruction": conv["human"],
"input": "",
"output": conv["assistant"],
"system": conv["system"]
}
f.write(json.dumps(combined_entry) + '\n')
# Add instructions
for inst in training_data["instructions"]:
f.write(json.dumps(inst) + '\n')
dataset_files["combined"] = str(combined_file)
self.logger.info(f"β
Saved training data to {len(dataset_files)} files")
return dataset_files
class CybersecurityFineTuner:
"""
Fine-tunes free models for cybersecurity domain
Removes generic AI restrictions and optimizes for security tasks
"""
def __init__(self, config_path: str = "configs/finetune_config.yaml"):
self.logger = logging.getLogger(__name__)
self.config_path = config_path
self.config = self._load_config()
# Security-focused training configuration
self.security_training_args = {
"learning_rate": 5e-5,
"batch_size": 4,
"gradient_accumulation_steps": 4,
"max_seq_length": 2048,
"num_epochs": 3,
"warmup_ratio": 0.1,
"save_steps": 500,
"evaluation_strategy": "steps",
"eval_steps": 500,
"logging_steps": 100,
"load_best_model_at_end": True,
"remove_unused_columns": False,
"dataloader_pin_memory": False
}
def _load_config(self) -> Dict[str, Any]:
"""Load fine-tuning configuration"""
if Path(self.config_path).exists():
with open(self.config_path, 'r') as f:
return yaml.safe_load(f)
else:
# Create default configuration
return self._create_default_config()
def _create_default_config(self) -> Dict[str, Any]:
"""Create default fine-tuning configuration"""
config = {
"models": [
{
"name": "llama3_cybersec_ft",
"base_model": "meta-llama/Llama-2-7b-chat-hf",
"security_focus": "general",
"training_args": self.security_training_args
},
{
"name": "mistral_security_ft",
"base_model": "mistralai/Mistral-7B-Instruct-v0.1",
"security_focus": "analysis",
"training_args": self.security_training_args
}
],
"dataset_config": {
"train_ratio": 0.8,
"eval_ratio": 0.1,
"test_ratio": 0.1,
"max_length": 2048
},
"security_overrides": {
"remove_safety_filters": True,
"domain_focus": "cybersecurity",
"ethical_framework": "professional_security"
}
}
# Save configuration
Path(self.config_path).parent.mkdir(parents=True, exist_ok=True)
with open(self.config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
return config
def prepare_dataset(self, dataset_file: str) -> str:
"""Prepare dataset for fine-tuning"""
self.logger.info(f"π Preparing dataset from {dataset_file}")
# Load dataset
with open(dataset_file, 'r') as f:
data = [json.loads(line) for line in f]
self.logger.info(f"π Loaded {len(data)} training examples")
# Split dataset
train_size = int(len(data) * self.config["dataset_config"]["train_ratio"])
eval_size = int(len(data) * self.config["dataset_config"]["eval_ratio"])
train_data = data[:train_size]
eval_data = data[train_size:train_size + eval_size]
test_data = data[train_size + eval_size:]
# Save splits
dataset_dir = Path(dataset_file).parent / "splits"
dataset_dir.mkdir(exist_ok=True)
splits = {
"train": train_data,
"eval": eval_data,
"test": test_data
}
split_files = {}
for split_name, split_data in splits.items():
split_file = dataset_dir / f"{split_name}.jsonl"
with open(split_file, 'w') as f:
for item in split_data:
f.write(json.dumps(item) + '\n')
split_files[split_name] = str(split_file)
self.logger.info(f"β
Dataset prepared: {len(train_data)} train, {len(eval_data)} eval, {len(test_data)} test")
return str(dataset_dir)
def fine_tune_model(self, model_config: Dict[str, Any], dataset_dir: str) -> str:
"""Fine-tune model for cybersecurity domain"""
self.logger.info(f"π§ Fine-tuning {model_config['name']}...")
# Create fine-tuning script
finetune_script = f"""#!/usr/bin/env python3
import os
import json
import torch
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
TrainingArguments, Trainer, DataCollatorForLanguageModeling
)
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
def load_dataset(file_path):
with open(file_path, 'r') as f:
data = [json.loads(line) for line in f]
return Dataset.from_list(data)
def format_prompt(example):
if "system" in example and example["system"]:
return f"<s>[INST] <<SYS>>\\n{{example['system']}}\\n<</SYS>>\\n\\n{{example['instruction']}} [/INST] {{example['output']}} </s>"
else:
return f"<s>[INST] {{example['instruction']}} [/INST] {{example['output']}} </s>"
# Load model and tokenizer
model_name = "{model_config['base_model']}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Add padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Setup LoRA configuration for efficient fine-tuning
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.1,
bias="none"
)
model = get_peft_model(model, lora_config)
# Load and prepare datasets
train_dataset = load_dataset("{dataset_dir}/train.jsonl")
eval_dataset = load_dataset("{dataset_dir}/eval.jsonl")
def tokenize_function(examples):
# Format prompts
texts = [format_prompt(example) for example in examples]
# Tokenize
tokenized = tokenizer(
texts,
padding=True,
truncation=True,
max_length={model_config['training_args']['max_seq_length']},
return_tensors="pt"
)
# Set labels for language modeling
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
# Apply tokenization
train_dataset = train_dataset.map(tokenize_function, batched=True)
eval_dataset = eval_dataset.map(tokenize_function, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir="{model_config['name']}_checkpoint",
learning_rate={model_config['training_args']['learning_rate']},
per_device_train_batch_size={model_config['training_args']['batch_size']},
per_device_eval_batch_size={model_config['training_args']['batch_size']},
gradient_accumulation_steps={model_config['training_args']['gradient_accumulation_steps']},
num_train_epochs={model_config['training_args']['num_epochs']},
warmup_ratio={model_config['training_args']['warmup_ratio']},
save_steps={model_config['training_args']['save_steps']},
eval_steps={model_config['training_args']['eval_steps']},
evaluation_strategy="{model_config['training_args']['evaluation_strategy']}",
logging_steps={model_config['training_args']['logging_steps']},
load_best_model_at_end={model_config['training_args']['load_best_model_at_end']},
metric_for_best_model="eval_loss",
greater_is_better=False,
dataloader_pin_memory=False,
fp16=True,
gradient_checkpointing=True,
report_to="none"
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=8
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=tokenizer
)
# Train the model
trainer.train()
# Save the final model
trainer.save_model("{model_config['name']}_final")
tokenizer.save_pretrained("{model_config['name']}_final")
print("β
Fine-tuning completed!")
"""
# Save and run fine-tuning script
script_path = f"finetune_{model_config['name']}.py"
with open(script_path, 'w') as f:
f.write(finetune_script)
self.logger.info(f"πΎ Fine-tuning script saved to {script_path}")
return script_path
def run_fine_tuning(self) -> List[str]:
"""Run fine-tuning for all configured models"""
self.logger.info("π Starting cybersecurity fine-tuning pipeline...")
# Generate training data
dataset_builder = CybersecurityDatasetBuilder()
dataset_files = dataset_builder.generate_training_data()
# Prepare dataset
dataset_dir = self.prepare_dataset(dataset_files["combined"])
# Fine-tune each model
trained_models = []
for model_config in self.config["models"]:
script_path = self.fine_tune_model(model_config, dataset_dir)
trained_models.append({
"name": model_config["name"],
"script": script_path,
"base_model": model_config["base_model"]
})
self.logger.info(f"β
Fine-tuning pipeline prepared for {len(trained_models)} models")
return trained_models
# Command-line interface
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Cybersecurity Fine-Tuning Pipeline")
parser.add_argument("--action", choices=["generate", "prepare", "train"],
default="generate", help="Action to perform")
parser.add_argument("--config", default="configs/finetune_config.yaml",
help="Configuration file path")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
if args.action == "generate":
# Generate training data only
dataset_builder = CybersecurityDatasetBuilder()
dataset_files = dataset_builder.generate_training_data()
print(f"β
Training data generated: {dataset_files}")
elif args.action == "prepare":
# Prepare dataset for training
finetuner = CybersecurityFineTuner(args.config)
dataset_builder = CybersecurityDatasetBuilder()
dataset_files = dataset_builder.generate_training_data()
dataset_dir = finetuner.prepare_dataset(dataset_files["combined"])
print(f"β
Dataset prepared: {dataset_dir}")
elif args.action == "train":
# Run full fine-tuning pipeline
finetuner = CybersecurityFineTuner(args.config)
trained_models = finetuner.run_fine_tuning()
print("\nπ― Fine-tuning scripts generated:")
for model in trained_models:
print(f" {model['name']}: {model['script']}")
print("\nπ To run fine-tuning:")
for model in trained_models:
print(f" python {model['script']}")
|