Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -392,11 +392,11 @@
|
|
| 392 |
# if __name__ == "__main__":
|
| 393 |
# main()
|
| 394 |
|
| 395 |
-
|
| 396 |
import streamlit as st
|
| 397 |
import matplotlib.pyplot as plt
|
| 398 |
import torch
|
| 399 |
-
from transformers import AutoTokenizer,
|
|
|
|
| 400 |
from datasets import load_dataset, Dataset
|
| 401 |
from evaluate import load as load_metric
|
| 402 |
from torch.utils.data import DataLoader
|
|
@@ -413,35 +413,39 @@ import plotly.graph_objects as go
|
|
| 413 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 414 |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
| 415 |
|
| 416 |
-
class CustomDataCollator
|
|
|
|
|
|
|
|
|
|
| 417 |
def __call__(self, features):
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
max_length
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False):
|
| 426 |
raw_datasets = load_dataset(dataset_name)
|
| 427 |
raw_datasets = raw_datasets.shuffle(seed=42)
|
| 428 |
del raw_datasets["unsupervised"]
|
| 429 |
|
| 430 |
-
if
|
| 431 |
-
tokenizer =
|
| 432 |
|
| 433 |
-
def
|
| 434 |
-
|
|
|
|
| 435 |
|
| 436 |
-
tokenized_datasets = raw_datasets.map(
|
| 437 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
| 438 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 439 |
else:
|
| 440 |
-
|
| 441 |
-
examples["input_ids"] = [list(text.encode('utf-8')) for text in examples["text"]]
|
| 442 |
-
return examples
|
| 443 |
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
| 445 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
| 446 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 447 |
|
|
@@ -454,7 +458,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
|
|
| 454 |
train_datasets.append(train_dataset)
|
| 455 |
test_datasets.append(test_dataset)
|
| 456 |
|
| 457 |
-
data_collator = CustomDataCollator(tokenizer
|
| 458 |
|
| 459 |
return train_datasets, test_datasets, data_collator, raw_datasets
|
| 460 |
|
|
@@ -634,15 +638,11 @@ def read_log_file2():
|
|
| 634 |
def main():
|
| 635 |
st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
|
| 636 |
logs = read_log_file2()
|
| 637 |
-
# cleanLogs = # Define a pattern to match relevant log entries
|
| 638 |
pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE)
|
| 639 |
|
| 640 |
-
|
| 641 |
-
# Filter the log data
|
| 642 |
filtered_logs = [line for line in logs.splitlines() if pattern.search(line)]
|
| 643 |
st.markdown(filtered_logs)
|
| 644 |
|
| 645 |
-
# Provide a download button for the logs
|
| 646 |
st.download_button(
|
| 647 |
label="Download Logs",
|
| 648 |
data="\n".join(filtered_logs),
|
|
@@ -650,13 +650,13 @@ def main():
|
|
| 650 |
mime="text/plain"
|
| 651 |
)
|
| 652 |
dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
|
| 653 |
-
model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
|
| 654 |
|
| 655 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
| 656 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
| 657 |
use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False)
|
| 658 |
|
| 659 |
-
train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8)
|
| 660 |
|
| 661 |
trainloaders = []
|
| 662 |
testloaders = []
|
|
@@ -684,9 +684,6 @@ def main():
|
|
| 684 |
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
|
| 685 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
| 686 |
|
| 687 |
-
trainloaders.append(trainloader)
|
| 688 |
-
testloaders.append(testloader)
|
| 689 |
-
|
| 690 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
| 691 |
client = CustomClient(net, trainloader, testloader, client_id=i+1)
|
| 692 |
clients.append(client)
|
|
@@ -711,18 +708,10 @@ def main():
|
|
| 711 |
st.write(f"### Round {round_num + 1} ✅")
|
| 712 |
|
| 713 |
logs = read_log_file2()
|
| 714 |
-
filtered_log_list = [line for line in logs.splitlines if pattern.search(line)]
|
| 715 |
filtered_logs = "\n".join(filtered_log_list)
|
| 716 |
|
| 717 |
st.markdown(filtered_logs)
|
| 718 |
-
# Provide a download button for the logs
|
| 719 |
-
# st.download_button(
|
| 720 |
-
# label="Download Logs",
|
| 721 |
-
# data=logs,
|
| 722 |
-
# file_name="./log.txt",
|
| 723 |
-
# mime="text/plain"
|
| 724 |
-
# )
|
| 725 |
-
# # Extract relevant data
|
| 726 |
accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
|
| 727 |
loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
|
| 728 |
|
|
@@ -733,21 +722,17 @@ def main():
|
|
| 733 |
accuracies = [float(match[1]) for match in accuracy_matches]
|
| 734 |
losses = [float(match[1]) for match in loss_matches]
|
| 735 |
|
| 736 |
-
# Create accuracy plot
|
| 737 |
accuracy_fig = go.Figure()
|
| 738 |
accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
|
| 739 |
accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
|
| 740 |
|
| 741 |
-
# Create loss plot
|
| 742 |
loss_fig = go.Figure()
|
| 743 |
loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
|
| 744 |
loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
|
| 745 |
|
| 746 |
-
# Display plots in Streamlit
|
| 747 |
st.plotly_chart(accuracy_fig)
|
| 748 |
st.plotly_chart(loss_fig)
|
| 749 |
|
| 750 |
-
# Display data table
|
| 751 |
data = {
|
| 752 |
'Round': rounds,
|
| 753 |
'Accuracy': accuracies,
|
|
@@ -775,7 +760,6 @@ def main():
|
|
| 775 |
|
| 776 |
st.success("Training completed successfully!")
|
| 777 |
|
| 778 |
-
# Display final metrics
|
| 779 |
st.write("## Final Client Metrics")
|
| 780 |
for client in clients:
|
| 781 |
st.write(f"### Client {client.client_id}")
|
|
@@ -788,7 +772,6 @@ def main():
|
|
| 788 |
|
| 789 |
st.write(" ")
|
| 790 |
|
| 791 |
-
# Display log.txt content
|
| 792 |
st.write("## Training Log")
|
| 793 |
st.write(read_log_file2())
|
| 794 |
|
|
|
|
| 392 |
# if __name__ == "__main__":
|
| 393 |
# main()
|
| 394 |
|
|
|
|
| 395 |
import streamlit as st
|
| 396 |
import matplotlib.pyplot as plt
|
| 397 |
import torch
|
| 398 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
|
| 399 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 400 |
from datasets import load_dataset, Dataset
|
| 401 |
from evaluate import load as load_metric
|
| 402 |
from torch.utils.data import DataLoader
|
|
|
|
| 413 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 414 |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
| 415 |
|
| 416 |
+
class CustomDataCollator:
|
| 417 |
+
def __init__(self, pad_token_id=0):
|
| 418 |
+
self.pad_token_id = pad_token_id
|
| 419 |
+
|
| 420 |
def __call__(self, features):
|
| 421 |
+
max_length = max(len(f["input_ids"]) for f in features)
|
| 422 |
+
for f in features:
|
| 423 |
+
f['input_ids'] += [self.pad_token_id] * (max_length - len(f['input_ids']))
|
| 424 |
+
batch = {k: torch.tensor([f[k] for f in features]) for k in features[0].keys()}
|
| 425 |
+
return batch
|
| 426 |
+
|
| 427 |
+
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False, model_name="bert-base-uncased"):
|
|
|
|
| 428 |
raw_datasets = load_dataset(dataset_name)
|
| 429 |
raw_datasets = raw_datasets.shuffle(seed=42)
|
| 430 |
del raw_datasets["unsupervised"]
|
| 431 |
|
| 432 |
+
if model_name == "google/byt5-small":
|
| 433 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 434 |
|
| 435 |
+
def utf8_encode_function(examples):
|
| 436 |
+
examples["input_ids"] = [tokenizer(text.encode('utf-8'), return_tensors="pt")["input_ids"].squeeze().tolist() for text in examples["text"]]
|
| 437 |
+
return examples
|
| 438 |
|
| 439 |
+
tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
|
| 440 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
| 441 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 442 |
else:
|
| 443 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
| 444 |
|
| 445 |
+
def tokenize_function(examples):
|
| 446 |
+
return tokenizer(examples["text"], truncation=True)
|
| 447 |
+
|
| 448 |
+
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
| 449 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
| 450 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 451 |
|
|
|
|
| 458 |
train_datasets.append(train_dataset)
|
| 459 |
test_datasets.append(test_dataset)
|
| 460 |
|
| 461 |
+
data_collator = CustomDataCollator(pad_token_id=tokenizer.pad_token_id)
|
| 462 |
|
| 463 |
return train_datasets, test_datasets, data_collator, raw_datasets
|
| 464 |
|
|
|
|
| 638 |
def main():
|
| 639 |
st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
|
| 640 |
logs = read_log_file2()
|
|
|
|
| 641 |
pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE)
|
| 642 |
|
|
|
|
|
|
|
| 643 |
filtered_logs = [line for line in logs.splitlines() if pattern.search(line)]
|
| 644 |
st.markdown(filtered_logs)
|
| 645 |
|
|
|
|
| 646 |
st.download_button(
|
| 647 |
label="Download Logs",
|
| 648 |
data="\n".join(filtered_logs),
|
|
|
|
| 650 |
mime="text/plain"
|
| 651 |
)
|
| 652 |
dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
|
| 653 |
+
model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased", "google/byt5-small"])
|
| 654 |
|
| 655 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
| 656 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
| 657 |
use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False)
|
| 658 |
|
| 659 |
+
train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8, model_name=model_name)
|
| 660 |
|
| 661 |
trainloaders = []
|
| 662 |
testloaders = []
|
|
|
|
| 684 |
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
|
| 685 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
| 686 |
|
|
|
|
|
|
|
|
|
|
| 687 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
| 688 |
client = CustomClient(net, trainloader, testloader, client_id=i+1)
|
| 689 |
clients.append(client)
|
|
|
|
| 708 |
st.write(f"### Round {round_num + 1} ✅")
|
| 709 |
|
| 710 |
logs = read_log_file2()
|
| 711 |
+
filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)]
|
| 712 |
filtered_logs = "\n".join(filtered_log_list)
|
| 713 |
|
| 714 |
st.markdown(filtered_logs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
|
| 716 |
loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
|
| 717 |
|
|
|
|
| 722 |
accuracies = [float(match[1]) for match in accuracy_matches]
|
| 723 |
losses = [float(match[1]) for match in loss_matches]
|
| 724 |
|
|
|
|
| 725 |
accuracy_fig = go.Figure()
|
| 726 |
accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
|
| 727 |
accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
|
| 728 |
|
|
|
|
| 729 |
loss_fig = go.Figure()
|
| 730 |
loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
|
| 731 |
loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
|
| 732 |
|
|
|
|
| 733 |
st.plotly_chart(accuracy_fig)
|
| 734 |
st.plotly_chart(loss_fig)
|
| 735 |
|
|
|
|
| 736 |
data = {
|
| 737 |
'Round': rounds,
|
| 738 |
'Accuracy': accuracies,
|
|
|
|
| 760 |
|
| 761 |
st.success("Training completed successfully!")
|
| 762 |
|
|
|
|
| 763 |
st.write("## Final Client Metrics")
|
| 764 |
for client in clients:
|
| 765 |
st.write(f"### Client {client.client_id}")
|
|
|
|
| 772 |
|
| 773 |
st.write(" ")
|
| 774 |
|
|
|
|
| 775 |
st.write("## Training Log")
|
| 776 |
st.write(read_log_file2())
|
| 777 |
|