“Transcendental-Programmer” commited on
Commit
754afec
·
1 Parent(s): 3de89cd

FEAT: added server coordination and model aggregation logic

Browse files
src/main.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import logging
4
+ import logging.config
5
+ from pathlib import Path
6
+ from src.server.coordinator import FederatedCoordinator
7
+ from src.client.model import FederatedClient
8
+
9
+ def setup_logging(config):
10
+ """Setup logging configuration."""
11
+ # Create logs directory if it doesn't exist
12
+ Path("logs").mkdir(exist_ok=True)
13
+
14
+ log_level = (config.get('monitoring', {}).get('log_level')
15
+ or config.get('server', {}).get('monitoring', {}).get('log_level')
16
+ or config.get('client', {}).get('monitoring', {}).get('log_level')
17
+ or 'INFO')
18
+
19
+ # Configure logging with UTF-8 encoding
20
+ logging.basicConfig(
21
+ level=log_level,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23
+ handlers=[
24
+ logging.StreamHandler(),
25
+ logging.FileHandler('logs/federated_learning.log', mode='a', encoding='utf-8')
26
+ ]
27
+ )
28
+
29
+ # Reduce TensorFlow logging noise
30
+ logging.getLogger('tensorflow').setLevel(logging.WARNING)
31
+
32
+ # Create a divider in the log file
33
+ logger = logging.getLogger(__name__)
34
+ logger.info("\n" + "="*50)
35
+ logger.info("New Training Session Started")
36
+ logger.info("="*50 + "\n")
37
+
38
+ def load_config(config_path: str) -> dict:
39
+ with open(config_path, 'r') as f:
40
+ return yaml.safe_load(f)
41
+
42
+ def main():
43
+ parser = argparse.ArgumentParser(description='Federated Learning Demo')
44
+ parser.add_argument('--mode', choices=['server', 'client'], required=True)
45
+ parser.add_argument('--config', type=str, required=True)
46
+ args = parser.parse_args()
47
+
48
+ config = load_config(args.config)
49
+ setup_logging(config)
50
+ logger = logging.getLogger(__name__)
51
+
52
+ if args.mode == 'server':
53
+ coordinator = FederatedCoordinator(config)
54
+ logger.info("Starting server...")
55
+ coordinator.start()
56
+ else:
57
+ client = FederatedClient(1, config)
58
+ logger.info(f"Starting client with ID: {client.client_id}")
59
+ client.start()
60
+
61
+ if __name__ == "__main__":
62
+ main()
src/models/gan.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GAN implementation for financial data generation."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from typing import Dict, Tuple
7
+
8
+ class Generator(nn.Module):
9
+ def __init__(self, latent_dim: int, feature_dim: int, hidden_dims: List[int]):
10
+ super().__init__()
11
+
12
+ layers = []
13
+ prev_dim = latent_dim
14
+
15
+ for hidden_dim in hidden_dims:
16
+ layers.extend([
17
+ nn.Linear(prev_dim, hidden_dim),
18
+ nn.BatchNorm1d(hidden_dim),
19
+ nn.LeakyReLU(0.2),
20
+ nn.Dropout(0.3)
21
+ ])
22
+ prev_dim = hidden_dim
23
+
24
+ layers.append(nn.Linear(prev_dim, feature_dim))
25
+ layers.append(nn.Tanh())
26
+
27
+ self.model = nn.Sequential(*layers)
28
+
29
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
30
+ return self.model(z)
31
+
32
+ class Discriminator(nn.Module):
33
+ def __init__(self, feature_dim: int, hidden_dims: List[int]):
34
+ super().__init__()
35
+
36
+ layers = []
37
+ prev_dim = feature_dim
38
+
39
+ for hidden_dim in hidden_dims:
40
+ layers.extend([
41
+ nn.Linear(prev_dim, hidden_dim),
42
+ nn.LeakyReLU(0.2),
43
+ nn.Dropout(0.3)
44
+ ])
45
+ prev_dim = hidden_dim
46
+
47
+ layers.append(nn.Linear(prev_dim, 1))
48
+ layers.append(nn.Sigmoid())
49
+
50
+ self.model = nn.Sequential(*layers)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ return self.model(x)
54
+
55
+ class FinancialGAN:
56
+ def __init__(self, config: Dict):
57
+ """Initialize the GAN."""
58
+ self.latent_dim = config['model']['latent_dim']
59
+ self.feature_dim = config['model']['feature_dim']
60
+ self.hidden_dims = config['model']['hidden_dims']
61
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
62
+
63
+ self.generator = Generator(
64
+ self.latent_dim,
65
+ self.feature_dim,
66
+ self.hidden_dims
67
+ ).to(self.device)
68
+
69
+ self.discriminator = Discriminator(
70
+ self.feature_dim,
71
+ self.hidden_dims[::-1]
72
+ ).to(self.device)
73
+
74
+ self.g_optimizer = optim.Adam(
75
+ self.generator.parameters(),
76
+ lr=config['model']['learning_rate']
77
+ )
78
+ self.d_optimizer = optim.Adam(
79
+ self.discriminator.parameters(),
80
+ lr=config['model']['learning_rate']
81
+ )
82
+
83
+ self.criterion = nn.BCELoss()
84
+
85
+ def train_step(self, real_data: torch.Tensor) -> Tuple[float, float]:
86
+ """Perform one training step."""
87
+ batch_size = real_data.size(0)
88
+ real_label = torch.ones(batch_size, 1).to(self.device)
89
+ fake_label = torch.zeros(batch_size, 1).to(self.device)
90
+
91
+ # Train Discriminator
92
+ self.d_optimizer.zero_grad()
93
+ d_real_output = self.discriminator(real_data)
94
+ d_real_loss = self.criterion(d_real_output, real_label)
95
+
96
+ z = torch.randn(batch_size, self.latent_dim).to(self.device)
97
+ fake_data = self.generator(z)
98
+ d_fake_output = self.discriminator(fake_data.detach())
99
+ d_fake_loss = self.criterion(d_fake_output, fake_label)
100
+
101
+ d_loss = d_real_loss + d_fake_loss
102
+ d_loss.backward()
103
+ self.d_optimizer.step()
104
+
105
+ # Train Generator
106
+ self.g_optimizer.zero_grad()
107
+ g_output = self.discriminator(fake_data)
108
+ g_loss = self.criterion(g_output, real_label)
109
+ g_loss.backward()
110
+ self.g_optimizer.step()
111
+
112
+ return g_loss.item(), d_loss.item()
113
+
114
+ def generate_samples(self, num_samples: int) -> torch.Tensor:
115
+ """Generate synthetic financial data."""
116
+ self.generator.eval()
117
+ with torch.no_grad():
118
+ z = torch.randn(num_samples, self.latent_dim).to(self.device)
119
+ samples = self.generator(z)
120
+ self.generator.train()
121
+ return samples
122
+
src/models/vae.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """vae.py module."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import List
7
+
8
+ class VAE(nn.Module):
9
+ def __init__(self, input_dim: int, latent_dim: int, hidden_dims: List[int]):
10
+ super(VAE, self).__init__()
11
+
12
+ # Encoder
13
+ modules = []
14
+ in_features = input_dim
15
+ for h_dim in hidden_dims:
16
+ modules.append(nn.Linear(in_features, h_dim))
17
+ modules.append(nn.ReLU())
18
+ in_features = h_dim
19
+ self.encoder = nn.Sequential(*modules)
20
+
21
+ # Latent space
22
+ self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
23
+ self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
24
+
25
+ # Decoder
26
+ modules = []
27
+ hidden_dims.reverse()
28
+ in_features = latent_dim
29
+ for h_dim in hidden_dims:
30
+ modules.append(nn.Linear(in_features, h_dim))
31
+ modules.append(nn.ReLU())
32
+ in_features = h_dim
33
+ modules.append(nn.Linear(hidden_dims[-1], input_dim))
34
+ self.decoder = nn.Sequential(*modules)
35
+
36
+ def encode(self, x):
37
+ h = self.encoder(x)
38
+ return self.fc_mu(h), self.fc_var(h)
39
+
40
+ def decode(self, z):
41
+ return self.decoder(z)
42
+
43
+ def reparameterize(self, mu, log_var):
44
+ std = torch.exp(0.5 * log_var)
45
+ eps = torch.randn_like(std)
46
+ return mu + eps * std
47
+
48
+ def forward(self, x):
49
+ mu, log_var = self.encode(x)
50
+ z = self.reparameterize(mu, log_var)
51
+ return self.decode(z), mu, log_var
52
+
src/rag/generator.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generator component for the RAG system."""
2
+
3
+ from typing import List, Dict
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ LogitsProcessor,
9
+ LogitsProcessorList
10
+ )
11
+
12
+ class FinancialContextProcessor(LogitsProcessor):
13
+ """Custom logits processor for financial context."""
14
+ def __init__(self, financial_constraints: Dict):
15
+ self.constraints = financial_constraints
16
+
17
+ def __call__(self, input_ids: torch.LongTensor,
18
+ scores: torch.FloatTensor) -> torch.FloatTensor:
19
+ # Apply financial domain constraints
20
+ # This is a placeholder for actual constraints
21
+ return scores
22
+
23
+ class RAGGenerator:
24
+ def __init__(self, config: Dict):
25
+ """Initialize the generator."""
26
+ self.model_name = "gpt2" # Can be configured based on needs
27
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
28
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
29
+ self.max_length = 512
30
+
31
+ def prepare_context(self, retrieved_docs: List[Dict]) -> str:
32
+ """Prepare context from retrieved documents."""
33
+ context = ""
34
+ for doc in retrieved_docs:
35
+ context += f"{doc['document']['text']}\n"
36
+ return context.strip()
37
+
38
+ def generate(self, query: str, retrieved_docs: List[Dict],
39
+ financial_constraints: Dict = None) -> str:
40
+ """Generate text based on query and retrieved documents."""
41
+ context = self.prepare_context(retrieved_docs)
42
+ prompt = f"Context: {context}\nQuery: {query}\nResponse:"
43
+
44
+ # Prepare logits processors
45
+ processors = LogitsProcessorList()
46
+ if financial_constraints:
47
+ processors.append(FinancialContextProcessor(financial_constraints))
48
+
49
+ # Generate response
50
+ inputs = self.tokenizer(prompt, return_tensors="pt")
51
+ outputs = self.model.generate(
52
+ inputs.input_ids,
53
+ max_length=self.max_length,
54
+ num_return_sequences=1,
55
+ logits_processor=processors,
56
+ do_sample=True,
57
+ temperature=0.7,
58
+ top_p=0.9
59
+ )
60
+
61
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+
src/rag/retriever.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retrieval component for the RAG system."""
2
+
3
+ import faiss
4
+ import numpy as np
5
+ from typing import List, Dict, Tuple
6
+ from elasticsearch import Elasticsearch
7
+ from transformers import AutoTokenizer, AutoModel
8
+ import torch
9
+
10
+ class FinancialDataRetriever:
11
+ def __init__(self, config: Dict):
12
+ """Initialize the retriever with configuration."""
13
+ self.retriever_type = config['rag']['retriever']
14
+ self.max_documents = config['rag']['max_documents']
15
+ self.similarity_threshold = config['rag']['similarity_threshold']
16
+
17
+ # Initialize FAISS index
18
+ self.dimension = 768 # BERT embedding dimension
19
+ self.index = faiss.IndexFlatL2(self.dimension)
20
+
21
+ # Initialize transformer model for embeddings
22
+ self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
23
+ self.model = AutoModel.from_pretrained('bert-base-uncased')
24
+
25
+ # Initialize Elasticsearch if needed
26
+ if self.retriever_type == "elasticsearch":
27
+ self.es = Elasticsearch()
28
+
29
+ def encode_text(self, texts: List[str]) -> np.ndarray:
30
+ """Encode text using BERT."""
31
+ tokens = self.tokenizer(texts, padding=True, truncation=True,
32
+ return_tensors="pt", max_length=512)
33
+ with torch.no_grad():
34
+ outputs = self.model(**tokens)
35
+ embeddings = outputs.last_hidden_state[:, 0, :].numpy()
36
+ return embeddings
37
+
38
+ def index_documents(self, documents: List[Dict]):
39
+ """Index documents for retrieval."""
40
+ if self.retriever_type == "faiss":
41
+ texts = [doc['text'] for doc in documents]
42
+ embeddings = self.encode_text(texts)
43
+ self.index.add(embeddings)
44
+ self.documents = documents
45
+ else:
46
+ for doc in documents:
47
+ self.es.index(index="financial_data", document=doc)
48
+
49
+ def retrieve(self, query: str, k: int = None) -> List[Dict]:
50
+ """Retrieve relevant documents."""
51
+ k = k or self.max_documents
52
+ query_embedding = self.encode_text([query])
53
+
54
+ if self.retriever_type == "faiss":
55
+ distances, indices = self.index.search(query_embedding, k)
56
+ results = [
57
+ {
58
+ 'document': self.documents[idx],
59
+ 'score': float(1 / (1 + dist))
60
+ }
61
+ for dist, idx in zip(distances[0], indices[0])
62
+ if 1 / (1 + dist) >= self.similarity_threshold
63
+ ]
64
+ else:
65
+ response = self.es.search(
66
+ index="financial_data",
67
+ query={
68
+ "match": {
69
+ "text": query
70
+ }
71
+ },
72
+ size=k
73
+ )
74
+ results = [
75
+ {
76
+ 'document': hit['_source'],
77
+ 'score': hit['_score']
78
+ }
79
+ for hit in response['hits']['hits']
80
+ if hit['_score'] >= self.similarity_threshold
81
+ ]
82
+
83
+ return results
84
+
src/server/aggregator.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """aggregator.py module."""
2
+
3
+ import tensorflow as tf
4
+ from typing import List, Dict
5
+ import numpy as np
6
+ from collections import defaultdict
7
+
8
+ class FederatedAggregator:
9
+ def __init__(self, config: Dict):
10
+ """Initialize the federated aggregator."""
11
+ self.weighted = config['aggregation']['weighted']
12
+
13
+ def compute_metrics(self, client_metrics: List[Dict]) -> Dict:
14
+ """Compute aggregated metrics from client updates."""
15
+ if not client_metrics:
16
+ return {}
17
+
18
+ aggregated_metrics = defaultdict(float)
19
+ total_samples = sum(metrics['num_samples'] for metrics in client_metrics)
20
+
21
+ for metrics in client_metrics:
22
+ weight = metrics['num_samples'] / total_samples if self.weighted else 1.0
23
+
24
+ for metric_name, value in metrics['metrics'].items():
25
+ aggregated_metrics[metric_name] += value * weight
26
+
27
+ return dict(aggregated_metrics)
28
+
29
+ def check_convergence(self,
30
+ old_weights: List,
31
+ new_weights: List,
32
+ threshold: float = 1e-5) -> bool:
33
+ """Check if the model has converged."""
34
+ if old_weights is None or new_weights is None:
35
+ return False
36
+
37
+ weight_differences = [
38
+ np.mean(np.abs(old - new))
39
+ for old, new in zip(old_weights, new_weights)
40
+ ]
41
+
42
+ return all(diff < threshold for diff in weight_differences)
43
+
src/server/coordinator.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """coordinator.py module."""
2
+
3
+ import tensorflow as tf
4
+ from typing import List, Dict
5
+ import numpy as np
6
+ from collections import defaultdict
7
+ import logging
8
+ import time
9
+
10
+ class FederatedCoordinator:
11
+ def __init__(self, config: Dict):
12
+ """Initialize the federated learning coordinator."""
13
+ self.config = config
14
+ self.clients = {}
15
+ self.current_round = 0
16
+ self.min_clients = config.get('server', {}).get('federated', {}).get('min_clients', 2)
17
+ self.rounds = config.get('server', {}).get('federated', {}).get('rounds', 10)
18
+
19
+ def register_client(self, client_id: int, client_size: int):
20
+ """Register a new client."""
21
+ self.clients[client_id] = {
22
+ 'size': client_size,
23
+ 'weights': None,
24
+ 'metrics': defaultdict(list)
25
+ }
26
+
27
+ def aggregate_weights(self, client_updates: List[Dict]) -> List:
28
+ """Aggregate weights using FedAvg algorithm."""
29
+ total_size = sum(self.clients[update['client_id']]['size']
30
+ for update in client_updates)
31
+
32
+ aggregated_weights = [
33
+ np.zeros_like(w) for w in client_updates[0]['weights']
34
+ ]
35
+
36
+ for update in client_updates:
37
+ client_size = self.clients[update['client_id']]['size']
38
+ weight = client_size / total_size
39
+
40
+ for i, layer_weights in enumerate(update['weights']):
41
+ aggregated_weights[i] += layer_weights * weight
42
+
43
+ return aggregated_weights
44
+
45
+ def start(self):
46
+ """Start the federated learning process."""
47
+ logger = logging.getLogger(__name__)
48
+
49
+ # Print server startup information
50
+ logger.info("\n" + "=" * 60)
51
+ logger.info(f"{'Federated Learning Server Starting':^60}")
52
+ logger.info("=" * 60)
53
+
54
+ # Print configuration details
55
+ logger.info("\nServer Configuration:")
56
+ logger.info("-" * 30)
57
+ logger.info(f"Minimum clients required: {self.min_clients}")
58
+ logger.info(f"Total rounds planned: {self.rounds}")
59
+ logger.info(f"Current active clients: {len(self.clients)}")
60
+ logger.info("-" * 30 + "\n")
61
+
62
+ while self.current_round < self.rounds:
63
+ round_num = self.current_round + 1
64
+ logger.info(f"\nRound {round_num}/{self.rounds}")
65
+ logger.info("-" * 30)
66
+
67
+ if len(self.clients) < self.min_clients:
68
+ logger.warning(
69
+ f"Waiting for clients... "
70
+ f"(active: {len(self.clients)}/{self.min_clients})"
71
+ )
72
+ time.sleep(5)
73
+ continue
74
+
75
+ logger.info(f"Active clients: {list(self.clients.keys())}")
76
+ logger.info(f"Starting training round {round_num}")
77
+ self.current_round += 1
src/utils/metrics.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """metrics.py module."""
2
+
3
+ from typing import Dict, List
4
+ import numpy as np
5
+ from scipy.stats import wasserstein_distance, ks_2samp
6
+ from sklearn.metrics import mutual_info_score, silhouette_score
7
+ from sklearn.neighbors import NearestNeighbors
8
+
9
+ class MetricsCalculator:
10
+ @staticmethod
11
+ def calculate_distribution_similarity(real_data: np.ndarray,
12
+ synthetic_data: np.ndarray) -> Dict[str, float]:
13
+ """Calculate statistical similarity metrics between real and synthetic data."""
14
+ metrics = {}
15
+
16
+ # Wasserstein distance
17
+ metrics['wasserstein'] = wasserstein_distance(
18
+ real_data.flatten(),
19
+ synthetic_data.flatten()
20
+ )
21
+
22
+ # KL divergence approximation
23
+ metrics['mutual_info'] = mutual_info_score(
24
+ real_data.flatten(),
25
+ synthetic_data.flatten()
26
+ )
27
+
28
+ # Kolmogorov-Smirnov test
29
+ ks_statistic, p_value = ks_2samp(real_data.flatten(), synthetic_data.flatten())
30
+ metrics['ks_statistic'] = ks_statistic
31
+ metrics['ks_p_value'] = p_value
32
+
33
+ # Basic statistical measures
34
+ metrics['mean_diff'] = abs(np.mean(real_data) - np.mean(synthetic_data))
35
+ metrics['std_diff'] = abs(np.std(real_data) - np.std(synthetic_data))
36
+ metrics['percentile_diff'] = np.mean([
37
+ abs(np.percentile(real_data, p) - np.percentile(synthetic_data, p))
38
+ for p in [25, 50, 75]
39
+ ])
40
+
41
+ return metrics
42
+
43
+ @staticmethod
44
+ def evaluate_privacy_metrics(model, test_data: np.ndarray,
45
+ synthetic_data: np.ndarray) -> Dict[str, float]:
46
+ """Evaluate privacy-related metrics."""
47
+ metrics = {}
48
+
49
+ # Membership inference risk
50
+ metrics['membership_inference_risk'] = MetricsCalculator._calculate_membership_inference_risk(
51
+ test_data, synthetic_data
52
+ )
53
+
54
+ # Attribute inference risk
55
+ metrics['attribute_inference_risk'] = MetricsCalculator._calculate_attribute_inference_risk(
56
+ test_data, synthetic_data
57
+ )
58
+
59
+ # k-anonymity approximation
60
+ metrics['k_anonymity_score'] = MetricsCalculator._calculate_k_anonymity(synthetic_data)
61
+
62
+ # Uniqueness score
63
+ metrics['uniqueness_score'] = MetricsCalculator._calculate_uniqueness(synthetic_data)
64
+
65
+ return metrics
66
+
67
+ @staticmethod
68
+ def _calculate_membership_inference_risk(test_data: np.ndarray,
69
+ synthetic_data: np.ndarray) -> float:
70
+ """Calculate membership inference risk using nearest neighbor distance ratio."""
71
+ k = 3 # number of neighbors to consider
72
+ nn = NearestNeighbors(n_neighbors=k)
73
+ nn.fit(synthetic_data)
74
+
75
+ distances, _ = nn.kneighbors(test_data)
76
+ avg_min_distances = distances.mean(axis=1)
77
+
78
+ # Normalize to [0,1] where higher values indicate higher privacy
79
+ risk_score = 1.0 - (1.0 / (1.0 + np.mean(avg_min_distances)))
80
+ return risk_score
81
+
82
+ @staticmethod
83
+ def _calculate_attribute_inference_risk(test_data: np.ndarray,
84
+ synthetic_data: np.ndarray) -> float:
85
+ """Calculate attribute inference risk using correlation analysis."""
86
+ real_corr = np.corrcoef(test_data.T)
87
+ synth_corr = np.corrcoef(synthetic_data.T)
88
+
89
+ # Compare correlation matrices
90
+ correlation_diff = np.abs(real_corr - synth_corr).mean()
91
+
92
+ # Convert to risk score (0 to 1, where lower is better)
93
+ risk_score = 1.0 - np.exp(-correlation_diff)
94
+ return risk_score
95
+
96
+ @staticmethod
97
+ def _calculate_k_anonymity(data: np.ndarray, k: int = 5) -> float:
98
+ """Calculate approximate k-anonymity score."""
99
+ nn = NearestNeighbors(n_neighbors=k)
100
+ nn.fit(data)
101
+
102
+ distances, _ = nn.kneighbors(data)
103
+ k_anonymity_scores = distances[:, -1] # Distance to k-th neighbor
104
+
105
+ # Convert to score (0 to 1, where higher is better)
106
+ return float(np.mean(k_anonymity_scores > 0.1))
107
+
108
+ @staticmethod
109
+ def _calculate_uniqueness(data: np.ndarray) -> float:
110
+ """Calculate uniqueness score of the dataset."""
111
+ nn = NearestNeighbors(n_neighbors=2) # 2 because first neighbor is self
112
+ nn.fit(data)
113
+
114
+ distances, _ = nn.kneighbors(data)
115
+ uniqueness_scores = distances[:, 1] # Distance to nearest non-self neighbor
116
+
117
+ # Convert to score (0 to 1, where higher means more unique records)
118
+ return float(np.mean(uniqueness_scores > np.median(uniqueness_scores)))
119
+
src/utils/privacy.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """privacy.py module."""
2
+
3
+ import tensorflow_privacy as tfp
4
+ from typing import Dict, Any
5
+ import numpy as np
6
+
7
+ class PrivacyManager:
8
+ def __init__(self, config: Dict[str, Any]):
9
+ self.epsilon = config['privacy']['epsilon']
10
+ self.delta = config['privacy']['delta']
11
+ self.noise_multiplier = config['privacy']['noise_multiplier']
12
+
13
+ def add_noise_to_gradients(self, gradients: np.ndarray) -> np.ndarray:
14
+ """Add Gaussian noise to gradients for differential privacy."""
15
+ noise = np.random.normal(0, self.noise_multiplier, gradients.shape)
16
+ return gradients + noise
17
+
18
+ def verify_privacy_budget(self, num_iterations: int) -> bool:
19
+ """Check if training stays within privacy budget."""
20
+ eps = self.compute_epsilon(num_iterations)
21
+ return eps <= self.epsilon
22
+
23
+ def compute_epsilon(self, num_iterations: int) -> float:
24
+ """Compute the current epsilon value."""
25
+ q = 1.0 # sampling ratio
26
+ steps = num_iterations
27
+ orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
28
+ list(range(5, 64)) + [128, 256, 512])
29
+
30
+ return tfp.compute_dp_sgd_privacy(
31
+ n=1000, # number of training points
32
+ batch_size=32,
33
+ noise_multiplier=self.noise_multiplier,
34
+ epochs=steps,
35
+ delta=self.delta
36
+ )[0]
37
+
tests/test_client.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """test_client.py module."""
2
+
3
+ import pytest
4
+ import tensorflow as tf
5
+ import yaml
6
+ from src.client.data_handler import FinancialDataHandler
7
+ from src.client.model import FederatedClient
8
+
9
+ @pytest.fixture
10
+ def config():
11
+ """Load test configuration."""
12
+ with open('config/client_config.yaml', 'r') as f:
13
+ return yaml.safe_load(f)['client']
14
+
15
+ def test_data_handler(config):
16
+ """Test data handler functionality."""
17
+ handler = FinancialDataHandler(config)
18
+
19
+ # Test data simulation
20
+ data = handler.simulate_financial_data(num_samples=100)
21
+ assert len(data) == 100
22
+ assert all(col in data.columns for col in [
23
+ 'transaction_amount',
24
+ 'account_balance',
25
+ 'transaction_frequency',
26
+ 'credit_score',
27
+ 'days_since_last_transaction'
28
+ ])
29
+
30
+ # Test preprocessing
31
+ dataset, scaler = handler.get_client_data()
32
+ assert isinstance(dataset, tf.data.Dataset)
33
+
34
+ def test_federated_client(config):
35
+ """Test federated client functionality."""
36
+ client = FederatedClient(config)
37
+
38
+ # Test model building
39
+ assert isinstance(client.model, tf.keras.Model)
40
+
41
+ # Test local training
42
+ handler = FinancialDataHandler(config)
43
+ dataset, _ = handler.get_client_data()
44
+
45
+ training_result = client.train_local_model(dataset, epochs=1)
46
+ assert 'client_id' in training_result
47
+ assert 'weights' in training_result
48
+ assert 'metrics' in training_result
49
+