Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -403,19 +403,16 @@ def process_hdf5_file(uploaded_file, percentage):
|
|
403 |
# Step 4: Load the model from lwm_model module
|
404 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
405 |
print(f"Loading the LWM model on {device}...")
|
406 |
-
model = lwm_model.
|
407 |
-
#for name, param in model.state_dict().items():
|
408 |
-
# print(f"Layer: {name} | Weights: {param}")
|
409 |
|
410 |
# Step 5: Load the HDF5 file and extract the channels and labels
|
411 |
with h5py.File(uploaded_file.name, 'r') as f:
|
412 |
-
channels = np.array(f['channels']).astype(np.complex64)
|
413 |
-
labels = np.array(f['labels']).astype(np.int32)
|
414 |
print(f"Loaded dataset with {channels.shape[0]} samples.")
|
415 |
|
416 |
# Step 7: Tokenize the data using the tokenizer from input_preprocess
|
417 |
preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
|
418 |
-
#print(preprocessed_chs[0][0][-1]) #CORRECT
|
419 |
|
420 |
# Step 7: Perform inference using the functions from inference.py
|
421 |
output_emb = inference.lwm_inference(preprocessed_chs, 'cls_emb', model, device)
|
@@ -424,8 +421,6 @@ def process_hdf5_file(uploaded_file, percentage):
|
|
424 |
print(f"Output Embeddings Shape: {output_emb.shape}")
|
425 |
print(f"Output Raw Shape: {output_raw.shape}")
|
426 |
|
427 |
-
#print(f'percentage_idx: {percentage_idx}')
|
428 |
-
#print(f'percentage_value: {percentage_values_los[percentage_idx]}')
|
429 |
print(f'percentage_value: {percentage}')
|
430 |
train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(output_emb.view(len(output_emb),-1),
|
431 |
output_raw.view(len(output_raw),-1),
|
@@ -438,8 +433,7 @@ def process_hdf5_file(uploaded_file, percentage):
|
|
438 |
print(f'test_data_emb: {test_data_emb.shape}')
|
439 |
pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
|
440 |
pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
|
441 |
-
|
442 |
-
#print(f'actual labels: {test_labels}')
|
443 |
# Step 9: Generate confusion matrices for both raw and embeddings
|
444 |
raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
|
445 |
emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
|
|
|
403 |
# Step 4: Load the model from lwm_model module
|
404 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
405 |
print(f"Loading the LWM model on {device}...")
|
406 |
+
model = lwm_model.lwm.from_pretrained(device=device).float()
|
|
|
|
|
407 |
|
408 |
# Step 5: Load the HDF5 file and extract the channels and labels
|
409 |
with h5py.File(uploaded_file.name, 'r') as f:
|
410 |
+
channels = np.array(f['channels']).astype(np.complex64)
|
411 |
+
labels = np.array(f['labels']).astype(np.int32)
|
412 |
print(f"Loaded dataset with {channels.shape[0]} samples.")
|
413 |
|
414 |
# Step 7: Tokenize the data using the tokenizer from input_preprocess
|
415 |
preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
|
|
|
416 |
|
417 |
# Step 7: Perform inference using the functions from inference.py
|
418 |
output_emb = inference.lwm_inference(preprocessed_chs, 'cls_emb', model, device)
|
|
|
421 |
print(f"Output Embeddings Shape: {output_emb.shape}")
|
422 |
print(f"Output Raw Shape: {output_raw.shape}")
|
423 |
|
|
|
|
|
424 |
print(f'percentage_value: {percentage}')
|
425 |
train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(output_emb.view(len(output_emb),-1),
|
426 |
output_raw.view(len(output_raw),-1),
|
|
|
433 |
print(f'test_data_emb: {test_data_emb.shape}')
|
434 |
pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
|
435 |
pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
|
436 |
+
|
|
|
437 |
# Step 9: Generate confusion matrices for both raw and embeddings
|
438 |
raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
|
439 |
emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
|