Mariam-Elz commited on
Commit
d3607a8
·
verified ·
1 Parent(s): 1f9fe97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -15
app.py CHANGED
@@ -79,37 +79,93 @@
79
  # demo.launch()
80
  ########################3rd######################3
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  import torch
83
  import gradio as gr
84
  import requests
85
  import os
86
 
87
- # Download model weights from Hugging Face model repo (if not already present)
88
- model_repo = "Mariam-Elz/CRM" # Your Hugging Face model repo
89
 
 
90
  model_files = {
91
- "ccm-diffusion.pth": "ccm-diffusion.pth",
92
- "pixel-diffusion.pth": "pixel-diffusion.pth",
93
- "CRM.pth": "CRM.pth",
94
  }
95
 
96
  os.makedirs("models", exist_ok=True)
97
 
 
98
  for filename, output_path in model_files.items():
99
- file_path = f"models/{output_path}"
100
- if not os.path.exists(file_path):
101
  url = f"https://huggingface.co/{model_repo}/resolve/main/{filename}"
102
  print(f"Downloading {filename}...")
103
  response = requests.get(url)
104
- with open(file_path, "wb") as f:
105
  f.write(response.content)
106
 
107
- # Load model (This part depends on how the model is defined)
108
- device = "cuda" if torch.cuda.is_available() else "cpu"
109
-
110
  def load_model():
111
  model_path = "models/CRM.pth"
112
- model = torch.load(model_path, map_location=device)
113
  model.eval()
114
  return model
115
 
@@ -119,10 +175,10 @@ model = load_model()
119
  def infer(image):
120
  """Process input image and return a reconstructed image."""
121
  with torch.no_grad():
122
- # Assuming model expects a tensor input
123
- image_tensor = torch.tensor(image).to(device)
124
  output = model(image_tensor)
125
- return output.cpu().numpy()
126
 
127
  # Create Gradio UI
128
  demo = gr.Interface(
 
79
  # demo.launch()
80
  ########################3rd######################3
81
 
82
+ # import torch
83
+ # import gradio as gr
84
+ # import requests
85
+ # import os
86
+
87
+ # # Download model weights from Hugging Face model repo (if not already present)
88
+ # model_repo = "Mariam-Elz/CRM" # Your Hugging Face model repo
89
+
90
+ # model_files = {
91
+ # "ccm-diffusion.pth": "ccm-diffusion.pth",
92
+ # "pixel-diffusion.pth": "pixel-diffusion.pth",
93
+ # "CRM.pth": "CRM.pth",
94
+ # }
95
+
96
+ # os.makedirs("models", exist_ok=True)
97
+
98
+ # for filename, output_path in model_files.items():
99
+ # file_path = f"models/{output_path}"
100
+ # if not os.path.exists(file_path):
101
+ # url = f"https://huggingface.co/{model_repo}/resolve/main/{filename}"
102
+ # print(f"Downloading {filename}...")
103
+ # response = requests.get(url)
104
+ # with open(file_path, "wb") as f:
105
+ # f.write(response.content)
106
+
107
+ # # Load model (This part depends on how the model is defined)
108
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
109
+
110
+ # def load_model():
111
+ # model_path = "models/CRM.pth"
112
+ # model = torch.load(model_path, map_location=device)
113
+ # model.eval()
114
+ # return model
115
+
116
+ # model = load_model()
117
+
118
+ # # Define inference function
119
+ # def infer(image):
120
+ # """Process input image and return a reconstructed image."""
121
+ # with torch.no_grad():
122
+ # # Assuming model expects a tensor input
123
+ # image_tensor = torch.tensor(image).to(device)
124
+ # output = model(image_tensor)
125
+ # return output.cpu().numpy()
126
+
127
+ # # Create Gradio UI
128
+ # demo = gr.Interface(
129
+ # fn=infer,
130
+ # inputs=gr.Image(type="numpy"),
131
+ # outputs=gr.Image(type="numpy"),
132
+ # title="Convolutional Reconstruction Model",
133
+ # description="Upload an image to get the reconstructed output."
134
+ # )
135
+
136
+ # if __name__ == "__main__":
137
+ # demo.launch()
138
+
139
+
140
+ #################4th##################
141
  import torch
142
  import gradio as gr
143
  import requests
144
  import os
145
 
146
+ # Define model repo
147
+ model_repo = "Mariam-Elz/CRM"
148
 
149
+ # Define model files and download paths
150
  model_files = {
151
+ "CRM.pth": "models/CRM.pth"
 
 
152
  }
153
 
154
  os.makedirs("models", exist_ok=True)
155
 
156
+ # Download model files only if they don't exist
157
  for filename, output_path in model_files.items():
158
+ if not os.path.exists(output_path):
 
159
  url = f"https://huggingface.co/{model_repo}/resolve/main/{filename}"
160
  print(f"Downloading {filename}...")
161
  response = requests.get(url)
162
+ with open(output_path, "wb") as f:
163
  f.write(response.content)
164
 
165
+ # Load model with low memory usage
 
 
166
  def load_model():
167
  model_path = "models/CRM.pth"
168
+ model = torch.load(model_path, map_location="cpu") # Load on CPU to reduce memory usage
169
  model.eval()
170
  return model
171
 
 
175
  def infer(image):
176
  """Process input image and return a reconstructed image."""
177
  with torch.no_grad():
178
+ image_tensor = torch.tensor(image).unsqueeze(0) # Add batch dimension
179
+ image_tensor = image_tensor.to("cpu") # Keep on CPU to save memory
180
  output = model(image_tensor)
181
+ return output.squeeze(0).numpy()
182
 
183
  # Create Gradio UI
184
  demo = gr.Interface(