artursultanov commited on
Commit
8227ec8
·
verified ·
1 Parent(s): 9a8d3ef

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -6
README.md CHANGED
@@ -1,12 +1,61 @@
 
 
 
 
 
 
 
 
1
  # CosmoFormer Model
2
 
3
- This is a TorchScript version of our CrossFormer-based model.
4
- You can load it in plain PyTorch without needing `vit_pytorch`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- Usage Example:
7
  ```python
8
  import torch
9
- model = torch.jit.load('cosmoformer_traced.pt')
 
 
10
  model.eval()
11
- ...
12
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-sa-4.0
3
+ datasets:
4
+ - mwalmsley/gz2
5
+ metrics:
6
+ - accuracy
7
+ ---
8
+
9
  # CosmoFormer Model
10
 
11
+ This is a **TorchScript** version of our CrossFormer-based image classification model.
12
+ It was trained on [Galaxy Zoo 2 (GZ2)](https://www.zooniverse.org/projects/zookeeper/galaxy-zoo/about/research) data to classify galaxy morphologies (spirals, ellipticals, and other morphological types).
13
+ I also leveraged the [galaxy-datasets pip package](https://github.com/mwalmsley/galaxy-datasets) by [Michael Walmsley](https://github.com/mwalmsley) for data loading and handling.
14
+
15
+
16
+ ## Model Details
17
+
18
+ - **Architecture:** [CrossFormer](https://github.com/lucidrains/vit-pytorch) variant
19
+ - **Input Resolution:** 224×224 RGB
20
+ - **Number of Classes:** Depends on your label encoder (e.g., galaxy morphology classes)
21
+ - **Checkpoint Format:** TorchScript (`.pt`) file
22
+ - **Frameworks:** Originally in PyTorch with `vit_pytorch`. Now self-contained in TorchScript.
23
+
24
+ ## Usage
25
+
26
+ You can load and run this model **directly in PyTorch** **without** installing `vit_pytorch`. Just make sure you have an environment with:
27
+
28
+ - `torch` >= 1.13.0
29
+ - `torchvision` (optional, if you need standard transforms)
30
+
31
+ ### Quick Start Example
32
 
 
33
  ```python
34
  import torch
35
+
36
+ # 1. Load the model
37
+ model = torch.jit.load("cosmoformer_traced.pt")
38
  model.eval()
39
+
40
+ # 2. Inference
41
+ # Suppose you have a 3-channel image tensor (1, 3, 224, 224)
42
+ dummy_input = torch.randn(1, 3, 224, 224)
43
+
44
+ with torch.no_grad():
45
+ outputs = model(dummy_input)
46
+
47
+ print(outputs.shape) # e.g., [1, num_classes]
48
+
49
+
50
+ @article{10.1093/mnras/stt1458,
51
+ author = {Willett, Kyle W. and Lintott, Chris J. and Bamford, Steven P. and Masters, Karen L. and Simmons, Brooke D. and Casteels, Kevin R. V. and Edmondson, Edward M. and Fortson, Lucy F. and Kaviraj, Sugata and Keel, William C. and Melvin, Thomas and Nichol, Robert C. and Raddick, M. Jordan and Schawinski, Kevin and Simpson, Robert J. and Skibba, Ramin A. and Smith, Arfon M. and Thomas, Daniel},
52
+ title = "{Galaxy Zoo 2: detailed morphological classifications for 304 122 galaxies from the Sloan Digital Sky Survey}",
53
+ journal = {Monthly Notices of the Royal Astronomical Society},
54
+ volume = {435},
55
+ number = {4},
56
+ pages = {2835-2860},
57
+ year = {2013},
58
+ month = {09},
59
+ issn = {0035-8711},
60
+ doi = {10.1093/mnras/stt1458},
61
+ }