quchenyuan commited on
Commit
975380a
·
verified ·
1 Parent(s): ddc9fc7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +141 -3
README.md CHANGED
@@ -1,3 +1,141 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - vision
5
+ ---
6
+
7
+
8
+ # VisualSplit
9
+
10
+ **VisualSplit** is a ViT-based model that explicitly factorises an image into **classical visual descriptors**—such as **edges**, **color segmentation**, and **grayscale histogram**—and learns to reconstruct the image conditioned on those descriptors. This design yields **interpretable representations** where geometry (edges), albedo/appearance (segmented colors), and global tone (histogram) can be reasoned about or varied independently.
11
+
12
+ > **Training data**: ImageNet-1K.
13
+ ---
14
+
15
+ ## Model Description
16
+
17
+ - **Inputs** (at inference):
18
+ - An RGB image (for convenience) which is converted to descriptors using the provided `FeatureExtractor` (edges, color segmentation, grayscale histogram).
19
+ - **Outputs**:
20
+ - A reconstructed RGB image tensor (same spatial size as the model’s training resolution; default `224×224` unless you trained otherwise).
21
+
22
+ ---
23
+
24
+ ## Getting Started (Inference)
25
+
26
+ Below are two ways to run inference with the uploaded `model.safetensors`.
27
+
28
+ ### 1) Minimal PyTorch + safetensors (load state dict)
29
+
30
+ ```python
31
+ import torch
32
+ from huggingface_hub import hf_hub_download
33
+ from safetensors.torch import load_file
34
+
35
+ # 1) Import your model & config from the VisualSplit repo
36
+ from visualsplit.models.CrossViT import CrossViTForPreTraining, CrossViTConfig
37
+ from visualsplit.utils import FeatureExtractor
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ # 2) Build a config matching your training (edit if you changed widths/depths)
42
+ config = CrossViTConfig(
43
+ image_size=224, # change if your training size differs
44
+ patch_size=16,
45
+ # ... any other config fields your repo exposes
46
+ )
47
+
48
+ model = CrossViTForPreTraining(config).to(device)
49
+ model.eval()
50
+
51
+ # 3) Download and load state dict from this model repo
52
+ # Replace REPO_ID with your Hugging Face model id, e.g. "HenryQUQ/visualsplit")
53
+ ckpt_path = hf_hub_download(repo_id="REPO_ID", filename="model.safetensors")
54
+ state_dict = load_file(ckpt_path)
55
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
56
+ print("Missing keys:", missing)
57
+ print("Unexpected keys:", unexpected)
58
+
59
+ # 4) Prepare an input image and extract descriptors
60
+ from PIL import Image
61
+ from torchvision import transforms
62
+
63
+ image = Image.open("input.jpg").convert("RGB")
64
+ transform = transforms.Compose([
65
+ transforms.Resize((config.image_size, config.image_size)),
66
+ transforms.ToTensor(),
67
+ ])
68
+ pixel_values = transform(image).unsqueeze(0).to(device) # (1, 3, H, W)
69
+
70
+ # FeatureExtractor provided by the repo should return the required tensors
71
+ extractor = FeatureExtractor().to(device)
72
+ with torch.no_grad():
73
+ edge, gray_hist, segmented_rgb, _ = extractor(pixel_values)
74
+
75
+ # 5) Run inference (reconstruction)
76
+ with torch.no_grad():
77
+ outputs = model(
78
+ source_edge=edge,
79
+ source_gray_level_histogram=gray_hist,
80
+ source_segmented_rgb=segmented_rgb,
81
+ )
82
+ # Your repo’s forward returns may differ; adjust the key accordingly:
83
+ reconstructed = outputs["logits_reshape"] # (1, 3, H, W)
84
+
85
+ # 6) Convert to PIL for visualisation
86
+ to_pil = transforms.ToPILImage()
87
+ recon_img = to_pil(reconstructed.squeeze(0).cpu().clamp(0, 1))
88
+ recon_img.save("reconstructed.png")
89
+ print("Saved to reconstructed.png")
90
+ ```
91
+
92
+ ### 2) Reproducing the notebook flow (`notebook/validation.ipynb`)
93
+
94
+ The repository provides a validation notebook that:
95
+ 1. Loads the trained model,
96
+ 2. Uses `FeatureExtractor` to compute **edges**, **color-segmented RGB**, and **grayscale histograms**,
97
+ 3. Runs the model to obtain a reconstructed image,
98
+ 4. Saves/visualises the result.
99
+
100
+ ---
101
+
102
+ ## Installation & Requirements
103
+
104
+ ```bash
105
+ # clone the VisualSplit code
106
+ git clone https://github.com/HenryQUQ/VisualSplit.git
107
+ cd VisualSplit
108
+ # pip install -e .
109
+ ```
110
+
111
+ ---
112
+
113
+ ## Training Data
114
+
115
+ - **Dataset**: **ImageNet-1K**.
116
+ -
117
+ > This repository only hosts the **trained checkpoint for inference**. Follow the GitHub repo for the full training pipeline and data preparation scripts.
118
+
119
+ ---
120
+
121
+ ## Model Sources
122
+
123
+ - **Code**: https://github.com/HenryQUQ/VisualSplit
124
+ - **Weights (this page)**: this Hugging Face model repo
125
+
126
+ ---
127
+
128
+ ## Citation
129
+
130
+ If you use this model or ideas, please cite:
131
+
132
+ ```bibtex
133
+ @inproceedings{Qu2025VisualSplit,
134
+ title = {Exploring Image Representation with Decoupled Classical Visual Descriptors},
135
+ author = {Qu, Chenyuan and Chen, Hao and Jiao, Jianbo},
136
+ booktitle = {British Machine Vision Conference (BMVC)},
137
+ year = {2025}
138
+ }
139
+ ```
140
+
141
+ ---