Spaces:
Build error
Build error
| # # %%writefile app.py | |
| # import streamlit as st | |
| # import matplotlib.pyplot as plt | |
| # import torch | |
| # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW | |
| # from datasets import load_dataset, Dataset | |
| # from evaluate import load as load_metric | |
| # from torch.utils.data import DataLoader | |
| # import pandas as pd | |
| # import random | |
| # from collections import OrderedDict | |
| # import flwr as fl | |
| # DEVICE = torch.device("cpu") | |
| # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2): | |
| # raw_datasets = load_dataset(dataset_name) | |
| # raw_datasets = raw_datasets.shuffle(seed=42) | |
| # del raw_datasets["unsupervised"] | |
| # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| # def tokenize_function(examples): | |
| # return tokenizer(examples["text"], truncation=True) | |
| # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
| # tokenized_datasets = tokenized_datasets.remove_columns("text") | |
| # tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
| # train_datasets = [] | |
| # test_datasets = [] | |
| # for _ in range(num_clients): | |
| # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size)) | |
| # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size)) | |
| # train_datasets.append(train_dataset) | |
| # test_datasets.append(test_dataset) | |
| # data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| # return train_datasets, test_datasets, data_collator | |
| # def read_log_file(): | |
| # with open("./log.txt", "r") as file: | |
| # return file.read() | |
| # def train(net, trainloader, epochs): | |
| # optimizer = AdamW(net.parameters(), lr=5e-5) | |
| # net.train() | |
| # for _ in range(epochs): | |
| # for batch in trainloader: | |
| # batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
| # outputs = net(**batch) | |
| # loss = outputs.loss | |
| # loss.backward() | |
| # optimizer.step() | |
| # optimizer.zero_grad() | |
| # def test(net, testloader): | |
| # metric = load_metric("accuracy") | |
| # net.eval() | |
| # loss = 0 | |
| # for batch in testloader: | |
| # batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
| # with torch.no_grad(): | |
| # outputs = net(**batch) | |
| # logits = outputs.logits | |
| # loss += outputs.loss.item() | |
| # predictions = torch.argmax(logits, dim=-1) | |
| # metric.add_batch(predictions=predictions, references=batch["labels"]) | |
| # loss /= len(testloader) | |
| # accuracy = metric.compute()["accuracy"] | |
| # return loss, accuracy | |
| # class CustomClient(fl.client.NumPyClient): | |
| # def __init__(self, net, trainloader, testloader, client_id): | |
| # self.net = net | |
| # self.trainloader = trainloader | |
| # self.testloader = testloader | |
| # self.client_id = client_id | |
| # self.losses = [] | |
| # self.accuracies = [] | |
| # def get_parameters(self, config): | |
| # return [val.cpu().numpy() for _, val in self.net.state_dict().items()] | |
| # def set_parameters(self, parameters): | |
| # params_dict = zip(self.net.state_dict().keys(), parameters) | |
| # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | |
| # self.net.load_state_dict(state_dict, strict=True) | |
| # def fit(self, parameters, config): | |
| # self.set_parameters(parameters) | |
| # train(self.net, self.trainloader, epochs=1) | |
| # loss, accuracy = test(self.net, self.testloader) | |
| # self.losses.append(loss) | |
| # self.accuracies.append(accuracy) | |
| # return self.get_parameters(config={}), len(self.trainloader.dataset), {} | |
| # def evaluate(self, parameters, config): | |
| # self.set_parameters(parameters) | |
| # loss, accuracy = test(self.net, self.testloader) | |
| # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)} | |
| # def plot_metrics(self, round_num, plot_placeholder): | |
| # if self.losses and self.accuracies: | |
| # plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}") | |
| # plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}") | |
| # plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}") | |
| # fig, ax1 = plt.subplots() | |
| # color = 'tab:red' | |
| # ax1.set_xlabel('Round') | |
| # ax1.set_ylabel('Loss', color=color) | |
| # ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color) | |
| # ax1.tick_params(axis='y', labelcolor=color) | |
| # ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis | |
| # color = 'tab:blue' | |
| # ax2.set_ylabel('Accuracy', color=color) | |
| # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color) | |
| # ax2.tick_params(axis='y', labelcolor=color) | |
| # fig.tight_layout() | |
| # plot_placeholder.pyplot(fig) | |
| # def main(): | |
| # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices") | |
| # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"]) | |
| # model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"]) | |
| # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2) | |
| # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3) | |
| # train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS) | |
| # trainloaders = [] | |
| # testloaders = [] | |
| # clients = [] | |
| # for i in range(NUM_CLIENTS): | |
| # st.write(f"### Client {i+1} Datasets") | |
| # train_df = pd.DataFrame(train_datasets[i]) | |
| # test_df = pd.DataFrame(test_datasets[i]) | |
| # st.write("#### Train Dataset") | |
| # edited_train_df = st.data_editor(train_df, key=f"train_{i}") | |
| # st.write("#### Test Dataset") | |
| # edited_test_df = st.data_editor(test_df, key=f"test_{i}") | |
| # edited_train_dataset = Dataset.from_pandas(edited_train_df) | |
| # edited_test_dataset = Dataset.from_pandas(edited_test_df) | |
| # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator) | |
| # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator) | |
| # trainloaders.append(trainloader) | |
| # testloaders.append(testloader) | |
| # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE) | |
| # client = CustomClient(net, trainloader, testloader, client_id=i+1) | |
| # clients.append(client) | |
| # if st.button("Start Training"): | |
| # def client_fn(cid): | |
| # return clients[int(cid)] | |
| # def weighted_average(metrics): | |
| # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | |
| # losses = [num_examples * m["loss"] for num_examples, m in metrics] | |
| # examples = [num_examples for num_examples, _ in metrics] | |
| # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} | |
| # strategy = fl.server.strategy.FedAvg( | |
| # fraction_fit=1.0, | |
| # fraction_evaluate=1.0, | |
| # evaluate_metrics_aggregation_fn=weighted_average, | |
| # ) | |
| # for round_num in range(NUM_ROUNDS): | |
| # st.write(f"### Round {round_num + 1}") | |
| # plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)] | |
| # fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt") | |
| # fl.simulation.start_simulation( | |
| # client_fn=client_fn, | |
| # num_clients=NUM_CLIENTS, | |
| # config=fl.server.ServerConfig(num_rounds=1), | |
| # strategy=strategy, | |
| # client_resources={"num_cpus": 1, "num_gpus": 0}, | |
| # ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0} | |
| # ) | |
| # for i, client in enumerate(clients): | |
| # st.markdown("LOGS : "+ read_log_file()) | |
| # client.plot_metrics(round_num + 1, plot_placeholders[i]) | |
| # st.write(" ") | |
| # st.success("Training completed successfully!") | |
| # # Display final metrics | |
| # st.write("## Final Client Metrics") | |
| # for client in clients: | |
| # st.write(f"### Client {client.client_id}") | |
| # st.write(f"Final Loss: {client.losses[-1]:.4f}") | |
| # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}") | |
| # client.plot_metrics(NUM_ROUNDS, st.empty()) | |
| # st.write(" ") | |
| # else: | |
| # st.write("Click the 'Start Training' button to start the training process.") | |
| # if __name__ == "__main__": | |
| # main() | |
| # %%writefile app.py | |
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW | |
| from datasets import load_dataset, Dataset | |
| from evaluate import load as load_metric | |
| from torch.utils.data import DataLoader | |
| import pandas as pd | |
| import random | |
| from collections import OrderedDict | |
| import flwr as fl | |
| from logging import INFO, DEBUG | |
| from flwr.common.logger import log | |
| import logging | |
| import streamlit | |
| # If you're curious of all the loggers | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt") | |
| def load_data(dataset_name, train_size=20, test_size=20, num_clients=2): | |
| raw_datasets = load_dataset(dataset_name) | |
| raw_datasets = raw_datasets.shuffle(seed=42) | |
| del raw_datasets["unsupervised"] | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| def tokenize_function(examples): | |
| return tokenizer(examples["text"], truncation=True) | |
| tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
| tokenized_datasets = tokenized_datasets.remove_columns("text") | |
| tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
| train_datasets = [] | |
| test_datasets = [] | |
| for _ in range(num_clients): | |
| train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size)) | |
| test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size)) | |
| train_datasets.append(train_dataset) | |
| test_datasets.append(test_dataset) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| return train_datasets, test_datasets, data_collator, raw_datasets | |
| def train(net, trainloader, epochs): | |
| optimizer = AdamW(net.parameters(), lr=5e-5) | |
| net.train() | |
| for _ in range(epochs): | |
| for batch in trainloader: | |
| batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
| outputs = net(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # class SaveModelStrategy(fl.server.strategy.FedAvg): | |
| # def aggregate_fit( | |
| # self, | |
| # server_round: int, | |
| # results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]], | |
| # failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], | |
| # ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: | |
| # """Aggregate model weights using weighted average and store checkpoint""" | |
| # # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics | |
| # aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures) | |
| # if aggregated_parameters is not None: | |
| # print(f"Saving round {server_round} aggregated_parameters...") | |
| # # Convert `Parameters` to `List[np.ndarray]` | |
| # aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters) | |
| # # Convert `List[np.ndarray]` to PyTorch`state_dict` | |
| # params_dict = zip(net.state_dict().keys(), aggregated_ndarrays) | |
| # state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) | |
| # net.load_state_dict(state_dict, strict=True) | |
| # # Save the model | |
| # torch.save(net.state_dict(), f"model_round_{server_round}.pth") | |
| # return aggregated_parameters, aggregated_metrics | |
| def test(net, testloader): | |
| metric = load_metric("accuracy") | |
| net.eval() | |
| loss = 0 | |
| for batch in testloader: | |
| batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
| with torch.no_grad(): | |
| outputs = net(**batch) | |
| logits = outputs.logits | |
| loss += outputs.loss.item() | |
| predictions = torch.argmax(logits, dim=-1) | |
| metric.add_batch(predictions=predictions, references=batch["labels"]) | |
| loss /= len(testloader) | |
| accuracy = metric.compute()["accuracy"] | |
| return loss, accuracy | |
| class CustomClient(fl.client.NumPyClient): | |
| def __init__(self, net, trainloader, testloader, client_id): | |
| self.net = net | |
| self.trainloader = trainloader | |
| self.testloader = testloader | |
| self.client_id = client_id | |
| self.losses = [] | |
| self.accuracies = [] | |
| def get_parameters(self, config): | |
| return [val.cpu().numpy() for _, val in self.net.state_dict().items()] | |
| def set_parameters(self, parameters): | |
| params_dict = zip(self.net.state_dict().keys(), parameters) | |
| state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | |
| self.net.load_state_dict(state_dict, strict=True) | |
| def fit(self, parameters, config): | |
| log(INFO, f"Client {self.client_id} is starting fit()") | |
| self.set_parameters(parameters) | |
| train(self.net, self.trainloader, epochs=1) | |
| loss, accuracy = test(self.net, self.testloader) | |
| self.losses.append(loss) | |
| self.accuracies.append(accuracy) | |
| log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
| return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy} | |
| def evaluate(self, parameters, config): | |
| log(INFO, f"Client {self.client_id} is starting evaluate()") | |
| self.set_parameters(parameters) | |
| loss, accuracy = test(self.net, self.testloader) | |
| log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
| return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)} | |
| def plot_metrics(self, round_num, plot_placeholder): | |
| if self.losses and self.accuracies: | |
| plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}") | |
| plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}") | |
| plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}") | |
| fig, ax1 = plt.subplots() | |
| color = 'tab:red' | |
| ax1.set_xlabel('Round') | |
| ax1.set_ylabel('Loss', color=color) | |
| ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color) | |
| ax1.tick_params(axis='y', labelcolor=color) | |
| ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis | |
| color = 'tab:blue' | |
| ax2.set_ylabel('Accuracy', color=color) | |
| ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color) | |
| ax2.tick_params(axis='y', labelcolor=color) | |
| fig.tight_layout() | |
| plot_placeholder.pyplot(fig) | |
| import matplotlib.pyplot as plt | |
| import re | |
| def read_log_file(log_path='./log.txt'): | |
| with open(log_path, 'r') as file: | |
| log_lines = file.readlines() | |
| return log_lines | |
| def parse_log(log_lines): | |
| rounds = [] | |
| clients = {} | |
| memory_usage = [] | |
| round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)') | |
| client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)') | |
| memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB') | |
| current_round = None | |
| for line in log_lines: | |
| round_match = round_pattern.search(line) | |
| client_match = client_pattern.search(line) | |
| memory_match = memory_pattern.search(line) | |
| if round_match: | |
| current_round = int(round_match.group(1)) | |
| rounds.append(current_round) | |
| elif client_match: | |
| client_id = int(client_match.group(1)) | |
| log_level = client_match.group(2) | |
| message = client_match.group(3) | |
| if client_id not in clients: | |
| clients[client_id] = {'rounds': [], 'messages': []} | |
| clients[client_id]['rounds'].append(current_round) | |
| clients[client_id]['messages'].append((log_level, message)) | |
| elif memory_match: | |
| memory_usage.append(float(memory_match.group(1))) | |
| return rounds, clients, memory_usage | |
| def plot_metrics(rounds, clients, memory_usage): | |
| st.write("## Metrics Overview") | |
| st.write("### Memory Usage") | |
| plt.figure() | |
| plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)') | |
| plt.xlabel('Step') | |
| plt.ylabel('Memory Usage (GB)') | |
| plt.legend() | |
| st.pyplot(plt) | |
| for client_id, data in clients.items(): | |
| st.write(f"### Client {client_id} Metrics") | |
| info_messages = [msg for level, msg in data['messages'] if level == 'INFO'] | |
| debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG'] | |
| st.write("#### INFO Messages") | |
| for msg in info_messages: | |
| st.write(msg) | |
| st.write("#### DEBUG Messages") | |
| for msg in debug_messages: | |
| st.write(msg) | |
| # Placeholder for actual loss and accuracy values, assuming they're included in the messages | |
| losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg] | |
| accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg] | |
| if losses: | |
| plt.figure() | |
| plt.plot(data['rounds'], losses, label='Loss') | |
| plt.xlabel('Round') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| st.pyplot(plt) | |
| if accuracies: | |
| plt.figure() | |
| plt.plot(data['rounds'], accuracies, label='Accuracy') | |
| plt.xlabel('Round') | |
| plt.ylabel('Accuracy') | |
| plt.legend() | |
| st.pyplot(plt) | |
| def read_log_file2(): | |
| with open("./log.txt", "r") as file: | |
| return file.read() | |
| def main(): | |
| st.markdown(print(streamlit.logger._loggers)) | |
| st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices") | |
| dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"]) | |
| model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"]) | |
| NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2) | |
| NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3) | |
| train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS) | |
| trainloaders = [] | |
| testloaders = [] | |
| clients = [] | |
| for i in range(NUM_CLIENTS): | |
| st.write(f"### Client {i+1} Datasets") | |
| train_df = pd.DataFrame(train_datasets[i]) | |
| test_df = pd.DataFrame(test_datasets[i]) | |
| st.write("#### Train Dataset (Words)") | |
| st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20))) | |
| st.write("#### Train Dataset (Tokens)") | |
| edited_train_df = st.data_editor(train_df, key=f"train_{i}") | |
| st.write("#### Test Dataset (Words)") | |
| st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20))) | |
| st.write("#### Test Dataset (Tokens)") | |
| edited_test_df = st.data_editor(test_df, key=f"test_{i}") | |
| edited_train_dataset = Dataset.from_pandas(edited_train_df) | |
| edited_test_dataset = Dataset.from_pandas(edited_test_df) | |
| trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator) | |
| testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator) | |
| trainloaders.append(trainloader) | |
| testloaders.append(testloader) | |
| net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE) | |
| client = CustomClient(net, trainloader, testloader, client_id=i+1) | |
| clients.append(client) | |
| if st.button("Start Training"): | |
| def client_fn(cid): | |
| return clients[int(cid)] | |
| def weighted_average(metrics): | |
| accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | |
| losses = [num_examples * m["loss"] for num_examples, m in metrics] | |
| examples = [num_examples for num_examples, _ in metrics] | |
| return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} | |
| strategy = fl.server.strategy.FedAvg( | |
| fraction_fit=1.0, | |
| fraction_evaluate=1.0, | |
| evaluate_metrics_aggregation_fn=weighted_average, | |
| ) | |
| for round_num in range(NUM_ROUNDS): | |
| st.write(f"### Round {round_num + 1}") | |
| st.markdown(read_log_file2()) | |
| plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)] | |
| fl.simulation.start_simulation( | |
| client_fn=client_fn, | |
| num_clients=NUM_CLIENTS, | |
| config=fl.server.ServerConfig(num_rounds=1), | |
| strategy=strategy, | |
| client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}, | |
| ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)} | |
| ) | |
| for i, client in enumerate(clients): | |
| client.plot_metrics(round_num + 1, plot_placeholders[i]) | |
| st.write(" ") | |
| st.success("Training completed successfully!") | |
| # Display final metrics | |
| st.write("## Final Client Metrics") | |
| for client in clients: | |
| st.write(f"### Client {client.client_id}") | |
| if client.losses and client.accuracies: | |
| st.write(f"Final Loss: {client.losses[-1]:.4f}") | |
| st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}") | |
| client.plot_metrics(NUM_ROUNDS, st.empty()) | |
| else: | |
| st.write("No metrics available.") | |
| st.write(" ") | |
| # Display log.txt content | |
| st.write("## Training Log") | |
| # st.text(read_log_file()) | |
| st.write("## Training Log Analysis") | |
| log_lines = read_log_file() | |
| rounds, clients, memory_usage = parse_log(log_lines) | |
| plot_metrics(rounds, clients, memory_usage) | |
| else: | |
| st.write("Click the 'Start Training' button to start the training process.") | |
| if __name__ == "__main__": | |
| main() | |