Kiwinicki commited on
Commit
427541b
·
1 Parent(s): 6e9efd3

upload better cycles generator

Browse files
README.md CHANGED
@@ -1,43 +1,64 @@
1
- To load and initialize the `Generator` model from the repository, follow these steps:
2
 
3
- 1. **Install Required Packages**: Ensure you have the necessary Python packages installed:
4
 
5
- ```python
6
- pip install torch omegaconf huggingface_hub
7
- ```
8
 
9
- 2. **Download Model Files**: Retrieve the `generator.pth`, `config.json`, and `model.py` files from the Hugging Face repository. You can use the `huggingface_hub` library for this:
10
 
11
- ```python
12
- from huggingface_hub import hf_hub_download
 
13
 
14
- repo_id = "Kiwinicki/sat2map-generator"
15
- generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
16
- config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
17
- model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
18
- ```
19
 
20
- 3. **Load the Model**: Incorporate the downloaded `model.py` to define the `Generator` class, then load the model's state dictionary and configuration:
21
 
22
- ```python
23
- import torch
24
- import json
25
- from omegaconf import OmegaConf
26
- import sys
27
- from pathlib import Path
28
- from model import Generator
29
 
30
- # Load configuration
31
- with open(config_path, "r") as f:
32
- config_dict = json.load(f)
33
- cfg = OmegaConf.create(config_dict)
34
 
35
- # Initialize and load the generator model
36
- generator = Generator(cfg)
37
- generator.load_state_dict(torch.load(generator_path))
38
- generator.eval()
39
- x = torch.randn([1, cfg['channels'], 256, 256])
40
- out = generator(x)
41
- ```
42
 
43
- Here, `generator` is the initialized model ready for inference.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Loading the Generator Model
2
 
3
+ To load and initialize the `Generator` (based on CycleGAN with better cycles) model from the repository, follow these steps:
4
 
5
+ ## 1. Install Required Packages
 
 
6
 
7
+ Ensure you have the necessary Python packages installed:
8
 
9
+ ```bash
10
+ pip install torch==2.5.1 torchvision==0.20.1 safetensors huggingface_hub
11
+ ```
12
 
13
+ ## 2. Download Model Files
 
 
 
 
14
 
15
+ Retrieve the `pytorch_model.safetensors` and `model.py` files from the Hugging Face repository using the `huggingface_hub` library:
16
 
17
+ ```python
18
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
19
 
20
+ repo_id = "Kiwinicki/sat2map-generator"
21
+ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.safetensors")
22
+ generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
23
+ ```
24
 
25
+ ## 3. Load the Model
 
 
 
 
 
 
26
 
27
+ Import the `Generator` class and load the model weights from the safetensors file:
28
+
29
+ ```python
30
+ import torch
31
+ from safetensors.torch import load_file
32
+ from model import Generator, GeneratorConfig
33
+
34
+ # Initialize configuration with default values
35
+ cfg = GeneratorConfig(
36
+ channels=3,
37
+ num_features=64,
38
+ num_residuals=12,
39
+ depth=4
40
+ )
41
+
42
+ # Load the generator model
43
+ state_dict = load_file(model_path)
44
+ generator = Generator(cfg)
45
+ generator.load_state_dict(state_dict)
46
+ generator.eval()
47
+
48
+ # Test the model
49
+ x = torch.randn([1, cfg.channels, 256, 256])
50
+ out = generator(x)
51
+ print(f"Output shape: {out.shape}")
52
+ ```
53
+
54
+ ## 4. Model Configuration
55
+
56
+ The model uses the following default configuration:
57
+ - **channels**: 3 (RGB images)
58
+ - **num_features**: 64 (base number of features)
59
+ - **num_residuals**: 12 (number of residual blocks)
60
+ - **depth**: 4 (network depth)
61
+
62
+ The `generator` is now ready for inference on satellite-to-map translation tasks.
63
+
64
+ Model trained by Andrii Norets from "Czarna Magia".
config.json DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "channels": 3,
3
- "num_features": 64,
4
- "num_residuals": 12,
5
- "depth": 4
6
- }
 
 
 
 
 
 
 
generator.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:925df068c3a6b7110b3be435eb4432399fd337e61ca2b512462f2b596864eca9
3
- size 59701794
 
 
 
 
pytorch_model.safetensors → generator.safetensors RENAMED
File without changes
model.py CHANGED
@@ -1,9 +1,17 @@
1
  from torch import tanh, Tensor
2
  import torch.nn as nn
3
- from omegaconf import DictConfig
4
  from abc import ABC, abstractmethod
5
 
6
 
 
 
 
 
 
 
 
 
7
  class BaseGenerator(ABC, nn.Module):
8
  def __init__(self, channels: int = 3):
9
  super().__init__()
@@ -15,7 +23,7 @@ class BaseGenerator(ABC, nn.Module):
15
 
16
 
17
  class Generator(BaseGenerator):
18
- def __init__(self, cfg: DictConfig):
19
  super().__init__(cfg.channels)
20
  self.cfg = cfg
21
  self.model = self._construct_model()
@@ -124,3 +132,17 @@ class ResidualBlock(nn.Module):
124
 
125
  def forward(self, x: Tensor) -> Tensor:
126
  return x + self.block(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from torch import tanh, Tensor
2
  import torch.nn as nn
3
+ from dataclasses import dataclass
4
  from abc import ABC, abstractmethod
5
 
6
 
7
+ @dataclass
8
+ class GeneratorConfig:
9
+ channels: int = 3
10
+ num_features: int = 64
11
+ num_residuals: int = 12
12
+ depth: int = 4
13
+
14
+
15
  class BaseGenerator(ABC, nn.Module):
16
  def __init__(self, channels: int = 3):
17
  super().__init__()
 
23
 
24
 
25
  class Generator(BaseGenerator):
26
+ def __init__(self, cfg: GeneratorConfig):
27
  super().__init__(cfg.channels)
28
  self.cfg = cfg
29
  self.model = self._construct_model()
 
132
 
133
  def forward(self, x: Tensor) -> Tensor:
134
  return x + self.block(x)
135
+
136
+ if __name__ == '__main__':
137
+ import torch
138
+ from safetensors.torch import load_file
139
+
140
+ cfg = GeneratorConfig()
141
+ state_dict = load_file('generator.safetensors')
142
+ generator = Generator(cfg)
143
+ generator.load_state_dict(state_dict)
144
+ generator.eval()
145
+
146
+ x = torch.randn([1, cfg.channels, 256, 256])
147
+ out = generator(x)
148
+ print(out.shape)
model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:15e1ccacf5b528313d57c55df11eccf643e3efb54a1089ffaf52766c3e4174d4
3
- size 59680580
 
 
 
 
requirements.txt DELETED
@@ -1,25 +0,0 @@
1
- antlr4-python3-runtime==4.9.3
2
- certifi==2024.12.14
3
- charset-normalizer==3.4.1
4
- colorama==0.4.6
5
- filelock==3.17.0
6
- fsspec==2024.12.0
7
- huggingface-hub==0.28.0
8
- idna==3.10
9
- Jinja2==3.1.5
10
- MarkupSafe==3.0.2
11
- mpmath==1.3.0
12
- networkx==3.4.2
13
- numpy==2.2.2
14
- omegaconf==2.3.0
15
- packaging==24.2
16
- pillow==11.1.0
17
- PyYAML==6.0.2
18
- requests==2.32.3
19
- setuptools==75.8.0
20
- sympy==1.13.1
21
- torch==2.5.1
22
- torchvision==0.20.1
23
- tqdm==4.67.1
24
- typing_extensions==4.12.2
25
- urllib3==2.3.0