|
# Loading the Generator Model |
|
|
|
To load and initialize the `Generator` (based on CycleGAN with better cycles) model from the repository, follow these steps: |
|
|
|
## 1. Install Required Packages |
|
|
|
Ensure you have the necessary Python packages installed: |
|
|
|
```bash |
|
pip install torch==2.5.1 torchvision==0.20.1 huggingface_hub |
|
``` |
|
|
|
## 2. Download Model Files |
|
|
|
Retrieve the `generator.pth` and `model.py` files from the Hugging Face repository using the `huggingface_hub` library: |
|
|
|
```python |
|
from huggingface_hub import hf_hub_download |
|
|
|
repo_id = "Kiwinicki/sat2map-generator" |
|
model_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") |
|
generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py") |
|
``` |
|
|
|
## 3. Load the Model |
|
|
|
Import the `Generator` class and load the model weights from the `.pth` file: |
|
|
|
```python |
|
import torch |
|
from model import Generator, GeneratorConfig |
|
|
|
|
|
# Load the generator model |
|
cfg = GeneratorConfig() |
|
generator = Generator(cfg) |
|
generator.load_state_dict(torch.load('generator.pth')) |
|
generator.eval() |
|
|
|
# Test the model |
|
x = torch.randn([1, cfg.channels, 256, 256]) |
|
out = generator(x) |
|
print(f"Output shape: {out.shape}") |
|
``` |
|
|
|
## 4. Model Configuration |
|
|
|
The model uses the following default configuration: |
|
- **channels**: 3 (RGB images) |
|
- **num_features**: 64 (base number of features) |
|
- **num_residuals**: 12 (number of residual blocks) |
|
- **depth**: 4 (network depth) |
|
|
|
The `generator` is now ready for inference on satellite-to-map translation tasks. |
|
|
|
Model trained by Andrii Norets from "Czarna Magia". |