File size: 3,344 Bytes
8227ec8
 
 
 
 
 
 
72596b6
4c575c1
8227ec8
9a8d3ef
 
8227ec8
 
 
 
 
 
 
4c575c1
 
8227ec8
2babd09
8227ec8
 
 
 
 
 
 
 
 
 
 
9a8d3ef
 
 
4ee11bb
 
 
8227ec8
4ee11bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a8d3ef
8227ec8
4ee11bb
 
 
 
 
 
 
 
 
 
 
 
8227ec8
4ee11bb
8227ec8
4ee11bb
 
8227ec8
4ee11bb
 
da4f6d9
8227ec8
4dcf8e1
8227ec8
 
 
 
 
 
 
 
 
 
 
 
4dcf8e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
---
license: cc-by-nc-sa-4.0
datasets:
- mwalmsley/gz2
metrics:
- accuracy 
---
# Important Disclaimer
This model is a part of my **bachelor thesis** (VUT in Brno, FIT).  

# CosmoFormer Model

This is a **TorchScript** version of our CrossFormer-based image classification model.  
It was trained on [Galaxy Zoo 2 (GZ2)](https://www.zooniverse.org/projects/zookeeper/galaxy-zoo/about/research) data to classify galaxy morphologies (spirals, ellipticals, and other morphological types).
I also leveraged the [galaxy-datasets pip package](https://github.com/mwalmsley/galaxy-datasets) by [Michael Walmsley](https://github.com/mwalmsley) for data loading and handling.


## Model Details

- **Architecture:** [CrossFormer](https://github.com/lucidrains/vit-pytorch) variant
- **Model accuracy:** 75%  
- **Input Resolution:** 224×224 RGB  
- **Number of Classes:** 8 (Depends on your label encoder e.g., galaxy morphology classes)  
- **Checkpoint Format:** TorchScript (`.pt`) file  
- **Frameworks:** Originally in PyTorch with `vit_pytorch`. Now self-contained in TorchScript.  

## Usage

You can load and run this model **directly in PyTorch** **without** installing `vit_pytorch`. Just make sure you have an environment with:

- `torch` >= 1.13.0  
- `torchvision` (optional, if you need standard transforms)  

### Quick Start Example

```python
import torch
import torchvision.transforms.v2 as v2
from huggingface_hub import hf_hub_download
from PIL import Image

label_mapping = {
    0: 'barred_spiral',
    1: 'edge_on_disk',
    2: 'featured_without_bar_or_spiral',
    3: 'irregular',
    4: 'smooth_cigar',
    5: 'smooth_inbetween',
    6: 'smooth_round',
    7: 'unbarred_spiral'
}

# 1. Define the path to the hugging face repo
ts_path = hf_hub_download(
    repo_id="artursultanov/cosmoformer-model",
    filename="cosmoformer_traced_cpu.pt"
)

# 2. Load the model from the hugging face repo
model = torch.jit.load(ts_path, map_location="cpu")
model.eval()

# 3. Define image transform to match model's internal representation 
transform = v2.Compose([
    v2.Resize((224, 224)),
    v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
])

# 4. Load the image
image_path = "test_image.jpg"
image = Image.open(image_path).convert("RGB")

tensor = transform(image)  # shape [3, 224, 224]
tensor = tensor.unsqueeze(0)  shape [1, 3, 224, 224]

# 5. Inference
with torch.no_grad():
    output = model(tensor)
    predicted_idx = torch.argmax(output, dim=1).item()

predicted_label = label_mapping[predicted_idx]
print("Predicted class:", predicted_label)
```

```
@article{10.1093/mnras/stt1458,
author = {Willett, Kyle W. and Lintott, Chris J. and Bamford, Steven P. and Masters, Karen L. and Simmons, Brooke D. and Casteels, Kevin R. V. and Edmondson, Edward M. and Fortson, Lucy F. and Kaviraj, Sugata and Keel, William C. and Melvin, Thomas and Nichol, Robert C. and Raddick, M. Jordan and Schawinski, Kevin and Simpson, Robert J. and Skibba, Ramin A. and Smith, Arfon M. and Thomas, Daniel},
title = "{Galaxy Zoo 2: detailed morphological classifications for 304 122 galaxies from the Sloan Digital Sky Survey}",
journal = {Monthly Notices of the Royal Astronomical Society},
volume = {435},
number = {4},
pages = {2835-2860},
year = {2013},
month = {09},
issn = {0035-8711},
doi = {10.1093/mnras/stt1458},
}
```