Spaces:
Sleeping
Sleeping
“Transcendental-Programmer”
commited on
Commit
·
754afec
1
Parent(s):
3de89cd
FEAT: added server coordination and model aggregation logic
Browse files- src/main.py +62 -0
- src/models/gan.py +122 -0
- src/models/vae.py +52 -0
- src/rag/generator.py +62 -0
- src/rag/retriever.py +84 -0
- src/server/aggregator.py +43 -0
- src/server/coordinator.py +77 -0
- src/utils/metrics.py +119 -0
- src/utils/privacy.py +37 -0
- tests/test_client.py +49 -0
src/main.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import yaml
|
3 |
+
import logging
|
4 |
+
import logging.config
|
5 |
+
from pathlib import Path
|
6 |
+
from src.server.coordinator import FederatedCoordinator
|
7 |
+
from src.client.model import FederatedClient
|
8 |
+
|
9 |
+
def setup_logging(config):
|
10 |
+
"""Setup logging configuration."""
|
11 |
+
# Create logs directory if it doesn't exist
|
12 |
+
Path("logs").mkdir(exist_ok=True)
|
13 |
+
|
14 |
+
log_level = (config.get('monitoring', {}).get('log_level')
|
15 |
+
or config.get('server', {}).get('monitoring', {}).get('log_level')
|
16 |
+
or config.get('client', {}).get('monitoring', {}).get('log_level')
|
17 |
+
or 'INFO')
|
18 |
+
|
19 |
+
# Configure logging with UTF-8 encoding
|
20 |
+
logging.basicConfig(
|
21 |
+
level=log_level,
|
22 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
23 |
+
handlers=[
|
24 |
+
logging.StreamHandler(),
|
25 |
+
logging.FileHandler('logs/federated_learning.log', mode='a', encoding='utf-8')
|
26 |
+
]
|
27 |
+
)
|
28 |
+
|
29 |
+
# Reduce TensorFlow logging noise
|
30 |
+
logging.getLogger('tensorflow').setLevel(logging.WARNING)
|
31 |
+
|
32 |
+
# Create a divider in the log file
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
logger.info("\n" + "="*50)
|
35 |
+
logger.info("New Training Session Started")
|
36 |
+
logger.info("="*50 + "\n")
|
37 |
+
|
38 |
+
def load_config(config_path: str) -> dict:
|
39 |
+
with open(config_path, 'r') as f:
|
40 |
+
return yaml.safe_load(f)
|
41 |
+
|
42 |
+
def main():
|
43 |
+
parser = argparse.ArgumentParser(description='Federated Learning Demo')
|
44 |
+
parser.add_argument('--mode', choices=['server', 'client'], required=True)
|
45 |
+
parser.add_argument('--config', type=str, required=True)
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
config = load_config(args.config)
|
49 |
+
setup_logging(config)
|
50 |
+
logger = logging.getLogger(__name__)
|
51 |
+
|
52 |
+
if args.mode == 'server':
|
53 |
+
coordinator = FederatedCoordinator(config)
|
54 |
+
logger.info("Starting server...")
|
55 |
+
coordinator.start()
|
56 |
+
else:
|
57 |
+
client = FederatedClient(1, config)
|
58 |
+
logger.info(f"Starting client with ID: {client.client_id}")
|
59 |
+
client.start()
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
main()
|
src/models/gan.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GAN implementation for financial data generation."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
from typing import Dict, Tuple
|
7 |
+
|
8 |
+
class Generator(nn.Module):
|
9 |
+
def __init__(self, latent_dim: int, feature_dim: int, hidden_dims: List[int]):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
layers = []
|
13 |
+
prev_dim = latent_dim
|
14 |
+
|
15 |
+
for hidden_dim in hidden_dims:
|
16 |
+
layers.extend([
|
17 |
+
nn.Linear(prev_dim, hidden_dim),
|
18 |
+
nn.BatchNorm1d(hidden_dim),
|
19 |
+
nn.LeakyReLU(0.2),
|
20 |
+
nn.Dropout(0.3)
|
21 |
+
])
|
22 |
+
prev_dim = hidden_dim
|
23 |
+
|
24 |
+
layers.append(nn.Linear(prev_dim, feature_dim))
|
25 |
+
layers.append(nn.Tanh())
|
26 |
+
|
27 |
+
self.model = nn.Sequential(*layers)
|
28 |
+
|
29 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
30 |
+
return self.model(z)
|
31 |
+
|
32 |
+
class Discriminator(nn.Module):
|
33 |
+
def __init__(self, feature_dim: int, hidden_dims: List[int]):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
layers = []
|
37 |
+
prev_dim = feature_dim
|
38 |
+
|
39 |
+
for hidden_dim in hidden_dims:
|
40 |
+
layers.extend([
|
41 |
+
nn.Linear(prev_dim, hidden_dim),
|
42 |
+
nn.LeakyReLU(0.2),
|
43 |
+
nn.Dropout(0.3)
|
44 |
+
])
|
45 |
+
prev_dim = hidden_dim
|
46 |
+
|
47 |
+
layers.append(nn.Linear(prev_dim, 1))
|
48 |
+
layers.append(nn.Sigmoid())
|
49 |
+
|
50 |
+
self.model = nn.Sequential(*layers)
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
+
return self.model(x)
|
54 |
+
|
55 |
+
class FinancialGAN:
|
56 |
+
def __init__(self, config: Dict):
|
57 |
+
"""Initialize the GAN."""
|
58 |
+
self.latent_dim = config['model']['latent_dim']
|
59 |
+
self.feature_dim = config['model']['feature_dim']
|
60 |
+
self.hidden_dims = config['model']['hidden_dims']
|
61 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
62 |
+
|
63 |
+
self.generator = Generator(
|
64 |
+
self.latent_dim,
|
65 |
+
self.feature_dim,
|
66 |
+
self.hidden_dims
|
67 |
+
).to(self.device)
|
68 |
+
|
69 |
+
self.discriminator = Discriminator(
|
70 |
+
self.feature_dim,
|
71 |
+
self.hidden_dims[::-1]
|
72 |
+
).to(self.device)
|
73 |
+
|
74 |
+
self.g_optimizer = optim.Adam(
|
75 |
+
self.generator.parameters(),
|
76 |
+
lr=config['model']['learning_rate']
|
77 |
+
)
|
78 |
+
self.d_optimizer = optim.Adam(
|
79 |
+
self.discriminator.parameters(),
|
80 |
+
lr=config['model']['learning_rate']
|
81 |
+
)
|
82 |
+
|
83 |
+
self.criterion = nn.BCELoss()
|
84 |
+
|
85 |
+
def train_step(self, real_data: torch.Tensor) -> Tuple[float, float]:
|
86 |
+
"""Perform one training step."""
|
87 |
+
batch_size = real_data.size(0)
|
88 |
+
real_label = torch.ones(batch_size, 1).to(self.device)
|
89 |
+
fake_label = torch.zeros(batch_size, 1).to(self.device)
|
90 |
+
|
91 |
+
# Train Discriminator
|
92 |
+
self.d_optimizer.zero_grad()
|
93 |
+
d_real_output = self.discriminator(real_data)
|
94 |
+
d_real_loss = self.criterion(d_real_output, real_label)
|
95 |
+
|
96 |
+
z = torch.randn(batch_size, self.latent_dim).to(self.device)
|
97 |
+
fake_data = self.generator(z)
|
98 |
+
d_fake_output = self.discriminator(fake_data.detach())
|
99 |
+
d_fake_loss = self.criterion(d_fake_output, fake_label)
|
100 |
+
|
101 |
+
d_loss = d_real_loss + d_fake_loss
|
102 |
+
d_loss.backward()
|
103 |
+
self.d_optimizer.step()
|
104 |
+
|
105 |
+
# Train Generator
|
106 |
+
self.g_optimizer.zero_grad()
|
107 |
+
g_output = self.discriminator(fake_data)
|
108 |
+
g_loss = self.criterion(g_output, real_label)
|
109 |
+
g_loss.backward()
|
110 |
+
self.g_optimizer.step()
|
111 |
+
|
112 |
+
return g_loss.item(), d_loss.item()
|
113 |
+
|
114 |
+
def generate_samples(self, num_samples: int) -> torch.Tensor:
|
115 |
+
"""Generate synthetic financial data."""
|
116 |
+
self.generator.eval()
|
117 |
+
with torch.no_grad():
|
118 |
+
z = torch.randn(num_samples, self.latent_dim).to(self.device)
|
119 |
+
samples = self.generator(z)
|
120 |
+
self.generator.train()
|
121 |
+
return samples
|
122 |
+
|
src/models/vae.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""vae.py module."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
class VAE(nn.Module):
|
9 |
+
def __init__(self, input_dim: int, latent_dim: int, hidden_dims: List[int]):
|
10 |
+
super(VAE, self).__init__()
|
11 |
+
|
12 |
+
# Encoder
|
13 |
+
modules = []
|
14 |
+
in_features = input_dim
|
15 |
+
for h_dim in hidden_dims:
|
16 |
+
modules.append(nn.Linear(in_features, h_dim))
|
17 |
+
modules.append(nn.ReLU())
|
18 |
+
in_features = h_dim
|
19 |
+
self.encoder = nn.Sequential(*modules)
|
20 |
+
|
21 |
+
# Latent space
|
22 |
+
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
|
23 |
+
self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
|
24 |
+
|
25 |
+
# Decoder
|
26 |
+
modules = []
|
27 |
+
hidden_dims.reverse()
|
28 |
+
in_features = latent_dim
|
29 |
+
for h_dim in hidden_dims:
|
30 |
+
modules.append(nn.Linear(in_features, h_dim))
|
31 |
+
modules.append(nn.ReLU())
|
32 |
+
in_features = h_dim
|
33 |
+
modules.append(nn.Linear(hidden_dims[-1], input_dim))
|
34 |
+
self.decoder = nn.Sequential(*modules)
|
35 |
+
|
36 |
+
def encode(self, x):
|
37 |
+
h = self.encoder(x)
|
38 |
+
return self.fc_mu(h), self.fc_var(h)
|
39 |
+
|
40 |
+
def decode(self, z):
|
41 |
+
return self.decoder(z)
|
42 |
+
|
43 |
+
def reparameterize(self, mu, log_var):
|
44 |
+
std = torch.exp(0.5 * log_var)
|
45 |
+
eps = torch.randn_like(std)
|
46 |
+
return mu + eps * std
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
mu, log_var = self.encode(x)
|
50 |
+
z = self.reparameterize(mu, log_var)
|
51 |
+
return self.decode(z), mu, log_var
|
52 |
+
|
src/rag/generator.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Generator component for the RAG system."""
|
2 |
+
|
3 |
+
from typing import List, Dict
|
4 |
+
import torch
|
5 |
+
from transformers import (
|
6 |
+
AutoTokenizer,
|
7 |
+
AutoModelForCausalLM,
|
8 |
+
LogitsProcessor,
|
9 |
+
LogitsProcessorList
|
10 |
+
)
|
11 |
+
|
12 |
+
class FinancialContextProcessor(LogitsProcessor):
|
13 |
+
"""Custom logits processor for financial context."""
|
14 |
+
def __init__(self, financial_constraints: Dict):
|
15 |
+
self.constraints = financial_constraints
|
16 |
+
|
17 |
+
def __call__(self, input_ids: torch.LongTensor,
|
18 |
+
scores: torch.FloatTensor) -> torch.FloatTensor:
|
19 |
+
# Apply financial domain constraints
|
20 |
+
# This is a placeholder for actual constraints
|
21 |
+
return scores
|
22 |
+
|
23 |
+
class RAGGenerator:
|
24 |
+
def __init__(self, config: Dict):
|
25 |
+
"""Initialize the generator."""
|
26 |
+
self.model_name = "gpt2" # Can be configured based on needs
|
27 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
28 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
|
29 |
+
self.max_length = 512
|
30 |
+
|
31 |
+
def prepare_context(self, retrieved_docs: List[Dict]) -> str:
|
32 |
+
"""Prepare context from retrieved documents."""
|
33 |
+
context = ""
|
34 |
+
for doc in retrieved_docs:
|
35 |
+
context += f"{doc['document']['text']}\n"
|
36 |
+
return context.strip()
|
37 |
+
|
38 |
+
def generate(self, query: str, retrieved_docs: List[Dict],
|
39 |
+
financial_constraints: Dict = None) -> str:
|
40 |
+
"""Generate text based on query and retrieved documents."""
|
41 |
+
context = self.prepare_context(retrieved_docs)
|
42 |
+
prompt = f"Context: {context}\nQuery: {query}\nResponse:"
|
43 |
+
|
44 |
+
# Prepare logits processors
|
45 |
+
processors = LogitsProcessorList()
|
46 |
+
if financial_constraints:
|
47 |
+
processors.append(FinancialContextProcessor(financial_constraints))
|
48 |
+
|
49 |
+
# Generate response
|
50 |
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
51 |
+
outputs = self.model.generate(
|
52 |
+
inputs.input_ids,
|
53 |
+
max_length=self.max_length,
|
54 |
+
num_return_sequences=1,
|
55 |
+
logits_processor=processors,
|
56 |
+
do_sample=True,
|
57 |
+
temperature=0.7,
|
58 |
+
top_p=0.9
|
59 |
+
)
|
60 |
+
|
61 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
62 |
+
|
src/rag/retriever.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Retrieval component for the RAG system."""
|
2 |
+
|
3 |
+
import faiss
|
4 |
+
import numpy as np
|
5 |
+
from typing import List, Dict, Tuple
|
6 |
+
from elasticsearch import Elasticsearch
|
7 |
+
from transformers import AutoTokenizer, AutoModel
|
8 |
+
import torch
|
9 |
+
|
10 |
+
class FinancialDataRetriever:
|
11 |
+
def __init__(self, config: Dict):
|
12 |
+
"""Initialize the retriever with configuration."""
|
13 |
+
self.retriever_type = config['rag']['retriever']
|
14 |
+
self.max_documents = config['rag']['max_documents']
|
15 |
+
self.similarity_threshold = config['rag']['similarity_threshold']
|
16 |
+
|
17 |
+
# Initialize FAISS index
|
18 |
+
self.dimension = 768 # BERT embedding dimension
|
19 |
+
self.index = faiss.IndexFlatL2(self.dimension)
|
20 |
+
|
21 |
+
# Initialize transformer model for embeddings
|
22 |
+
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
23 |
+
self.model = AutoModel.from_pretrained('bert-base-uncased')
|
24 |
+
|
25 |
+
# Initialize Elasticsearch if needed
|
26 |
+
if self.retriever_type == "elasticsearch":
|
27 |
+
self.es = Elasticsearch()
|
28 |
+
|
29 |
+
def encode_text(self, texts: List[str]) -> np.ndarray:
|
30 |
+
"""Encode text using BERT."""
|
31 |
+
tokens = self.tokenizer(texts, padding=True, truncation=True,
|
32 |
+
return_tensors="pt", max_length=512)
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = self.model(**tokens)
|
35 |
+
embeddings = outputs.last_hidden_state[:, 0, :].numpy()
|
36 |
+
return embeddings
|
37 |
+
|
38 |
+
def index_documents(self, documents: List[Dict]):
|
39 |
+
"""Index documents for retrieval."""
|
40 |
+
if self.retriever_type == "faiss":
|
41 |
+
texts = [doc['text'] for doc in documents]
|
42 |
+
embeddings = self.encode_text(texts)
|
43 |
+
self.index.add(embeddings)
|
44 |
+
self.documents = documents
|
45 |
+
else:
|
46 |
+
for doc in documents:
|
47 |
+
self.es.index(index="financial_data", document=doc)
|
48 |
+
|
49 |
+
def retrieve(self, query: str, k: int = None) -> List[Dict]:
|
50 |
+
"""Retrieve relevant documents."""
|
51 |
+
k = k or self.max_documents
|
52 |
+
query_embedding = self.encode_text([query])
|
53 |
+
|
54 |
+
if self.retriever_type == "faiss":
|
55 |
+
distances, indices = self.index.search(query_embedding, k)
|
56 |
+
results = [
|
57 |
+
{
|
58 |
+
'document': self.documents[idx],
|
59 |
+
'score': float(1 / (1 + dist))
|
60 |
+
}
|
61 |
+
for dist, idx in zip(distances[0], indices[0])
|
62 |
+
if 1 / (1 + dist) >= self.similarity_threshold
|
63 |
+
]
|
64 |
+
else:
|
65 |
+
response = self.es.search(
|
66 |
+
index="financial_data",
|
67 |
+
query={
|
68 |
+
"match": {
|
69 |
+
"text": query
|
70 |
+
}
|
71 |
+
},
|
72 |
+
size=k
|
73 |
+
)
|
74 |
+
results = [
|
75 |
+
{
|
76 |
+
'document': hit['_source'],
|
77 |
+
'score': hit['_score']
|
78 |
+
}
|
79 |
+
for hit in response['hits']['hits']
|
80 |
+
if hit['_score'] >= self.similarity_threshold
|
81 |
+
]
|
82 |
+
|
83 |
+
return results
|
84 |
+
|
src/server/aggregator.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""aggregator.py module."""
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from typing import List, Dict
|
5 |
+
import numpy as np
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
class FederatedAggregator:
|
9 |
+
def __init__(self, config: Dict):
|
10 |
+
"""Initialize the federated aggregator."""
|
11 |
+
self.weighted = config['aggregation']['weighted']
|
12 |
+
|
13 |
+
def compute_metrics(self, client_metrics: List[Dict]) -> Dict:
|
14 |
+
"""Compute aggregated metrics from client updates."""
|
15 |
+
if not client_metrics:
|
16 |
+
return {}
|
17 |
+
|
18 |
+
aggregated_metrics = defaultdict(float)
|
19 |
+
total_samples = sum(metrics['num_samples'] for metrics in client_metrics)
|
20 |
+
|
21 |
+
for metrics in client_metrics:
|
22 |
+
weight = metrics['num_samples'] / total_samples if self.weighted else 1.0
|
23 |
+
|
24 |
+
for metric_name, value in metrics['metrics'].items():
|
25 |
+
aggregated_metrics[metric_name] += value * weight
|
26 |
+
|
27 |
+
return dict(aggregated_metrics)
|
28 |
+
|
29 |
+
def check_convergence(self,
|
30 |
+
old_weights: List,
|
31 |
+
new_weights: List,
|
32 |
+
threshold: float = 1e-5) -> bool:
|
33 |
+
"""Check if the model has converged."""
|
34 |
+
if old_weights is None or new_weights is None:
|
35 |
+
return False
|
36 |
+
|
37 |
+
weight_differences = [
|
38 |
+
np.mean(np.abs(old - new))
|
39 |
+
for old, new in zip(old_weights, new_weights)
|
40 |
+
]
|
41 |
+
|
42 |
+
return all(diff < threshold for diff in weight_differences)
|
43 |
+
|
src/server/coordinator.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""coordinator.py module."""
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from typing import List, Dict
|
5 |
+
import numpy as np
|
6 |
+
from collections import defaultdict
|
7 |
+
import logging
|
8 |
+
import time
|
9 |
+
|
10 |
+
class FederatedCoordinator:
|
11 |
+
def __init__(self, config: Dict):
|
12 |
+
"""Initialize the federated learning coordinator."""
|
13 |
+
self.config = config
|
14 |
+
self.clients = {}
|
15 |
+
self.current_round = 0
|
16 |
+
self.min_clients = config.get('server', {}).get('federated', {}).get('min_clients', 2)
|
17 |
+
self.rounds = config.get('server', {}).get('federated', {}).get('rounds', 10)
|
18 |
+
|
19 |
+
def register_client(self, client_id: int, client_size: int):
|
20 |
+
"""Register a new client."""
|
21 |
+
self.clients[client_id] = {
|
22 |
+
'size': client_size,
|
23 |
+
'weights': None,
|
24 |
+
'metrics': defaultdict(list)
|
25 |
+
}
|
26 |
+
|
27 |
+
def aggregate_weights(self, client_updates: List[Dict]) -> List:
|
28 |
+
"""Aggregate weights using FedAvg algorithm."""
|
29 |
+
total_size = sum(self.clients[update['client_id']]['size']
|
30 |
+
for update in client_updates)
|
31 |
+
|
32 |
+
aggregated_weights = [
|
33 |
+
np.zeros_like(w) for w in client_updates[0]['weights']
|
34 |
+
]
|
35 |
+
|
36 |
+
for update in client_updates:
|
37 |
+
client_size = self.clients[update['client_id']]['size']
|
38 |
+
weight = client_size / total_size
|
39 |
+
|
40 |
+
for i, layer_weights in enumerate(update['weights']):
|
41 |
+
aggregated_weights[i] += layer_weights * weight
|
42 |
+
|
43 |
+
return aggregated_weights
|
44 |
+
|
45 |
+
def start(self):
|
46 |
+
"""Start the federated learning process."""
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
|
49 |
+
# Print server startup information
|
50 |
+
logger.info("\n" + "=" * 60)
|
51 |
+
logger.info(f"{'Federated Learning Server Starting':^60}")
|
52 |
+
logger.info("=" * 60)
|
53 |
+
|
54 |
+
# Print configuration details
|
55 |
+
logger.info("\nServer Configuration:")
|
56 |
+
logger.info("-" * 30)
|
57 |
+
logger.info(f"Minimum clients required: {self.min_clients}")
|
58 |
+
logger.info(f"Total rounds planned: {self.rounds}")
|
59 |
+
logger.info(f"Current active clients: {len(self.clients)}")
|
60 |
+
logger.info("-" * 30 + "\n")
|
61 |
+
|
62 |
+
while self.current_round < self.rounds:
|
63 |
+
round_num = self.current_round + 1
|
64 |
+
logger.info(f"\nRound {round_num}/{self.rounds}")
|
65 |
+
logger.info("-" * 30)
|
66 |
+
|
67 |
+
if len(self.clients) < self.min_clients:
|
68 |
+
logger.warning(
|
69 |
+
f"Waiting for clients... "
|
70 |
+
f"(active: {len(self.clients)}/{self.min_clients})"
|
71 |
+
)
|
72 |
+
time.sleep(5)
|
73 |
+
continue
|
74 |
+
|
75 |
+
logger.info(f"Active clients: {list(self.clients.keys())}")
|
76 |
+
logger.info(f"Starting training round {round_num}")
|
77 |
+
self.current_round += 1
|
src/utils/metrics.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""metrics.py module."""
|
2 |
+
|
3 |
+
from typing import Dict, List
|
4 |
+
import numpy as np
|
5 |
+
from scipy.stats import wasserstein_distance, ks_2samp
|
6 |
+
from sklearn.metrics import mutual_info_score, silhouette_score
|
7 |
+
from sklearn.neighbors import NearestNeighbors
|
8 |
+
|
9 |
+
class MetricsCalculator:
|
10 |
+
@staticmethod
|
11 |
+
def calculate_distribution_similarity(real_data: np.ndarray,
|
12 |
+
synthetic_data: np.ndarray) -> Dict[str, float]:
|
13 |
+
"""Calculate statistical similarity metrics between real and synthetic data."""
|
14 |
+
metrics = {}
|
15 |
+
|
16 |
+
# Wasserstein distance
|
17 |
+
metrics['wasserstein'] = wasserstein_distance(
|
18 |
+
real_data.flatten(),
|
19 |
+
synthetic_data.flatten()
|
20 |
+
)
|
21 |
+
|
22 |
+
# KL divergence approximation
|
23 |
+
metrics['mutual_info'] = mutual_info_score(
|
24 |
+
real_data.flatten(),
|
25 |
+
synthetic_data.flatten()
|
26 |
+
)
|
27 |
+
|
28 |
+
# Kolmogorov-Smirnov test
|
29 |
+
ks_statistic, p_value = ks_2samp(real_data.flatten(), synthetic_data.flatten())
|
30 |
+
metrics['ks_statistic'] = ks_statistic
|
31 |
+
metrics['ks_p_value'] = p_value
|
32 |
+
|
33 |
+
# Basic statistical measures
|
34 |
+
metrics['mean_diff'] = abs(np.mean(real_data) - np.mean(synthetic_data))
|
35 |
+
metrics['std_diff'] = abs(np.std(real_data) - np.std(synthetic_data))
|
36 |
+
metrics['percentile_diff'] = np.mean([
|
37 |
+
abs(np.percentile(real_data, p) - np.percentile(synthetic_data, p))
|
38 |
+
for p in [25, 50, 75]
|
39 |
+
])
|
40 |
+
|
41 |
+
return metrics
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def evaluate_privacy_metrics(model, test_data: np.ndarray,
|
45 |
+
synthetic_data: np.ndarray) -> Dict[str, float]:
|
46 |
+
"""Evaluate privacy-related metrics."""
|
47 |
+
metrics = {}
|
48 |
+
|
49 |
+
# Membership inference risk
|
50 |
+
metrics['membership_inference_risk'] = MetricsCalculator._calculate_membership_inference_risk(
|
51 |
+
test_data, synthetic_data
|
52 |
+
)
|
53 |
+
|
54 |
+
# Attribute inference risk
|
55 |
+
metrics['attribute_inference_risk'] = MetricsCalculator._calculate_attribute_inference_risk(
|
56 |
+
test_data, synthetic_data
|
57 |
+
)
|
58 |
+
|
59 |
+
# k-anonymity approximation
|
60 |
+
metrics['k_anonymity_score'] = MetricsCalculator._calculate_k_anonymity(synthetic_data)
|
61 |
+
|
62 |
+
# Uniqueness score
|
63 |
+
metrics['uniqueness_score'] = MetricsCalculator._calculate_uniqueness(synthetic_data)
|
64 |
+
|
65 |
+
return metrics
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def _calculate_membership_inference_risk(test_data: np.ndarray,
|
69 |
+
synthetic_data: np.ndarray) -> float:
|
70 |
+
"""Calculate membership inference risk using nearest neighbor distance ratio."""
|
71 |
+
k = 3 # number of neighbors to consider
|
72 |
+
nn = NearestNeighbors(n_neighbors=k)
|
73 |
+
nn.fit(synthetic_data)
|
74 |
+
|
75 |
+
distances, _ = nn.kneighbors(test_data)
|
76 |
+
avg_min_distances = distances.mean(axis=1)
|
77 |
+
|
78 |
+
# Normalize to [0,1] where higher values indicate higher privacy
|
79 |
+
risk_score = 1.0 - (1.0 / (1.0 + np.mean(avg_min_distances)))
|
80 |
+
return risk_score
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def _calculate_attribute_inference_risk(test_data: np.ndarray,
|
84 |
+
synthetic_data: np.ndarray) -> float:
|
85 |
+
"""Calculate attribute inference risk using correlation analysis."""
|
86 |
+
real_corr = np.corrcoef(test_data.T)
|
87 |
+
synth_corr = np.corrcoef(synthetic_data.T)
|
88 |
+
|
89 |
+
# Compare correlation matrices
|
90 |
+
correlation_diff = np.abs(real_corr - synth_corr).mean()
|
91 |
+
|
92 |
+
# Convert to risk score (0 to 1, where lower is better)
|
93 |
+
risk_score = 1.0 - np.exp(-correlation_diff)
|
94 |
+
return risk_score
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def _calculate_k_anonymity(data: np.ndarray, k: int = 5) -> float:
|
98 |
+
"""Calculate approximate k-anonymity score."""
|
99 |
+
nn = NearestNeighbors(n_neighbors=k)
|
100 |
+
nn.fit(data)
|
101 |
+
|
102 |
+
distances, _ = nn.kneighbors(data)
|
103 |
+
k_anonymity_scores = distances[:, -1] # Distance to k-th neighbor
|
104 |
+
|
105 |
+
# Convert to score (0 to 1, where higher is better)
|
106 |
+
return float(np.mean(k_anonymity_scores > 0.1))
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def _calculate_uniqueness(data: np.ndarray) -> float:
|
110 |
+
"""Calculate uniqueness score of the dataset."""
|
111 |
+
nn = NearestNeighbors(n_neighbors=2) # 2 because first neighbor is self
|
112 |
+
nn.fit(data)
|
113 |
+
|
114 |
+
distances, _ = nn.kneighbors(data)
|
115 |
+
uniqueness_scores = distances[:, 1] # Distance to nearest non-self neighbor
|
116 |
+
|
117 |
+
# Convert to score (0 to 1, where higher means more unique records)
|
118 |
+
return float(np.mean(uniqueness_scores > np.median(uniqueness_scores)))
|
119 |
+
|
src/utils/privacy.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""privacy.py module."""
|
2 |
+
|
3 |
+
import tensorflow_privacy as tfp
|
4 |
+
from typing import Dict, Any
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class PrivacyManager:
|
8 |
+
def __init__(self, config: Dict[str, Any]):
|
9 |
+
self.epsilon = config['privacy']['epsilon']
|
10 |
+
self.delta = config['privacy']['delta']
|
11 |
+
self.noise_multiplier = config['privacy']['noise_multiplier']
|
12 |
+
|
13 |
+
def add_noise_to_gradients(self, gradients: np.ndarray) -> np.ndarray:
|
14 |
+
"""Add Gaussian noise to gradients for differential privacy."""
|
15 |
+
noise = np.random.normal(0, self.noise_multiplier, gradients.shape)
|
16 |
+
return gradients + noise
|
17 |
+
|
18 |
+
def verify_privacy_budget(self, num_iterations: int) -> bool:
|
19 |
+
"""Check if training stays within privacy budget."""
|
20 |
+
eps = self.compute_epsilon(num_iterations)
|
21 |
+
return eps <= self.epsilon
|
22 |
+
|
23 |
+
def compute_epsilon(self, num_iterations: int) -> float:
|
24 |
+
"""Compute the current epsilon value."""
|
25 |
+
q = 1.0 # sampling ratio
|
26 |
+
steps = num_iterations
|
27 |
+
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
|
28 |
+
list(range(5, 64)) + [128, 256, 512])
|
29 |
+
|
30 |
+
return tfp.compute_dp_sgd_privacy(
|
31 |
+
n=1000, # number of training points
|
32 |
+
batch_size=32,
|
33 |
+
noise_multiplier=self.noise_multiplier,
|
34 |
+
epochs=steps,
|
35 |
+
delta=self.delta
|
36 |
+
)[0]
|
37 |
+
|
tests/test_client.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""test_client.py module."""
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
import tensorflow as tf
|
5 |
+
import yaml
|
6 |
+
from src.client.data_handler import FinancialDataHandler
|
7 |
+
from src.client.model import FederatedClient
|
8 |
+
|
9 |
+
@pytest.fixture
|
10 |
+
def config():
|
11 |
+
"""Load test configuration."""
|
12 |
+
with open('config/client_config.yaml', 'r') as f:
|
13 |
+
return yaml.safe_load(f)['client']
|
14 |
+
|
15 |
+
def test_data_handler(config):
|
16 |
+
"""Test data handler functionality."""
|
17 |
+
handler = FinancialDataHandler(config)
|
18 |
+
|
19 |
+
# Test data simulation
|
20 |
+
data = handler.simulate_financial_data(num_samples=100)
|
21 |
+
assert len(data) == 100
|
22 |
+
assert all(col in data.columns for col in [
|
23 |
+
'transaction_amount',
|
24 |
+
'account_balance',
|
25 |
+
'transaction_frequency',
|
26 |
+
'credit_score',
|
27 |
+
'days_since_last_transaction'
|
28 |
+
])
|
29 |
+
|
30 |
+
# Test preprocessing
|
31 |
+
dataset, scaler = handler.get_client_data()
|
32 |
+
assert isinstance(dataset, tf.data.Dataset)
|
33 |
+
|
34 |
+
def test_federated_client(config):
|
35 |
+
"""Test federated client functionality."""
|
36 |
+
client = FederatedClient(config)
|
37 |
+
|
38 |
+
# Test model building
|
39 |
+
assert isinstance(client.model, tf.keras.Model)
|
40 |
+
|
41 |
+
# Test local training
|
42 |
+
handler = FinancialDataHandler(config)
|
43 |
+
dataset, _ = handler.get_client_data()
|
44 |
+
|
45 |
+
training_result = client.train_local_model(dataset, epochs=1)
|
46 |
+
assert 'client_id' in training_result
|
47 |
+
assert 'weights' in training_result
|
48 |
+
assert 'metrics' in training_result
|
49 |
+
|