Sadjad Alikhani commited on
Commit
d1b5811
·
verified ·
1 Parent(s): a392854

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -52
app.py CHANGED
@@ -7,14 +7,16 @@ import io
7
  import sys
8
  import torch
9
  import subprocess
 
 
 
10
 
11
  # Paths to the predefined images folder
12
  RAW_PATH = os.path.join("images", "raw")
13
  EMBEDDINGS_PATH = os.path.join("images", "embeddings")
14
 
15
- # Specific values for percentage and complexity
16
  percentage_values = [10, 30, 50, 70, 100]
17
- complexity_values = [16, 32]
18
 
19
  # Custom class to capture print output
20
  class PrintCapture(io.StringIO):
@@ -30,11 +32,10 @@ class PrintCapture(io.StringIO):
30
  return ''.join(self.output)
31
 
32
  # Function to load and display predefined images based on user selection
33
- def display_predefined_images(percentage_idx, complexity_idx):
34
  percentage = percentage_values[percentage_idx]
35
- complexity = complexity_values[complexity_idx]
36
- raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
37
- embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
38
 
39
  raw_image = Image.open(raw_image_path)
40
  embeddings_image = Image.open(embeddings_image_path)
@@ -62,8 +63,57 @@ def load_module_from_path(module_name, file_path):
62
  spec.loader.exec_module(module)
63
  return module
64
 
65
- # Function to process the uploaded .p file and perform inference using the custom model
66
- def process_p_file(uploaded_file, percentage_idx, complexity_idx):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  capture = PrintCapture()
68
  sys.stdout = capture # Redirect print statements to capture
69
 
@@ -90,51 +140,42 @@ def process_p_file(uploaded_file, percentage_idx, complexity_idx):
90
  input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
91
  inference_path = os.path.join(os.getcwd(), 'inference.py')
92
 
93
- print(lwm_model_path)
94
- print(input_preprocess_path)
95
- print(inference_path)
96
-
97
  # Load lwm_model
98
- if os.path.exists(lwm_model_path):
99
- lwm_model = load_module_from_path("lwm_model", lwm_model_path)
100
- else:
101
- return f"Error: lwm_model.py not found at {lwm_model_path}"
102
 
103
  # Load input_preprocess
104
- if os.path.exists(input_preprocess_path):
105
- input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
106
- else:
107
- return f"Error: input_preprocess.py not found at {input_preprocess_path}"
108
 
109
  # Load inference
110
- if os.path.exists(inference_path):
111
- inference = load_module_from_path("inference", inference_path)
112
- else:
113
- return f"Error: inference.py not found at {inference_path}"
114
 
115
  # Step 4: Load the model from lwm_model module
116
  device = 'cpu'
117
  print(f"Loading the LWM model on {device}...")
118
  model = lwm_model.LWM.from_pretrained(device=device)
119
 
120
- # Step 5: Tokenize the data using the tokenizer from input_preprocess
121
- with open(uploaded_file.name, 'rb') as f:
122
- manual_data = pickle.load(f)
 
 
123
 
124
- preprocessed_chs = input_preprocess.tokenizer(manual_data=manual_data)
 
125
 
126
- # Step 6: Perform inference using the functions from inference.py
127
- output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
128
- output_raw = inference.create_raw_dataset(preprocessed_chs, device)
129
 
130
- print(f"Output Embeddings Shape: {output_emb.shape}")
131
- print(f"Output Raw Shape: {output_raw.shape}")
 
132
 
133
- # Step 7: Generate random images as a test
134
- random_raw_image = create_random_image()
135
- random_embeddings_image = create_random_image()
136
 
137
- return random_raw_image, random_embeddings_image, capture.get_output()
138
 
139
  except Exception as e:
140
  return str(e), str(e), capture.get_output()
@@ -143,11 +184,11 @@ def process_p_file(uploaded_file, percentage_idx, complexity_idx):
143
  sys.stdout = sys.__stdout__ # Reset print statements
144
 
145
  # Function to handle logic based on whether a file is uploaded or not
146
- def los_nlos_classification(file, percentage_idx, complexity_idx):
147
  if file is not None:
148
- return process_p_file(file, percentage_idx, complexity_idx)
149
  else:
150
- return display_predefined_images(percentage_idx, complexity_idx), None
151
 
152
  # Define the Gradio interface
153
  with gr.Blocks(css="""
@@ -183,38 +224,30 @@ with gr.Blocks(css="""
183
  with gr.Column(elem_id="slider-container"):
184
  gr.Markdown("Percentage of Data for Training")
185
  percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
186
- with gr.Column(elem_id="slider-container"):
187
- gr.Markdown("Task Complexity")
188
- complexity_slider_bp = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
189
 
190
  with gr.Row():
191
  raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
192
  embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
193
 
194
- percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
195
- complexity_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
196
 
197
  with gr.Tab("LoS/NLoS Classification Task"):
198
  gr.Markdown("### LoS/NLoS Classification Task")
199
 
200
- file_input = gr.File(label="Upload .p File", file_types=[".p"])
201
 
202
  with gr.Row():
203
  with gr.Column(elem_id="slider-container"):
204
  gr.Markdown("Percentage of Data for Training")
205
  percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
206
- with gr.Column(elem_id="slider-container"):
207
- gr.Markdown("Task Complexity")
208
- complexity_slider_los = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
209
 
210
  with gr.Row():
211
  raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
212
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
213
  output_textbox = gr.Textbox(label="Console Output", lines=10)
214
 
215
- file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
216
- percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
217
- complexity_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
218
 
219
  # Launch the app
220
  if __name__ == "__main__":
 
7
  import sys
8
  import torch
9
  import subprocess
10
+ import h5py
11
+ from sklearn.metrics import confusion_matrix
12
+ import matplotlib.pyplot as plt
13
 
14
  # Paths to the predefined images folder
15
  RAW_PATH = os.path.join("images", "raw")
16
  EMBEDDINGS_PATH = os.path.join("images", "embeddings")
17
 
18
+ # Specific values for percentage of data for training
19
  percentage_values = [10, 30, 50, 70, 100]
 
20
 
21
  # Custom class to capture print output
22
  class PrintCapture(io.StringIO):
 
32
  return ''.join(self.output)
33
 
34
  # Function to load and display predefined images based on user selection
35
+ def display_predefined_images(percentage_idx):
36
  percentage = percentage_values[percentage_idx]
37
+ raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_16.png") # Assume complexity 16 for simplicity
38
+ embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_16.png")
 
39
 
40
  raw_image = Image.open(raw_image_path)
41
  embeddings_image = Image.open(embeddings_image_path)
 
63
  spec.loader.exec_module(module)
64
  return module
65
 
66
+ # Function to split dataset into training and test sets based on user selection
67
+ def split_dataset(channels, labels, percentage_idx):
68
+ percentage = percentage_values[percentage_idx] / 100
69
+ num_samples = channels.shape[0]
70
+ train_size = int(num_samples * percentage)
71
+ print(f'Number of Training Samples: {train_size}')
72
+
73
+ indices = np.arange(num_samples)
74
+ np.random.shuffle(indices)
75
+
76
+ train_idx, test_idx = indices[:train_size], indices[train_size:]
77
+
78
+ train_data, test_data = channels[train_idx], channels[test_idx]
79
+ train_labels, test_labels = labels[train_idx], labels[test_idx]
80
+
81
+ return train_data, test_data, train_labels, test_labels
82
+
83
+ # Function to calculate Euclidean distance between a point and a centroid
84
+ def euclidean_distance(x, centroid):
85
+ return np.linalg.norm(x - centroid)
86
+
87
+ # Function to classify test data based on distance to class centroids
88
+ def classify_based_on_distance(train_data, train_labels, test_data):
89
+ centroid_0 = np.mean(train_data[train_labels == 0], axis=0)
90
+ centroid_1 = np.mean(train_data[train_labels == 1], axis=0)
91
+
92
+ predictions = []
93
+ for test_point in test_data:
94
+ dist_0 = euclidean_distance(test_point, centroid_0)
95
+ dist_1 = euclidean_distance(test_point, centroid_1)
96
+ predictions.append(0 if dist_0 < dist_1 else 1)
97
+
98
+ return np.array(predictions)
99
+
100
+ # Function to generate confusion matrix plot
101
+ def plot_confusion_matrix(y_true, y_pred, title):
102
+ cm = confusion_matrix(y_true, y_pred)
103
+ plt.figure(figsize=(5, 5))
104
+ plt.imshow(cm, cmap='Blues')
105
+ plt.title(title)
106
+ plt.xlabel('Predicted')
107
+ plt.ylabel('Actual')
108
+ plt.colorbar()
109
+ plt.xticks([0, 1], labels=[0, 1])
110
+ plt.yticks([0, 1], labels=[0, 1])
111
+ plt.tight_layout()
112
+ plt.savefig(f"{title}.png")
113
+ return Image.open(f"{title}.png")
114
+
115
+ # Function to process the uploaded HDF5 file and perform classification using the custom model
116
+ def process_hdf5_file(uploaded_file, percentage_idx):
117
  capture = PrintCapture()
118
  sys.stdout = capture # Redirect print statements to capture
119
 
 
140
  input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
141
  inference_path = os.path.join(os.getcwd(), 'inference.py')
142
 
 
 
 
 
143
  # Load lwm_model
144
+ lwm_model = load_module_from_path("lwm_model", lwm_model_path)
 
 
 
145
 
146
  # Load input_preprocess
147
+ input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
 
 
 
148
 
149
  # Load inference
150
+ inference = load_module_from_path("inference", inference_path)
 
 
 
151
 
152
  # Step 4: Load the model from lwm_model module
153
  device = 'cpu'
154
  print(f"Loading the LWM model on {device}...")
155
  model = lwm_model.LWM.from_pretrained(device=device)
156
 
157
+ # Step 5: Load the HDF5 file and extract the channels and labels
158
+ with h5py.File(uploaded_file.name, 'r') as f:
159
+ channels = np.array(f['channels']) # Assuming 'channels' dataset in the HDF5 file
160
+ labels = np.array(f['labels']) # Assuming 'labels' dataset in the HDF5 file
161
+ print(f"Loaded dataset with {channels.shape[0]} samples.")
162
 
163
+ # Step 6: Split the dataset into training and test sets
164
+ train_data_raw, test_data_raw, train_labels, test_labels = split_dataset(channels, labels, percentage_idx)
165
 
166
+ # Step 7: Tokenize the data using the tokenizer from input_preprocess
167
+ preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
168
+ train_data_emb, test_data_emb, _, _ = split_dataset(preprocessed_chs, labels, percentage_idx)
169
 
170
+ # Step 8: Perform classification using the Euclidean distance for both raw and embeddings
171
+ pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
172
+ pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
173
 
174
+ # Step 9: Generate confusion matrices for both raw and embeddings
175
+ raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
176
+ emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
177
 
178
+ return raw_cm_image, emb_cm_image, capture.get_output()
179
 
180
  except Exception as e:
181
  return str(e), str(e), capture.get_output()
 
184
  sys.stdout = sys.__stdout__ # Reset print statements
185
 
186
  # Function to handle logic based on whether a file is uploaded or not
187
+ def los_nlos_classification(file, percentage_idx):
188
  if file is not None:
189
+ return process_hdf5_file(file, percentage_idx)
190
  else:
191
+ return display_predefined_images(percentage_idx), None
192
 
193
  # Define the Gradio interface
194
  with gr.Blocks(css="""
 
224
  with gr.Column(elem_id="slider-container"):
225
  gr.Markdown("Percentage of Data for Training")
226
  percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
 
 
 
227
 
228
  with gr.Row():
229
  raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
230
  embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
231
 
232
+ percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
 
233
 
234
  with gr.Tab("LoS/NLoS Classification Task"):
235
  gr.Markdown("### LoS/NLoS Classification Task")
236
 
237
+ file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
238
 
239
  with gr.Row():
240
  with gr.Column(elem_id="slider-container"):
241
  gr.Markdown("Percentage of Data for Training")
242
  percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
 
 
 
243
 
244
  with gr.Row():
245
  raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
246
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
247
  output_textbox = gr.Textbox(label="Console Output", lines=10)
248
 
249
+ file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
250
+ percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
 
251
 
252
  # Launch the app
253
  if __name__ == "__main__":