Sadjad Alikhani commited on
Commit
0176215
·
verified ·
1 Parent(s): 148ab33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -113
app.py CHANGED
@@ -1,118 +1,52 @@
1
- import gradio as gr
2
- import os
3
- from PIL import Image
4
-
5
- # Paths to the images folder
6
- RAW_PATH = os.path.join("images", "raw")
7
- EMBEDDINGS_PATH = os.path.join("images", "embeddings")
8
-
9
- # Specific values for percentage and complexity
10
- percentage_values = [10, 30, 50, 70, 100]
11
- complexity_values = [16, 32]
12
-
13
- # Function to load and display images based on user selection
14
- def display_images(percentage_idx, complexity_idx):
15
- # Map the slider index to the actual value
16
- percentage = percentage_values[percentage_idx]
17
- complexity = complexity_values[complexity_idx]
18
-
19
- # Generate the paths to the images
20
- raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
21
- embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
22
-
23
- # Load images using PIL
24
- raw_image = Image.open(raw_image_path)
25
- embeddings_image = Image.open(embeddings_image_path)
26
-
27
- # Return the loaded images
28
- return raw_image, embeddings_image
29
-
30
- # Define the beam prediction function (template based)
31
- def beam_prediction(percentage_idx, complexity_idx):
32
- # Add your beam prediction logic here (this is placeholder code)
33
- raw_img, embeddings_img = display_images(percentage_idx, complexity_idx)
34
- return raw_img, embeddings_img
35
-
36
- # Define the LoS/NLoS classification function (template based)
37
- def los_nlos_classification(uploaded_file, percentage_idx, complexity_idx):
38
- # Placeholder code for processing the uploaded .py file (can be extended)
39
- # Add your LoS/NLoS classification logic here
40
- raw_img, embeddings_img = display_images(percentage_idx, complexity_idx)
41
- return raw_img, embeddings_img
42
-
43
- # Define the Gradio interface
44
- with gr.Blocks(css="""
45
- .vertical-slider input[type=range] {
46
- writing-mode: bt-lr; /* IE */
47
- -webkit-appearance: slider-vertical; /* WebKit */
48
- width: 8px;
49
- height: 200px;
50
- }
51
- .slider-container {
52
- display: inline-block;
53
- margin-right: 50px;
54
- text-align: center;
55
- }
56
- """) as demo:
57
-
58
- # Contact Section
59
- gr.Markdown(
60
- """
61
- ## Contact
62
- <div style="display: flex; align-items: center;">
63
- <a target="_blank" href="mailto:[email protected]"><img src="https://img.shields.io/badge/[email protected]?logo=gmail " alt="Email"></a>&nbsp;&nbsp;
64
- <a target="_blank" href="https://telegram.me/wirelessmodel"><img src="https://img.shields.io/badge/[email protected]?logo=telegram " alt="Telegram"></a>&nbsp;&nbsp;
65
- </div>
66
- """
67
- )
68
-
69
- # Tabs for Beam Prediction and LoS/NLoS Classification
70
- with gr.Tab("Beam Prediction Task"):
71
- gr.Markdown("### Beam Prediction Task")
72
 
73
- # Sliders for percentage and complexity
74
- with gr.Row():
75
- with gr.Column(elem_id="slider-container"):
76
- gr.Markdown("Percentage of Data for Training")
77
- percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
78
- with gr.Column(elem_id="slider-container"):
79
- gr.Markdown("Task Complexity")
80
- complexity_slider_bp = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
81
-
82
- # Image outputs (display the images side by side and set a smaller size for the images)
83
- with gr.Row():
84
- raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
85
- embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
86
-
87
- # Instant image updates when sliders change
88
- percentage_slider_bp.change(fn=beam_prediction, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
89
- complexity_slider_bp.change(fn=beam_prediction, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
90
-
91
- with gr.Tab("LoS/NLoS Classification Task"):
92
- gr.Markdown("### LoS/NLoS Classification Task")
93
 
94
- # File uploader for uploading .py file
95
- file_input = gr.File(label="Upload .py File", file_types=[".py"])
96
-
97
- # Sliders for percentage and complexity
98
- with gr.Row():
99
- with gr.Column(elem_id="slider-container"):
100
- gr.Markdown("Percentage of Data for Training")
101
- percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
102
- with gr.Column(elem_id="slider-container"):
103
- gr.Markdown("Task Complexity")
104
- complexity_slider_los = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
105
 
106
- # Image outputs (display the images side by side and set a smaller size for the images)
107
- with gr.Row():
108
- raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
109
- embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
110
 
111
- # Instant image updates when sliders or file input change
112
- file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los])
113
- percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los])
114
- complexity_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los])
115
 
116
- # Launch the app
117
- if __name__ == "__main__":
118
- demo.launch()
 
1
+ import torch
2
+ from transformers import AutoModel # Assuming you use a transformer-like model in your LWM repo
3
+ import numpy as np
4
+ import importlib.util
5
+
6
+ # Function to load the pre-trained model from Hugging Face
7
+ def load_pretrained_model():
8
+ # Load the pre-trained model from the Hugging Face repo
9
+ model = AutoModel.from_pretrained("sadjadalikhani/LWM")
10
+ model.eval() # Set model to evaluation mode
11
+ return model
12
+
13
+ # Function to process the uploaded .py file and perform inference using the model
14
+ def process_python_file(uploaded_file, percentage_idx, complexity_idx):
15
+ try:
16
+ # Step 1: Load the model
17
+ model = load_pretrained_model()
18
+
19
+ # Step 2: Load the uploaded .py file that contains the wireless channel matrix
20
+ # Import the Python file dynamically
21
+ spec = importlib.util.spec_from_file_location("uploaded_module", uploaded_file.name)
22
+ uploaded_module = importlib.util.module_from_spec(spec)
23
+ spec.loader.exec_module(uploaded_module)
24
+
25
+ # Assuming the uploaded file defines a variable called 'channel_matrix'
26
+ channel_matrix = uploaded_module.channel_matrix # This should be defined in the uploaded file
27
+
28
+ # Step 3: Perform inference on the channel matrix using the model
29
+ with torch.no_grad():
30
+ input_tensor = torch.tensor(channel_matrix).unsqueeze(0) # Add batch dimension
31
+ output = model(input_tensor) # Perform inference
32
+
33
+ # Step 4: Generate new images based on the inference results
34
+ # You can modify this logic depending on how you want to visualize the results
35
+ generated_raw_img = np.random.rand(300, 300, 3) * 255 # Placeholder: Replace with actual inference result
36
+ generated_embeddings_img = np.random.rand(300, 300, 3) * 255 # Placeholder: Replace with actual inference result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Save the generated images
39
+ generated_raw_image_path = os.path.join(GENERATED_PATH, f"generated_raw_{percentage_idx}_{complexity_idx}.png")
40
+ generated_embeddings_image_path = os.path.join(GENERATED_PATH, f"generated_embeddings_{percentage_idx}_{complexity_idx}.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ Image.fromarray(generated_raw_img.astype(np.uint8)).save(generated_raw_image_path)
43
+ Image.fromarray(generated_embeddings_img.astype(np.uint8)).save(generated_embeddings_image_path)
 
 
 
 
 
 
 
 
 
44
 
45
+ # Load the generated images
46
+ raw_image = Image.open(generated_raw_image_path)
47
+ embeddings_image = Image.open(generated_embeddings_image_path)
 
48
 
49
+ return raw_image, embeddings_image
 
 
 
50
 
51
+ except Exception as e:
52
+ return str(e), str(e)