Kiwinicki commited on
Commit
1dc389e
·
verified ·
1 Parent(s): 7b55753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -1,34 +1,36 @@
1
  import gradio as gr
2
  import torch
3
  from huggingface_hub import hf_hub_download
4
- import json
5
- from omegaconf import OmegaConf
6
  import sys
7
  import os
8
  from PIL import Image
9
  import torchvision.transforms as transforms
 
10
 
11
  photos_folder = "Photos"
12
 
13
- # Download model and config
14
  repo_id = "Kiwinicki/sat2map-generator"
15
- generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
16
- config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
17
- model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
18
 
19
  # Add path to model
20
- sys.path.append(os.path.dirname(model_path))
21
- from model import Generator
22
 
23
- # Load configuration
24
- with open(config_path, "r") as f:
25
- config_dict = json.load(f)
26
- cfg = OmegaConf.create(config_dict)
 
 
 
27
 
28
  # Initialize model
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  generator = Generator(cfg).to(device)
31
- generator.load_state_dict(torch.load(generator_path, map_location=device))
 
32
  generator.eval()
33
 
34
  # Transformations
@@ -86,12 +88,11 @@ def app():
86
  gallery = gr.Gallery(
87
  label="Image Gallery",
88
  value=gallery_images,
89
- columns=3, # Set number of columns directly in the constructor
90
  rows=2,
91
  height="auto"
92
  )
93
 
94
-
95
  with gr.Column():
96
  output_image = gr.Image(label="Result Image", type="pil")
97
 
 
1
  import gradio as gr
2
  import torch
3
  from huggingface_hub import hf_hub_download
 
 
4
  import sys
5
  import os
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
+ from safetensors.torch import load_file
9
 
10
  photos_folder = "Photos"
11
 
12
+ # Download model files
13
  repo_id = "Kiwinicki/sat2map-generator"
14
+ model_path = hf_hub_download(repo_id=repo_id, filename="generator.safetensors")
15
+ generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
 
16
 
17
  # Add path to model
18
+ sys.path.append(os.path.dirname(generator_code_path))
19
+ from model import Generator, GeneratorConfig
20
 
21
+ # Initialize configuration
22
+ cfg = GeneratorConfig(
23
+ channels=3,
24
+ num_features=64,
25
+ num_residuals=12,
26
+ depth=4
27
+ )
28
 
29
  # Initialize model
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  generator = Generator(cfg).to(device)
32
+ state_dict = load_file(model_path)
33
+ generator.load_state_dict(state_dict)
34
  generator.eval()
35
 
36
  # Transformations
 
88
  gallery = gr.Gallery(
89
  label="Image Gallery",
90
  value=gallery_images,
91
+ columns=3,
92
  rows=2,
93
  height="auto"
94
  )
95
 
 
96
  with gr.Column():
97
  output_image = gr.Image(label="Result Image", type="pil")
98