Nathan12 commited on
Commit
e167cea
·
1 Parent(s): 13cf301

update compressor

Browse files
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +182 -0
  2. app.py +2 -12
.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fasterai
2
+ from fasterai.sparse.all import *
3
+ from fasterai.prune.all import *
4
+ import torch
5
+ import gradio as gr
6
+ import os
7
+ from torch.ao.quantization import get_default_qconfig_mapping
8
+ import torch.ao.quantization.quantize_fx as quantize_fx
9
+ from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
10
+
11
+ class Quant():
12
+ def __init__(self, backend="x86"):
13
+ self.qconfig = get_default_qconfig_mapping(backend)
14
+
15
+ def quantize(self, model):
16
+ x = torch.randn(3, 224, 224)
17
+ model_prepared = prepare_fx(model.eval(), self.qconfig, x)
18
+ return convert_fx(model_prepared)
19
+
20
+
21
+ def optimize_model(input_model, sparsity, context, criteria):
22
+
23
+ model = torch.load(input_model, weights_only=False)
24
+ model = model.eval()
25
+ model = model.to('cpu')
26
+ sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
27
+ sp.sparsify_model(sparsity)
28
+ sp._clean_buffers()
29
+ pr = Pruner(model, sparsity, context, criteria=eval(criteria))
30
+ pr.prune_model()
31
+ qu = Quant()
32
+ qu_model = qu.quantize(model)
33
+
34
+ comp_path = "./comp_model.pth"
35
+
36
+ scripted = torch.jit.script(qu_model)
37
+ torch.jit.save(scripted, comp_path)
38
+ #torch.save(qu_model, comp_path)
39
+
40
+ return comp_path
41
+
42
+ import matplotlib.pyplot as plt
43
+ import seaborn as sns
44
+ import io
45
+ import numpy as np
46
+
47
+ def get_model_size(model_path):
48
+ """Get model size in MB"""
49
+ size_bytes = os.path.getsize(model_path)
50
+ size_mb = size_bytes / (1024 * 1024)
51
+ return round(size_mb, 2)
52
+
53
+ def create_size_comparison_plot(original_size, compressed_size):
54
+ """Create a bar plot comparing model sizes"""
55
+ # Set seaborn style
56
+ sns.set_style("darkgrid")
57
+
58
+ # Create figure with higher DPI for better resolution
59
+ fig = plt.figure(figsize=(10, 6), dpi=150)
60
+
61
+ # Set transparent background
62
+ fig.patch.set_alpha(0.0)
63
+ ax = plt.gca()
64
+ ax.patch.set_alpha(0.0)
65
+
66
+ # Plot bars with custom colors and alpha
67
+ bars = plt.bar(['Original', 'Compressed'],
68
+ [original_size, compressed_size],
69
+ color=['#FF6B00', '#FF9F1C'],
70
+ alpha=0.8,
71
+ width=0.6)
72
+
73
+ # Add size labels on top of bars with improved styling
74
+ for bar in bars:
75
+ height = bar.get_height()
76
+ plt.text(bar.get_x() + bar.get_width()/2., height + (height * 0.01),
77
+ f'{height:.2f} MB',
78
+ ha='center', va='bottom',
79
+ fontsize=11,
80
+ fontweight='bold',
81
+ color='white')
82
+
83
+ # Calculate compression percentage
84
+ compression_ratio = ((original_size - compressed_size) / original_size) * 100
85
+
86
+ # Customize title and labels with better visibility
87
+ plt.title(f'Model Size Comparison\nCompression: {compression_ratio:.1f}%',
88
+ fontsize=14,
89
+ fontweight='bold',
90
+ pad=20,
91
+ color='white')
92
+
93
+ plt.xlabel('Model Version',
94
+ fontsize=12,
95
+ fontweight='bold',
96
+ labelpad=10,
97
+ color='white')
98
+
99
+ plt.ylabel('Size (MB)',
100
+ fontsize=12,
101
+ fontweight='bold',
102
+ labelpad=10,
103
+ color='white')
104
+
105
+ # Customize grid
106
+ ax.grid(alpha=0.2, color='gray')
107
+
108
+ # Remove top and right spines
109
+ sns.despine()
110
+
111
+ # Set y-axis limits with some padding
112
+ max_value = max(original_size, compressed_size)
113
+ plt.ylim(0, max_value * 1.2)
114
+
115
+ # Add more y-axis ticks
116
+ plt.yticks(np.linspace(0, max_value * 1.2, 10))
117
+
118
+ # Make tick labels white
119
+ ax.tick_params(colors='white')
120
+ for spine in ax.spines.values():
121
+ spine.set_color('white')
122
+
123
+ # Format axes with white text
124
+ ax.xaxis.label.set_color('white')
125
+ ax.yaxis.label.set_color('white')
126
+ ax.tick_params(axis='x', colors='white')
127
+ ax.tick_params(axis='y', colors='white')
128
+
129
+ # Format y-axis tick labels
130
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}'))
131
+
132
+ # Adjust layout to prevent label cutoff
133
+ plt.tight_layout()
134
+
135
+ return fig
136
+
137
+ def main_interface(model_name, sparsity, action):
138
+ import torchvision.models as models
139
+
140
+ model_mapping = {
141
+ 'ResNet18': models.resnet18(pretrained=False),
142
+ 'ResNet50': models.resnet50(pretrained=False),
143
+ 'MobileNetV2': models.mobilenet_v2(pretrained=False),
144
+ 'EfficientNet-B0': models.efficientnet_b0(pretrained=False),
145
+ 'VGG16': models.vgg16(pretrained=False),
146
+ 'DenseNet121': models.densenet121(pretrained=False)
147
+ }
148
+
149
+ model = model_mapping[model_name]
150
+
151
+ # Save model temporarily
152
+ temp_path = "./temp_model.pth"
153
+ torch.save(model, temp_path)
154
+
155
+ original_size = get_model_size(temp_path)
156
+
157
+ try:
158
+ compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final")
159
+ compressed_size = get_model_size(compressed_path)
160
+ size_plot = create_size_comparison_plot(original_size, compressed_size)
161
+
162
+ return compressed_path, size_plot
163
+ finally:
164
+ # Clean up temporary file
165
+ if os.path.exists(temp_path):
166
+ os.remove(temp_path)
167
+
168
+
169
+ available_models = ['ResNet18', 'ResNet50', 'MobileNetV2', 'EfficientNet-B0', 'VGG16', 'DenseNet121']
170
+
171
+ iface = gr.Interface(
172
+ fn=main_interface,
173
+ inputs=[
174
+ gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
175
+ gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50),
176
+ ],
177
+ outputs=[
178
+ gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot
179
+ ],
180
+ )
181
+
182
+ iface.launch()
app.py CHANGED
@@ -20,7 +20,7 @@ class Quant():
20
 
21
  def optimize_model(input_model, sparsity, context, criteria):
22
 
23
- model = torch.load(input_model)
24
  model = model.eval()
25
  model = model.to('cpu')
26
  sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
@@ -155,15 +155,7 @@ def main_interface(model_name, sparsity, action):
155
  original_size = get_model_size(temp_path)
156
 
157
  try:
158
- if action == 'Speed':
159
- compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final")
160
- elif action == 'Size':
161
- compressed_path = optimize_model(temp_path, sparsity, 'global', "large_final")
162
- elif action == 'Consumption':
163
- compressed_path = optimize_model(temp_path, sparsity, 'local', "random")
164
- else:
165
- return None, None
166
-
167
  compressed_size = get_model_size(compressed_path)
168
  size_plot = create_size_comparison_plot(original_size, compressed_size)
169
 
@@ -181,10 +173,8 @@ iface = gr.Interface(
181
  inputs=[
182
  gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
183
  gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50),
184
- gr.Radio(["Speed", "Size", "Consumption"], label="Select Action", value="Speed")
185
  ],
186
  outputs=[
187
- gr.File(label="Download Compressed Model"),
188
  gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot
189
  ],
190
  )
 
20
 
21
  def optimize_model(input_model, sparsity, context, criteria):
22
 
23
+ model = torch.load(input_model, weights_only=False)
24
  model = model.eval()
25
  model = model.to('cpu')
26
  sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
 
155
  original_size = get_model_size(temp_path)
156
 
157
  try:
158
+ compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final")
 
 
 
 
 
 
 
 
159
  compressed_size = get_model_size(compressed_path)
160
  size_plot = create_size_comparison_plot(original_size, compressed_size)
161
 
 
173
  inputs=[
174
  gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
175
  gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50),
 
176
  ],
177
  outputs=[
 
178
  gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot
179
  ],
180
  )