Kiwinicki commited on
Commit
35a938f
·
1 Parent(s): 0562814

working version on pth weights

Browse files
Files changed (2) hide show
  1. app.py +8 -16
  2. requirements.txt +1 -2
app.py CHANGED
@@ -5,32 +5,24 @@ import sys
5
  import os
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
- from safetensors.torch import load_file
9
 
10
  photos_folder = "Photos"
11
 
12
- # Download model files
13
  repo_id = "Kiwinicki/sat2map-generator"
14
- model_path = hf_hub_download(repo_id=repo_id, filename="generator.safetensors")
15
- generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
16
 
17
  # Add path to model
18
- sys.path.append(os.path.dirname(generator_code_path))
19
  from model import Generator, GeneratorConfig
20
 
21
- # Initialize configuration
22
- cfg = GeneratorConfig(
23
- channels=3,
24
- num_features=64,
25
- num_residuals=12,
26
- depth=4
27
- )
28
-
29
  # Initialize model
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  generator = Generator(cfg).to(device)
32
- state_dict = load_file(model_path)
33
- generator.load_state_dict(state_dict)
34
  generator.eval()
35
 
36
  # Transformations
@@ -40,6 +32,7 @@ transform = transforms.Compose([
40
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
41
  ])
42
 
 
43
  def process_image(image):
44
  if image is None:
45
  return None
@@ -55,7 +48,6 @@ def process_image(image):
55
  output_image = output_tensor.squeeze(0).cpu()
56
  output_image = output_image * 0.5 + 0.5 # Denormalization
57
  output_image = transforms.ToPILImage()(output_image)
58
-
59
  return output_image
60
 
61
  def load_images_from_folder(folder):
 
5
  import os
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
+
9
 
10
  photos_folder = "Photos"
11
 
12
+ # Download model and config
13
  repo_id = "Kiwinicki/sat2map-generator"
14
+ generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
15
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
16
 
17
  # Add path to model
18
+ sys.path.append(os.path.dirname(model_path))
19
  from model import Generator, GeneratorConfig
20
 
 
 
 
 
 
 
 
 
21
  # Initialize model
22
+ cfg = GeneratorConfig()
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  generator = Generator(cfg).to(device)
25
+ generator.load_state_dict(torch.load(generator_path))
 
26
  generator.eval()
27
 
28
  # Transformations
 
32
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
33
  ])
34
 
35
+
36
  def process_image(image):
37
  if image is None:
38
  return None
 
48
  output_image = output_tensor.squeeze(0).cpu()
49
  output_image = output_image * 0.5 + 0.5 # Denormalization
50
  output_image = transforms.ToPILImage()(output_image)
 
51
  return output_image
52
 
53
  def load_images_from_folder(folder):
requirements.txt CHANGED
@@ -3,5 +3,4 @@ torch>=2.0.0
3
  torchvision>=0.15.0
4
  gradio
5
  pillow
6
- pydantic==2.10.6
7
- safetensors
 
3
  torchvision>=0.15.0
4
  gradio
5
  pillow
6
+ pydantic==2.10.6