upload better cycles generator
Browse files- README.md +54 -33
- config.json +0 -6
- generator.pth +0 -3
- pytorch_model.safetensors → generator.safetensors +0 -0
- model.py +24 -2
- model.safetensors +0 -3
- requirements.txt +0 -25
README.md
CHANGED
@@ -1,43 +1,64 @@
|
|
1 |
-
|
2 |
|
3 |
-
|
4 |
|
5 |
-
|
6 |
-
pip install torch omegaconf huggingface_hub
|
7 |
-
```
|
8 |
|
9 |
-
|
10 |
|
11 |
-
|
12 |
-
|
|
|
13 |
|
14 |
-
|
15 |
-
generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
|
16 |
-
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
|
17 |
-
model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
|
18 |
-
```
|
19 |
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
import json
|
25 |
-
from omegaconf import OmegaConf
|
26 |
-
import sys
|
27 |
-
from pathlib import Path
|
28 |
-
from model import Generator
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
generator = Generator(cfg)
|
37 |
-
generator.load_state_dict(torch.load(generator_path))
|
38 |
-
generator.eval()
|
39 |
-
x = torch.randn([1, cfg['channels'], 256, 256])
|
40 |
-
out = generator(x)
|
41 |
-
```
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Loading the Generator Model
|
2 |
|
3 |
+
To load and initialize the `Generator` (based on CycleGAN with better cycles) model from the repository, follow these steps:
|
4 |
|
5 |
+
## 1. Install Required Packages
|
|
|
|
|
6 |
|
7 |
+
Ensure you have the necessary Python packages installed:
|
8 |
|
9 |
+
```bash
|
10 |
+
pip install torch==2.5.1 torchvision==0.20.1 safetensors huggingface_hub
|
11 |
+
```
|
12 |
|
13 |
+
## 2. Download Model Files
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
Retrieve the `pytorch_model.safetensors` and `model.py` files from the Hugging Face repository using the `huggingface_hub` library:
|
16 |
|
17 |
+
```python
|
18 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
repo_id = "Kiwinicki/sat2map-generator"
|
21 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.safetensors")
|
22 |
+
generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
|
23 |
+
```
|
24 |
|
25 |
+
## 3. Load the Model
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
Import the `Generator` class and load the model weights from the safetensors file:
|
28 |
+
|
29 |
+
```python
|
30 |
+
import torch
|
31 |
+
from safetensors.torch import load_file
|
32 |
+
from model import Generator, GeneratorConfig
|
33 |
+
|
34 |
+
# Initialize configuration with default values
|
35 |
+
cfg = GeneratorConfig(
|
36 |
+
channels=3,
|
37 |
+
num_features=64,
|
38 |
+
num_residuals=12,
|
39 |
+
depth=4
|
40 |
+
)
|
41 |
+
|
42 |
+
# Load the generator model
|
43 |
+
state_dict = load_file(model_path)
|
44 |
+
generator = Generator(cfg)
|
45 |
+
generator.load_state_dict(state_dict)
|
46 |
+
generator.eval()
|
47 |
+
|
48 |
+
# Test the model
|
49 |
+
x = torch.randn([1, cfg.channels, 256, 256])
|
50 |
+
out = generator(x)
|
51 |
+
print(f"Output shape: {out.shape}")
|
52 |
+
```
|
53 |
+
|
54 |
+
## 4. Model Configuration
|
55 |
+
|
56 |
+
The model uses the following default configuration:
|
57 |
+
- **channels**: 3 (RGB images)
|
58 |
+
- **num_features**: 64 (base number of features)
|
59 |
+
- **num_residuals**: 12 (number of residual blocks)
|
60 |
+
- **depth**: 4 (network depth)
|
61 |
+
|
62 |
+
The `generator` is now ready for inference on satellite-to-map translation tasks.
|
63 |
+
|
64 |
+
Model trained by Andrii Norets from "Czarna Magia".
|
config.json
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"channels": 3,
|
3 |
-
"num_features": 64,
|
4 |
-
"num_residuals": 12,
|
5 |
-
"depth": 4
|
6 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generator.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:925df068c3a6b7110b3be435eb4432399fd337e61ca2b512462f2b596864eca9
|
3 |
-
size 59701794
|
|
|
|
|
|
|
|
pytorch_model.safetensors → generator.safetensors
RENAMED
File without changes
|
model.py
CHANGED
@@ -1,9 +1,17 @@
|
|
1 |
from torch import tanh, Tensor
|
2 |
import torch.nn as nn
|
3 |
-
from
|
4 |
from abc import ABC, abstractmethod
|
5 |
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
class BaseGenerator(ABC, nn.Module):
|
8 |
def __init__(self, channels: int = 3):
|
9 |
super().__init__()
|
@@ -15,7 +23,7 @@ class BaseGenerator(ABC, nn.Module):
|
|
15 |
|
16 |
|
17 |
class Generator(BaseGenerator):
|
18 |
-
def __init__(self, cfg:
|
19 |
super().__init__(cfg.channels)
|
20 |
self.cfg = cfg
|
21 |
self.model = self._construct_model()
|
@@ -124,3 +132,17 @@ class ResidualBlock(nn.Module):
|
|
124 |
|
125 |
def forward(self, x: Tensor) -> Tensor:
|
126 |
return x + self.block(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from torch import tanh, Tensor
|
2 |
import torch.nn as nn
|
3 |
+
from dataclasses import dataclass
|
4 |
from abc import ABC, abstractmethod
|
5 |
|
6 |
|
7 |
+
@dataclass
|
8 |
+
class GeneratorConfig:
|
9 |
+
channels: int = 3
|
10 |
+
num_features: int = 64
|
11 |
+
num_residuals: int = 12
|
12 |
+
depth: int = 4
|
13 |
+
|
14 |
+
|
15 |
class BaseGenerator(ABC, nn.Module):
|
16 |
def __init__(self, channels: int = 3):
|
17 |
super().__init__()
|
|
|
23 |
|
24 |
|
25 |
class Generator(BaseGenerator):
|
26 |
+
def __init__(self, cfg: GeneratorConfig):
|
27 |
super().__init__(cfg.channels)
|
28 |
self.cfg = cfg
|
29 |
self.model = self._construct_model()
|
|
|
132 |
|
133 |
def forward(self, x: Tensor) -> Tensor:
|
134 |
return x + self.block(x)
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
import torch
|
138 |
+
from safetensors.torch import load_file
|
139 |
+
|
140 |
+
cfg = GeneratorConfig()
|
141 |
+
state_dict = load_file('generator.safetensors')
|
142 |
+
generator = Generator(cfg)
|
143 |
+
generator.load_state_dict(state_dict)
|
144 |
+
generator.eval()
|
145 |
+
|
146 |
+
x = torch.randn([1, cfg.channels, 256, 256])
|
147 |
+
out = generator(x)
|
148 |
+
print(out.shape)
|
model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:15e1ccacf5b528313d57c55df11eccf643e3efb54a1089ffaf52766c3e4174d4
|
3 |
-
size 59680580
|
|
|
|
|
|
|
|
requirements.txt
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
antlr4-python3-runtime==4.9.3
|
2 |
-
certifi==2024.12.14
|
3 |
-
charset-normalizer==3.4.1
|
4 |
-
colorama==0.4.6
|
5 |
-
filelock==3.17.0
|
6 |
-
fsspec==2024.12.0
|
7 |
-
huggingface-hub==0.28.0
|
8 |
-
idna==3.10
|
9 |
-
Jinja2==3.1.5
|
10 |
-
MarkupSafe==3.0.2
|
11 |
-
mpmath==1.3.0
|
12 |
-
networkx==3.4.2
|
13 |
-
numpy==2.2.2
|
14 |
-
omegaconf==2.3.0
|
15 |
-
packaging==24.2
|
16 |
-
pillow==11.1.0
|
17 |
-
PyYAML==6.0.2
|
18 |
-
requests==2.32.3
|
19 |
-
setuptools==75.8.0
|
20 |
-
sympy==1.13.1
|
21 |
-
torch==2.5.1
|
22 |
-
torchvision==0.20.1
|
23 |
-
tqdm==4.67.1
|
24 |
-
typing_extensions==4.12.2
|
25 |
-
urllib3==2.3.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|