Mariam-Elz commited on
Commit
d8baef5
·
verified ·
1 Parent(s): 31e5ac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -8
app.py CHANGED
@@ -77,7 +77,7 @@
77
 
78
  # if __name__ == "__main__":
79
  # demo.launch()
80
- ########################3rd######################3
81
 
82
  # import torch
83
  # import gradio as gr
@@ -138,7 +138,65 @@
138
 
139
 
140
  #################4th##################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  import torch
 
142
  import gradio as gr
143
  import requests
144
  import os
@@ -162,11 +220,26 @@ for filename, output_path in model_files.items():
162
  with open(output_path, "wb") as f:
163
  f.write(response.content)
164
 
165
- # Load model with low memory usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def load_model():
 
167
  model_path = "models/CRM.pth"
168
- model = torch.load(model_path, map_location="cpu") # Load on CPU to reduce memory usage
169
- model.eval()
170
  return model
171
 
172
  model = load_model()
@@ -175,10 +248,10 @@ model = load_model()
175
  def infer(image):
176
  """Process input image and return a reconstructed image."""
177
  with torch.no_grad():
178
- image_tensor = torch.tensor(image).unsqueeze(0) # Add batch dimension
179
- image_tensor = image_tensor.to("cpu") # Keep on CPU to save memory
180
- output = model(image_tensor)
181
- return output.squeeze(0).numpy()
182
 
183
  # Create Gradio UI
184
  demo = gr.Interface(
 
77
 
78
  # if __name__ == "__main__":
79
  # demo.launch()
80
+ ########################3rd-MAIN######################3
81
 
82
  # import torch
83
  # import gradio as gr
 
138
 
139
 
140
  #################4th##################
141
+
142
+ # import torch
143
+ # import gradio as gr
144
+ # import requests
145
+ # import os
146
+
147
+ # # Define model repo
148
+ # model_repo = "Mariam-Elz/CRM"
149
+
150
+ # # Define model files and download paths
151
+ # model_files = {
152
+ # "CRM.pth": "models/CRM.pth"
153
+ # }
154
+
155
+ # os.makedirs("models", exist_ok=True)
156
+
157
+ # # Download model files only if they don't exist
158
+ # for filename, output_path in model_files.items():
159
+ # if not os.path.exists(output_path):
160
+ # url = f"https://huggingface.co/{model_repo}/resolve/main/{filename}"
161
+ # print(f"Downloading {filename}...")
162
+ # response = requests.get(url)
163
+ # with open(output_path, "wb") as f:
164
+ # f.write(response.content)
165
+
166
+ # # Load model with low memory usage
167
+ # def load_model():
168
+ # model_path = "models/CRM.pth"
169
+ # model = torch.load(model_path, map_location="cpu") # Load on CPU to reduce memory usage
170
+ # model.eval()
171
+ # return model
172
+
173
+ # model = load_model()
174
+
175
+ # # Define inference function
176
+ # def infer(image):
177
+ # """Process input image and return a reconstructed image."""
178
+ # with torch.no_grad():
179
+ # image_tensor = torch.tensor(image).unsqueeze(0) # Add batch dimension
180
+ # image_tensor = image_tensor.to("cpu") # Keep on CPU to save memory
181
+ # output = model(image_tensor)
182
+ # return output.squeeze(0).numpy()
183
+
184
+ # # Create Gradio UI
185
+ # demo = gr.Interface(
186
+ # fn=infer,
187
+ # inputs=gr.Image(type="numpy"),
188
+ # outputs=gr.Image(type="numpy"),
189
+ # title="Convolutional Reconstruction Model",
190
+ # description="Upload an image to get the reconstructed output."
191
+ # )
192
+
193
+ # if __name__ == "__main__":
194
+ # demo.launch()
195
+
196
+
197
+ ##############5TH#################
198
  import torch
199
+ import torch.nn as nn
200
  import gradio as gr
201
  import requests
202
  import os
 
220
  with open(output_path, "wb") as f:
221
  f.write(response.content)
222
 
223
+ # Define the model architecture (you MUST replace this with your actual model)
224
+ class CRM_Model(nn.Module):
225
+ def __init__(self):
226
+ super(CRM_Model, self).__init__()
227
+ self.layer1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
228
+ self.relu = nn.ReLU()
229
+ self.layer2 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
230
+
231
+ def forward(self, x):
232
+ x = self.layer1(x)
233
+ x = self.relu(x)
234
+ x = self.layer2(x)
235
+ return x
236
+
237
+ # Load model with proper architecture
238
  def load_model():
239
+ model = CRM_Model() # Instantiate the model architecture
240
  model_path = "models/CRM.pth"
241
+ model.load_state_dict(torch.load(model_path, map_location="cpu")) # Load weights
242
+ model.eval() # Set to evaluation mode
243
  return model
244
 
245
  model = load_model()
 
248
  def infer(image):
249
  """Process input image and return a reconstructed image."""
250
  with torch.no_grad():
251
+ image_tensor = torch.tensor(image).unsqueeze(0).permute(0, 3, 1, 2).float() / 255.0 # Convert to tensor
252
+ output = model(image_tensor) # Run through model
253
+ output = output.squeeze(0).permute(1, 2, 0).numpy() * 255.0 # Convert back to image
254
+ return output.astype("uint8")
255
 
256
  # Create Gradio UI
257
  demo = gr.Interface(