Spaces:
Build error
Build error
| # # ################# | |
| # import streamlit as st | |
| # import matplotlib.pyplot as plt | |
| # import torch | |
| # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW | |
| # from datasets import load_dataset, Dataset | |
| # from evaluate import load as load_metric | |
| # from torch.utils.data import DataLoader | |
| # import pandas as pd | |
| # import random | |
| # from collections import OrderedDict | |
| # import flwr as fl | |
| # from logging import INFO, DEBUG | |
| # from flwr.common.logger import log | |
| # import logging | |
| # import re | |
| # import plotly.graph_objects as go | |
| # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt") | |
| # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2): | |
| # raw_datasets = load_dataset(dataset_name) | |
| # raw_datasets = raw_datasets.shuffle(seed=42) | |
| # del raw_datasets["unsupervised"] | |
| # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| # def tokenize_function(examples): | |
| # return tokenizer(examples["text"], truncation=True) | |
| # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
| # tokenized_datasets = tokenized_datasets.remove_columns("text") | |
| # tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
| # train_datasets = [] | |
| # test_datasets = [] | |
| # for _ in range(num_clients): | |
| # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size)) | |
| # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size)) | |
| # train_datasets.append(train_dataset) | |
| # test_datasets.append(test_dataset) | |
| # data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| # return train_datasets, test_datasets, data_collator, raw_datasets | |
| # def train(net, trainloader, epochs): | |
| # optimizer = AdamW(net.parameters(), lr=5e-5) | |
| # net.train() | |
| # for _ in range(epochs): | |
| # for batch in trainloader: | |
| # batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
| # outputs = net(**batch) | |
| # loss = outputs.loss | |
| # loss.backward() | |
| # optimizer.step() | |
| # optimizer.zero_grad() | |
| # def test(net, testloader): | |
| # metric = load_metric("accuracy") | |
| # net.eval() | |
| # loss = 0 | |
| # for batch in testloader: | |
| # batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
| # with torch.no_grad(): | |
| # outputs = net(**batch) | |
| # logits = outputs.logits | |
| # loss += outputs.loss.item() | |
| # predictions = torch.argmax(logits, dim=-1) | |
| # metric.add_batch(predictions=predictions, references=batch["labels"]) | |
| # loss /= len(testloader) | |
| # accuracy = metric.compute()["accuracy"] | |
| # return loss, accuracy | |
| # class CustomClient(fl.client.NumPyClient): | |
| # def __init__(self, net, trainloader, testloader, client_id): | |
| # self.net = net | |
| # self.trainloader = trainloader | |
| # self.testloader = testloader | |
| # self.client_id = client_id | |
| # self.losses = [] | |
| # self.accuracies = [] | |
| # def get_parameters(self, config): | |
| # return [val.cpu().numpy() for _, val in self.net.state_dict().items()] | |
| # def set_parameters(self, parameters): | |
| # params_dict = zip(self.net.state_dict().keys(), parameters) | |
| # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | |
| # self.net.load_state_dict(state_dict, strict=True) | |
| # def fit(self, parameters, config): | |
| # log(INFO, f"Client {self.client_id} is starting fit()") | |
| # self.set_parameters(parameters) | |
| # train(self.net, self.trainloader, epochs=1) | |
| # loss, accuracy = test(self.net, self.testloader) | |
| # self.losses.append(loss) | |
| # self.accuracies.append(accuracy) | |
| # log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
| # return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy} | |
| # def evaluate(self, parameters, config): | |
| # log(INFO, f"Client {self.client_id} is starting evaluate()") | |
| # self.set_parameters(parameters) | |
| # loss, accuracy = test(self.net, self.testloader) | |
| # log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
| # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)} | |
| # def plot_metrics(self, round_num, plot_placeholder): | |
| # if self.losses and self.accuracies: | |
| # plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}") | |
| # plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}") | |
| # plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}") | |
| # fig, ax1 = plt.subplots() | |
| # color = 'tab:red' | |
| # ax1.set_xlabel('Round') | |
| # ax1.set_ylabel('Loss', color=color) | |
| # ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color) | |
| # ax1.tick_params(axis='y', labelcolor=color) | |
| # ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis | |
| # color = 'tab:blue' | |
| # ax2.set_ylabel('Accuracy', color=color) | |
| # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color) | |
| # ax2.tick_params(axis='y', labelcolor=color) | |
| # fig.tight_layout() | |
| # plot_placeholder.pyplot(fig) | |
| # def read_log_file(log_path='./log.txt'): | |
| # with open(log_path, 'r') as file: | |
| # log_lines = file.readlines() | |
| # return log_lines | |
| # def parse_log(log_lines): | |
| # rounds = [] | |
| # clients = {} | |
| # memory_usage = [] | |
| # round_pattern = re.compile(r'ROUND (\d+)') | |
| # client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)') | |
| # memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB') | |
| # current_round = None | |
| # for line in log_lines: | |
| # round_match = round_pattern.search(line) | |
| # client_match = client_pattern.search(line) | |
| # memory_match = memory_pattern.search(line) | |
| # if round_match: | |
| # current_round = int(round_match.group(1)) | |
| # rounds.append(current_round) | |
| # elif client_match: | |
| # client_id = int(client_match.group(1)) | |
| # log_level = client_match.group(2) | |
| # message = client_match.group(3) | |
| # if client_id not in clients: | |
| # clients[client_id] = {'rounds': [], 'messages': []} | |
| # clients[client_id]['rounds'].append(current_round) | |
| # clients[client_id]['messages'].append((log_level, message)) | |
| # elif memory_match: | |
| # memory_usage.append(float(memory_match.group(1))) | |
| # return rounds, clients, memory_usage | |
| # def plot_metrics(rounds, clients, memory_usage): | |
| # st.write("## Metrics Overview") | |
| # st.write("### Memory Usage") | |
| # plt.figure() | |
| # plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)') | |
| # plt.xlabel('Step') | |
| # plt.ylabel('Memory Usage (GB)') | |
| # plt.legend() | |
| # st.pyplot(plt) | |
| # for client_id, data in clients.items(): | |
| # st.write(f"### Client {client_id} Metrics") | |
| # info_messages = [msg for level, msg in data['messages'] if level == 'INFO'] | |
| # debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG'] | |
| # st.write("#### INFO Messages") | |
| # for msg in info_messages: | |
| # st.write(msg) | |
| # st.write("#### DEBUG Messages") | |
| # for msg in debug_messages: | |
| # st.write(msg) | |
| # # Placeholder for actual loss and accuracy values, assuming they're included in the messages | |
| # losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg] | |
| # accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg] | |
| # if losses: | |
| # plt.figure() | |
| # plt.plot(data['rounds'], losses, label='Loss') | |
| # plt.xlabel('Round') | |
| # plt.ylabel('Loss') | |
| # plt.legend() | |
| # st.pyplot(plt) | |
| # if accuracies: | |
| # plt.figure() | |
| # plt.plot(data['rounds'], accuracies, label='Accuracy') | |
| # plt.xlabel('Round') | |
| # plt.ylabel('Accuracy') | |
| # plt.legend() | |
| # st.pyplot(plt) | |
| # def read_log_file2(): | |
| # with open("./log.txt", "r") as file: | |
| # return file.read() | |
| # def main(): | |
| # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices") | |
| # logs = read_log_file2() | |
| # # cleanLogs = # Define a pattern to match relevant log entries | |
| # pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE) | |
| # # Filter the log data | |
| # filtered_logs = [line for line in logs.splitlines() if pattern.search(line)] | |
| # st.markdown(filtered_logs) | |
| # # Provide a download button for the logs | |
| # st.download_button( | |
| # label="Download Logs", | |
| # data="\n".join(filtered_logs), | |
| # file_name="./log.txt", | |
| # mime="text/plain" | |
| # ) | |
| # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"]) | |
| # model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"]) | |
| # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2) | |
| # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3) | |
| # train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS) | |
| # trainloaders = [] | |
| # testloaders = [] | |
| # clients = [] | |
| # for i in range(NUM_CLIENTS): | |
| # st.write(f"### Client {i+1} Datasets") | |
| # train_df = pd.DataFrame(train_datasets[i]) | |
| # test_df = pd.DataFrame(test_datasets[i]) | |
| # st.write("#### Train Dataset (Words)") | |
| # st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20))) | |
| # st.write("#### Train Dataset (Tokens)") | |
| # edited_train_df = st.data_editor(train_df, key=f"train_{i}") | |
| # st.write("#### Test Dataset (Words)") | |
| # st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20))) | |
| # st.write("#### Test Dataset (Tokens)") | |
| # edited_test_df = st.data_editor(test_df, key=f"test_{i}") | |
| # edited_train_dataset = Dataset.from_pandas(edited_train_df) | |
| # edited_test_dataset = Dataset.from_pandas(edited_test_df) | |
| # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator) | |
| # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator) | |
| # trainloaders.append(trainloader) | |
| # testloaders.append(testloader) | |
| # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE) | |
| # client = CustomClient(net, trainloader, testloader, client_id=i+1) | |
| # clients.append(client) | |
| # if st.button("Start Training"): | |
| # def client_fn(cid): | |
| # return clients[int(cid)].to_client() | |
| # def weighted_average(metrics): | |
| # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | |
| # losses = [num_examples * m["loss"] for num_examples, m in metrics] | |
| # examples = [num_examples for num_examples, _ in metrics] | |
| # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} | |
| # strategy = fl.server.strategy.FedAvg( | |
| # fraction_fit=1.0, | |
| # fraction_evaluate=1.0, | |
| # evaluate_metrics_aggregation_fn=weighted_average, | |
| # ) | |
| # for round_num in range(NUM_ROUNDS): | |
| # st.write(f"### Round {round_num + 1} ✅") | |
| # logs = read_log_file2() | |
| # filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)] | |
| # filtered_logs = "\n".join(filtered_log_list) | |
| # st.markdown(filtered_logs) | |
| # # Provide a download button for the logs | |
| # # st.download_button( | |
| # # label="Download Logs", | |
| # # data=logs, | |
| # # file_name="./log.txt", | |
| # # mime="text/plain" | |
| # # ) | |
| # # # Extract relevant data | |
| # accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}") | |
| # loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}") | |
| # accuracy_matches = accuracy_pattern.findall(filtered_logs) | |
| # loss_matches = loss_pattern.findall(filtered_logs) | |
| # rounds = [int(match[0]) for match in accuracy_matches] | |
| # accuracies = [float(match[1]) for match in accuracy_matches] | |
| # losses = [float(match[1]) for match in loss_matches] | |
| # # Create accuracy plot | |
| # accuracy_fig = go.Figure() | |
| # accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy')) | |
| # accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy') | |
| # # Create loss plot | |
| # loss_fig = go.Figure() | |
| # loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss')) | |
| # loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss') | |
| # # Display plots in Streamlit | |
| # st.plotly_chart(accuracy_fig) | |
| # st.plotly_chart(loss_fig) | |
| # # Display data table | |
| # data = { | |
| # 'Round': rounds, | |
| # 'Accuracy': accuracies, | |
| # 'Loss': losses | |
| # } | |
| # df = pd.DataFrame(data) | |
| # st.write("## Training Metrics") | |
| # st.table(df) | |
| # plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)] | |
| # fl.simulation.start_simulation( | |
| # client_fn=client_fn, | |
| # num_clients=NUM_CLIENTS, | |
| # config=fl.server.ServerConfig(num_rounds=1), | |
| # strategy=strategy, | |
| # client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}, | |
| # ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)} | |
| # ) | |
| # for i, client in enumerate(clients): | |
| # client.plot_metrics(round_num + 1, plot_placeholders[i]) | |
| # st.write(" ") | |
| # st.success("Training completed successfully!") | |
| # # Display final metrics | |
| # st.write("## Final Client Metrics") | |
| # for client in clients: | |
| # st.write(f"### Client {client.client_id}") | |
| # if client.losses and client.accuracies: | |
| # st.write(f"Final Loss: {client.losses[-1]:.4f}") | |
| # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}") | |
| # client.plot_metrics(NUM_ROUNDS, st.empty()) | |
| # else: | |
| # st.write("No metrics available.") | |
| # st.write(" ") | |
| # # Display log.txt content | |
| # st.write("## Training Log") | |
| # st.write(read_log_file2()) | |
| # st.write("## Training Log Analysis") | |
| # log_lines = read_log_file() | |
| # rounds, clients, memory_usage = parse_log(log_lines) | |
| # plot_metrics(rounds, clients, memory_usage) | |
| # else: | |
| # st.write("Click the 'Start Training' button to start the training process.") | |
| # if __name__ == "__main__": | |
| # main() | |
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| from datasets import load_dataset, Dataset | |
| from evaluate import load as load_metric | |
| from torch.utils.data import DataLoader | |
| import pandas as pd | |
| import random | |
| from collections import OrderedDict | |
| import flwr as fl | |
| from logging import INFO, DEBUG | |
| from flwr.common.logger import log | |
| import logging | |
| import re | |
| import plotly.graph_objects as go | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt") | |
| class CustomDataCollator: | |
| def __init__(self, pad_token_id=0): | |
| self.pad_token_id = pad_token_id | |
| def __call__(self, features): | |
| max_length = max(len(f["input_ids"]) for f in features) | |
| for f in features: | |
| f['input_ids'] += [self.pad_token_id] * (max_length - len(f['input_ids'])) | |
| batch = {k: torch.tensor([f[k] for f in features]) for k in features[0].keys()} | |
| return batch | |
| def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False, model_name="bert-base-uncased"): | |
| raw_datasets = load_dataset(dataset_name) | |
| raw_datasets = raw_datasets.shuffle(seed=42) | |
| del raw_datasets["unsupervised"] | |
| if model_name == "google/byt5-small": | |
| tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| def utf8_encode_function(examples): | |
| encoded_texts = [text.encode('utf-8') for text in examples["text"]] | |
| examples["input_ids"] = [tokenizer(list(encoded_text), return_tensors="pt", padding='max_length', truncation=True, max_length=512)["input_ids"].squeeze().tolist() for encoded_text in encoded_texts] | |
| return examples | |
| tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True) | |
| tokenized_datasets = tokenized_datasets.remove_columns("text") | |
| tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| def tokenize_function(examples): | |
| return tokenizer(examples["text"], truncation=True) | |
| tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
| tokenized_datasets = tokenized_datasets.remove_columns("text") | |
| tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
| train_datasets = [] | |
| test_datasets = [] | |
| for _ in range(num_clients): | |
| train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size)) | |
| test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size)) | |
| train_datasets.append(train_dataset) | |
| test_datasets.append(test_dataset) | |
| data_collator = CustomDataCollator(pad_token_id=tokenizer.pad_token_id) | |
| return train_datasets, test_datasets, data_collator, raw_datasets | |
| def train(net, trainloader, epochs): | |
| optimizer = AdamW(net.parameters(), lr=5e-5) | |
| net.train() | |
| for _ in range(epochs): | |
| for batch in trainloader: | |
| batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
| outputs = net(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| def test(net, testloader): | |
| metric = load_metric("accuracy") | |
| net.eval() | |
| loss = 0 | |
| for batch in testloader: | |
| batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
| with torch.no_grad(): | |
| outputs = net(**batch) | |
| logits = outputs.logits | |
| loss += outputs.loss.item() | |
| predictions = torch.argmax(logits, dim=-1) | |
| metric.add_batch(predictions=predictions, references=batch["labels"]) | |
| loss /= len(testloader) | |
| accuracy = metric.compute()["accuracy"] | |
| return loss, accuracy | |
| class CustomClient(fl.client.NumPyClient): | |
| def __init__(self, net, trainloader, testloader, client_id): | |
| self.net = net | |
| self.trainloader = trainloader | |
| self.testloader = testloader | |
| self.client_id = client_id | |
| self.losses = [] | |
| self.accuracies = [] | |
| def get_parameters(self, config): | |
| return [val.cpu().numpy() for _, val in self.net.state_dict().items()] | |
| def set_parameters(self, parameters): | |
| params_dict = zip(self.net.state_dict().keys(), parameters) | |
| state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | |
| self.net.load_state_dict(state_dict, strict=True) | |
| def fit(self, parameters, config): | |
| log(INFO, f"Client {self.client_id} is starting fit()") | |
| self.set_parameters(parameters) | |
| train(self.net, self.trainloader, epochs=1) | |
| loss, accuracy = test(self.net, self.testloader) | |
| self.losses.append(loss) | |
| self.accuracies.append(accuracy) | |
| log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
| return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy} | |
| def evaluate(self, parameters, config): | |
| log(INFO, f"Client {self.client_id} is starting evaluate()") | |
| self.set_parameters(parameters) | |
| loss, accuracy = test(self.net, self.testloader) | |
| log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
| return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)} | |
| def plot_metrics(self, round_num, plot_placeholder): | |
| if self.losses and self.accuracies: | |
| plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}") | |
| plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}") | |
| plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}") | |
| fig, ax1 = plt.subplots() | |
| color = 'tab:red' | |
| ax1.set_xlabel('Round') | |
| ax1.set_ylabel('Loss', color=color) | |
| ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color) | |
| ax1.tick_params(axis='y', labelcolor=color) | |
| ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis | |
| color = 'tab:blue' | |
| ax2.set_ylabel('Accuracy', color=color) | |
| ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color) | |
| ax2.tick_params(axis='y', labelcolor=color) | |
| fig.tight_layout() | |
| plot_placeholder.pyplot(fig) | |
| def read_log_file(log_path='./log.txt'): | |
| with open(log_path, 'r') as file: | |
| log_lines = file.readlines() | |
| return log_lines | |
| def parse_log(log_lines): | |
| rounds = [] | |
| clients = {} | |
| memory_usage = [] | |
| round_pattern = re.compile(r'ROUND (\d+)') | |
| client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)') | |
| memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB') | |
| current_round = None | |
| for line in log_lines: | |
| round_match = round_pattern.search(line) | |
| client_match = client_pattern.search(line) | |
| memory_match = memory_pattern.search(line) | |
| if round_match: | |
| current_round = int(round_match.group(1)) | |
| rounds.append(current_round) | |
| elif client_match: | |
| client_id = int(client_match.group(1)) | |
| log_level = client_match.group(2) | |
| message = client_match.group(3) | |
| if client_id not in clients: | |
| clients[client_id] = {'rounds': [], 'messages': []} | |
| clients[client_id]['rounds'].append(current_round) | |
| clients[client_id]['messages'].append((log_level, message)) | |
| elif memory_match: | |
| memory_usage.append(float(memory_match.group(1))) | |
| return rounds, clients, memory_usage | |
| def plot_metrics(rounds, clients, memory_usage): | |
| st.write("## Metrics Overview") | |
| st.write("### Memory Usage") | |
| plt.figure() | |
| plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)') | |
| plt.xlabel('Step') | |
| plt.ylabel('Memory Usage (GB)') | |
| plt.legend() | |
| st.pyplot(plt) | |
| for client_id, data in clients.items(): | |
| st.write(f"### Client {client_id} Metrics") | |
| info_messages = [msg for level, msg in data['messages'] if level == 'INFO'] | |
| debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG'] | |
| st.write("#### INFO Messages") | |
| for msg in info_messages: | |
| st.write(msg) | |
| st.write("#### DEBUG Messages") | |
| for msg in debug_messages: | |
| st.write(msg) | |
| losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg] | |
| accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg] | |
| if losses: | |
| plt.figure() | |
| plt.plot(data['rounds'], losses, label='Loss') | |
| plt.xlabel('Round') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| st.pyplot(plt) | |
| if accuracies: | |
| plt.figure() | |
| plt.plot(data['rounds'], accuracies, label='Accuracy') | |
| plt.xlabel('Round') | |
| plt.ylabel('Accuracy') | |
| plt.legend() | |
| st.pyplot(plt) | |
| def read_log_file2(): | |
| with open("./log.txt", "r") as file: | |
| return file.read() | |
| def main(): | |
| st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices") | |
| logs = read_log_file2() | |
| pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE) | |
| filtered_logs = [line for line in logs.splitlines() if pattern.search(line)] | |
| st.markdown(filtered_logs) | |
| st.download_button( | |
| label="Download Logs", | |
| data="\n".join(filtered_logs), | |
| file_name="./log.txt", | |
| mime="text/plain" | |
| ) | |
| dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"]) | |
| model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased", "google/byt5-small"]) | |
| NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2) | |
| NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3) | |
| use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False) | |
| train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8, model_name=model_name) | |
| trainloaders = [] | |
| testloaders = [] | |
| clients = [] | |
| for i in range(NUM_CLIENTS): | |
| st.write(f"### Client {i+1} Datasets") | |
| train_df = pd.DataFrame(train_datasets[i]) | |
| test_df = pd.DataFrame(test_datasets[i]) | |
| st.write("#### Train Dataset (Words)") | |
| st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20))) | |
| st.write("#### Train Dataset (Tokens)") | |
| edited_train_df = st.data_editor(train_df, key=f"train_{i}") | |
| st.write("#### Test Dataset (Words)") | |
| st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20))) | |
| st.write("#### Test Dataset (Tokens)") | |
| edited_test_df = st.data_editor(test_df, key=f"test_{i}") | |
| edited_train_dataset = Dataset.from_pandas(edited_train_df) | |
| edited_test_dataset = Dataset.from_pandas(edited_test_df) | |
| trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator) | |
| testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator) | |
| net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE) | |
| client = CustomClient(net, trainloader, testloader, client_id=i+1) | |
| clients.append(client) | |
| if st.button("Start Training"): | |
| def client_fn(cid): | |
| return clients[int(cid)].to_client() | |
| def weighted_average(metrics): | |
| accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | |
| losses = [num_examples * m["loss"] for num_examples, m in metrics] | |
| examples = [num_examples for num_examples, _ in metrics] | |
| return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} | |
| strategy = fl.server.strategy.FedAvg( | |
| fraction_fit=1.0, | |
| fraction_evaluate=1.0, | |
| evaluate_metrics_aggregation_fn=weighted_average, | |
| ) | |
| for round_num in range(NUM_ROUNDS): | |
| st.write(f"### Round {round_num + 1} ✅") | |
| logs = read_log_file2() | |
| filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)] | |
| filtered_logs = "\n".join(filtered_log_list) | |
| st.markdown(filtered_logs) | |
| accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}") | |
| loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}") | |
| accuracy_matches = accuracy_pattern.findall(filtered_logs) | |
| loss_matches = loss_pattern.findall(filtered_logs) | |
| rounds = [int(match[0]) for match in accuracy_matches] | |
| accuracies = [float(match[1]) for match in accuracy_matches] | |
| losses = [float(match[1]) for match in loss_matches] | |
| accuracy_fig = go.Figure() | |
| accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy')) | |
| accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy') | |
| loss_fig = go.Figure() | |
| loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss')) | |
| loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss') | |
| st.plotly_chart(accuracy_fig) | |
| st.plotly_chart(loss_fig) | |
| data = { | |
| 'Round': rounds, | |
| 'Accuracy': accuracies, | |
| 'Loss': losses | |
| } | |
| df = pd.DataFrame(data) | |
| st.write("## Training Metrics") | |
| st.table(df) | |
| plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)] | |
| fl.simulation.start_simulation( | |
| client_fn=client_fn, | |
| num_clients=NUM_CLIENTS, | |
| config=fl.server.ServerConfig(num_rounds=1), | |
| strategy=strategy, | |
| client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}, | |
| ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)} | |
| ) | |
| for i, client in enumerate(clients): | |
| client.plot_metrics(round_num + 1, plot_placeholders[i]) | |
| st.write(" ") | |
| st.success("Training completed successfully!") | |
| st.write("## Final Client Metrics") | |
| for client in clients: | |
| st.write(f"### Client {client.client_id}") | |
| if client.losses and client.accuracies: | |
| st.write(f"Final Loss: {client.losses[-1]:.4f}") | |
| st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}") | |
| client.plot_metrics(NUM_ROUNDS, st.empty()) | |
| else: | |
| st.write("No metrics available.") | |
| st.write(" ") | |
| st.write("## Training Log") | |
| st.write(read_log_file2()) | |
| st.write("## Training Log Analysis") | |
| log_lines = read_log_file() | |
| rounds, clients, memory_usage = parse_log(log_lines) | |
| plot_metrics(rounds, clients, memory_usage) | |
| else: | |
| st.write("Click the 'Start Training' button to start the training process.") | |
| if __name__ == "__main__": | |
| main() | |