Nathan12 commited on
Commit
6a77094
·
1 Parent(s): 4e6f440

update app

Browse files
Files changed (1) hide show
  1. app.py +142 -20
app.py CHANGED
@@ -39,34 +39,156 @@ def optimize_model(input_model, sparsity, context, criteria):
39
 
40
  return comp_path
41
 
42
- def main_interface(model_file, sparsity, action):
43
- if action == 'Speed':
44
- return optimize_model(model_file, sparsity, 'local', "large_final")
 
45
 
46
- if action == 'Size':
47
- return optimize_model(model_file, sparsity, 'global', "large_final")
 
 
 
48
 
49
- if action == 'Consumption':
50
- return optimize_model(model_file, sparsity, 'local', "random")
51
- else:
52
- return "Action not supported"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- granularity = ['weight', 'filter']
56
- context = ['local', 'global']
57
- criteria = ['large_final', 'random']
58
 
 
59
 
60
  iface = gr.Interface(
61
  fn=main_interface,
62
- inputs= [
63
- gr.File(label="Upload your PyTorch model (.pth file)"),
64
- gr.Slider(label="Compression Level", minimum=0, maximum=100),
65
- gr.Radio(["Speed", "Size", "Consumption"], label="Select Action")
66
- ],
67
- outputs=gr.File(label="Download Compressed Model"),
68
- title="FasterAI",
69
- description="Upload your neural network model (.pt file) and receive a compressed version.",
 
 
 
70
  )
71
 
72
  iface.launch()
 
39
 
40
  return comp_path
41
 
42
+ import matplotlib.pyplot as plt
43
+ import seaborn as sns
44
+ import io
45
+ import numpy as np
46
 
47
+ def get_model_size(model_path):
48
+ """Get model size in MB"""
49
+ size_bytes = os.path.getsize(model_path)
50
+ size_mb = size_bytes / (1024 * 1024)
51
+ return round(size_mb, 2)
52
 
53
+ def create_size_comparison_plot(original_size, compressed_size):
54
+ """Create a bar plot comparing model sizes"""
55
+ # Set seaborn style
56
+ sns.set_style("darkgrid")
57
+
58
+ # Create figure with higher DPI for better resolution
59
+ fig = plt.figure(figsize=(10, 6), dpi=150)
60
+
61
+ # Set transparent background
62
+ fig.patch.set_alpha(0.0)
63
+ ax = plt.gca()
64
+ ax.patch.set_alpha(0.0)
65
+
66
+ # Plot bars with custom colors and alpha
67
+ bars = plt.bar(['Original', 'Compressed'],
68
+ [original_size, compressed_size],
69
+ color=['#FF6B00', '#FF9F1C'],
70
+ alpha=0.8,
71
+ width=0.6)
72
+
73
+ # Add size labels on top of bars with improved styling
74
+ for bar in bars:
75
+ height = bar.get_height()
76
+ plt.text(bar.get_x() + bar.get_width()/2., height + (height * 0.01),
77
+ f'{height:.2f} MB',
78
+ ha='center', va='bottom',
79
+ fontsize=11,
80
+ fontweight='bold',
81
+ color='white')
82
+
83
+ # Calculate compression percentage
84
+ compression_ratio = ((original_size - compressed_size) / original_size) * 100
85
+
86
+ # Customize title and labels with better visibility
87
+ plt.title(f'Model Size Comparison\nCompression: {compression_ratio:.1f}%',
88
+ fontsize=14,
89
+ fontweight='bold',
90
+ pad=20,
91
+ color='white')
92
+
93
+ plt.xlabel('Model Version',
94
+ fontsize=12,
95
+ fontweight='bold',
96
+ labelpad=10,
97
+ color='white')
98
+
99
+ plt.ylabel('Size (MB)',
100
+ fontsize=12,
101
+ fontweight='bold',
102
+ labelpad=10,
103
+ color='white')
104
+
105
+ # Customize grid
106
+ ax.grid(alpha=0.2, color='gray')
107
+
108
+ # Remove top and right spines
109
+ sns.despine()
110
+
111
+ # Set y-axis limits with some padding
112
+ max_value = max(original_size, compressed_size)
113
+ plt.ylim(0, max_value * 1.2)
114
+
115
+ # Add more y-axis ticks
116
+ plt.yticks(np.linspace(0, max_value * 1.2, 10))
117
+
118
+ # Make tick labels white
119
+ ax.tick_params(colors='white')
120
+ for spine in ax.spines.values():
121
+ spine.set_color('white')
122
+
123
+ # Format axes with white text
124
+ ax.xaxis.label.set_color('white')
125
+ ax.yaxis.label.set_color('white')
126
+ ax.tick_params(axis='x', colors='white')
127
+ ax.tick_params(axis='y', colors='white')
128
+
129
+ # Format y-axis tick labels
130
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}'))
131
+
132
+ # Adjust layout to prevent label cutoff
133
+ plt.tight_layout()
134
+
135
+ return fig
136
 
137
+ def main_interface(model_name, sparsity, action):
138
+ import torchvision.models as models
139
+
140
+ model_mapping = {
141
+ 'ResNet18': models.resnet18(pretrained=True),
142
+ 'ResNet50': models.resnet50(pretrained=True),
143
+ 'MobileNetV2': models.mobilenet_v2(pretrained=True),
144
+ 'EfficientNet-B0': models.efficientnet_b0(pretrained=True),
145
+ 'VGG16': models.vgg16(pretrained=True),
146
+ 'DenseNet121': models.densenet121(pretrained=True)
147
+ }
148
+
149
+ model = model_mapping[model_name]
150
+
151
+ # Save model temporarily
152
+ temp_path = "./temp_model.pth"
153
+ torch.save(model, temp_path)
154
+
155
+ original_size = get_model_size(temp_path)
156
+
157
+ try:
158
+ if action == 'Speed':
159
+ compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final")
160
+ elif action == 'Size':
161
+ compressed_path = optimize_model(temp_path, sparsity, 'global', "large_final")
162
+ elif action == 'Consumption':
163
+ compressed_path = optimize_model(temp_path, sparsity, 'local', "random")
164
+ else:
165
+ return None, None
166
+
167
+ compressed_size = get_model_size(compressed_path)
168
+ size_plot = create_size_comparison_plot(original_size, compressed_size)
169
+
170
+ return compressed_path, size_plot
171
+ finally:
172
+ # Clean up temporary file
173
+ if os.path.exists(temp_path):
174
+ os.remove(temp_path)
175
 
 
 
 
176
 
177
+ available_models = ['ResNet18', 'ResNet50', 'MobileNetV2', 'EfficientNet-B0', 'VGG16', 'DenseNet121']
178
 
179
  iface = gr.Interface(
180
  fn=main_interface,
181
+ inputs=[
182
+ gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
183
+ gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50),
184
+ gr.Radio(["Speed", "Size", "Consumption"], label="Select Action", value="Speed")
185
+ ],
186
+ outputs=[
187
+ gr.File(label="Download Compressed Model"),
188
+ gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot
189
+ ],
190
+ title="FasterAI Compressor",
191
+ description="Select a pre-trained PyTorch model to compress using our optimization techniques.",
192
  )
193
 
194
  iface.launch()