File size: 25,411 Bytes
def2cca
 
 
 
 
 
 
 
 
 
 
 
50bd1fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
---
title: Cels
emoji: 🌖
colorFrom: blue
colorTo: purple
sdk: streamlit
sdk_version: 1.42.1
app_file: app.py
pinned: false
---

Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference


# Cross Attention Classifier

Ниже технические детали того, как устроен репозиторий и как обучить модель.
В самой последней секции "Описание подходов" подробно описано, как я пришла к этому подходу, с какими проблемами встретилась, а также описаны два других подхода, которые я решила не реализовывать (будет мини-эссе, готовьтесь).

## Описание проекта

В проекте используется self-supervised обучение (BYOL) и последующая классификацию изображений с помощью Cross Attention.  
- **train_byol.py** — скрипт для обучения модели-энкодера по методу BYOL.  
- **train_cross_classifier.py** — скрипт для обучения классификатора, который использует предварительно обученный энкодер и Cross Attention.  
- **app.py** — Streamlit-приложение для инференса и визуализации предсказаний (генерация случайных изображений и получение метки от модели).

## Структура репозитория

```
.
├── src
│   ├── dataset.py            # Реализация датасетов (RandomAugmentedDataset и RandomPairDataset)
│   ├── inference.py          # Класс для инференса (CrossAttentionInference) и вспомогательные методы
│   └── models.py             # Определения моделей (BYOL, VGGLikeEncode, CrossAttentionClassifier)
├── train_byol.py             # Скрипт обучения модели BYOL
├── train_cross_classifier.py # Скрипт обучения Cross Attention Classifier (использует готовый энкодер)
├── app.py                    # Streamlit-приложение для инференса
├── requirements.txt          # Список Python-зависимостей (pip install -r requirements.txt)
└── pyproject.toml / poetry.lock # Файл для установки зависимостей через Poetry
```

## Установка зависимостей

Можно установить зависимости двумя способами:  
1. **Через `pip` и `requirements.txt`:**  
   ```bash
   pip install -r requirements.txt
   ```
2. **Через Poetry:**  
   ```bash
   poetry install
   ```

## Как обучить модель

### 1. Обучение энкодера с помощью BYOL
Нужно запустить:
```bash
python train_byol.py
```
- Этот скрипт обучает модель энкодера (`VGGLikeEncode`) методом BYOL на данных, сгенерированных `RandomAugmentedDataset`.
- После обучения лучшая модель (с минимальным `val_loss`) сохраняется в `best_byol.pth`.

### 2. Обучение Cross Attention Classifier

`best_byol.pth` (веса энкодера) должны лежать в корневой папке (можно указать другой путь). Затем нужно запустить:
```bash
python train_cross_classifier.py
```
- Этот скрипт использует предобученный энкодер и обучает классификатор для определения, содержат ли картинки одинаковую геометрическую фигуру.
- По итогам сохранит веса модели-классификатора в `best_attention_classifier.pth`.

## Как запустить инференс

### Запуск через Streamlit-приложение

1. Файл весов `best_attention_classifier.pth` должен лежать в корневой папке
2. Нужно запустить Streamlit-приложение:
   ```bash
   streamlit run app.py
   ```
3. Дальше, нужно перейти по адресу, который выдаст Streamlit (по умолчанию [http://localhost:8501](http://localhost:8501)).  
4. Нажмите кнопку **«Сгенерировать изображения»**. Приложение сгенерирует пару случайных изображений и покажет предсказанную моделью метку.

### Использование класса инференса в коде напрямую

Можно использовать модель напрямую (без интерфейса Streamlit), импортируйте класс из `src/inference.py`, передайте путь к весам модели и вызовите метод предсказания. Пример:

```python
import torch
from src.inference import CrossAttentionInference

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

inference = CrossAttentionInference(
    model_path="best_attention_classifier.pth",
    device=device
)

pred_label, (img1, img2) = inference.predict_random_pair()
print(f"Предсказанная метка: {pred_label}")
```

---

## Описание подходов


### BYOL + Cross-attention (выбранный подход)

Когда я обдумывала финальную архитектуру, я поняла, что у креативности нет предела, поэтому каждое решение должно быть обосновано не только фразой "прикольно, можно попробовать", но и существующими проблемами, которые хочется решить.
И я решила отталкиваться от реальных задач, а именно от проблемы с отсутствием данных в медицинской сфере.
В текущей задаче такой проблемы, очевидно, нет, мы можем сгенерировать хоть миллион изображений и для всех будет лейбл.
Но вот что, если у нас нет возможности сгенерировать миллион изображений? Или если у нас есть 100к изображений, но только 10000 из них размечены? 
Например, у нас есть неплохой банк изображений с разными опухолями, но размечены только 10% из них. Как можно использовать эти данные для обучения модели, чтобы она могла классифицировать новые изображения?
Саму же задачу можно перенести на задачу вида "мониторинг прогрессирования заболевания" или "сравнение патологий".

Поэтому я решила использовать self-supervised обучение для того, чтобы обучить модель на неразмеченных данных и затем дообучить ее на небольшом датасете с разметкой.

#### BYOL
BYOL [Bootstrap Your Own Latent](https://arxiv.org/pdf/2006.07733) — это метод self-supervised обучения, который позволяет обучить модель на неразмеченных данных. 
Важной особенностью конкретно этого подхода заключается в том, что этот метод не требует негативных пар для обучения, как некоторые другие contrastive методы. 
В BYOL используется две копии одной и той же модели, которые обучаются предсказывать друг друга на основе двух views (аугментаций) одного изображения. 
Не схлапываться в один вектор помогает то, что архитектура не симметрична, так как в одной из веток добавляется MLP предиктор, а также stop gradient операция.
В итоге модель учится извлекать признаки из изображения, которые можно использовать для дообучения на меньшем датасете.  
В данном случае я использовала энкодер с похожей на `VGG` архитектурой.
Выбрала VGG я потому, что использовать, например, ResNet со skip-connection нет смысла, так как изображение всего ли 32x32, и через несколько слоев feature мапа была уже 8x8.

В целом сама задача заставляет балансировать между сложными подходами и реальной возможностью обучить модель на подобном датасете, так как реализовать можно (почти) что угодно, но переобучиться на таком датасете достаточно легко.

Изначально я планировала добавить в датасет для предобучения другие фигуры (треугольники, звездочки и тд), но сами эти фигуры занимают несколько пикселей, и аугментации их сильно искажают. В целом на таких маленьких изображениях почти все аугментации становятся агрессивными.
Поэтому я остановилась на двух фигурах, но добавила в аугментации реверс цвета, повороты, сдвигы, гауссовский шум и тд.

#### Cross Attention
Честно скажу, именно на этот подход меня вдохновила задача с собеседования, где нужно было сопоставить два снимка одной и той же области.
Я нашла статью - [An Adaptive Remote Sensing Image-Matching Network Based on Cross Attention and Deformable Convolution](https://www.researchgate.net/publication/388063503_An_Adaptive_Remote_Sensing_Image-Matching_Network_Based_on_Cross_Attention_and_Deformable_Convolution)
где авторы решают похожую задачу (они тоже кстати используют VGG), но для более сложных изображений (сопоставление фотографий со спутника).
Я помню в чем заключается проблема cross-attention - у него квадратичная сложность, и если изображение имеет размер 512x512, то это уже становится проблемой. 
Но так как в задаче изображения 32x32, я решила, что будет уместно применить данный подход (в предложенных подходах дальше я опишу, как бы решала задачу, если бы изображения были больше).
Также, я добавила position эмбеддинги, так как при переходе к cross-attention информация о позиции теряется.

Почему cross-attention? Он позволяет каждому пикселю (точнее, каждому патчу) в одном изображении "смотреть" на все патчи в другом изображении.
Таким образом, если фигуры находятся в противоположных углах, модель это учтет.
Ну и плюс тенденции последних лет - внимание, внимание, внимание.

#### Итоговая архитектура
Сама архитектура представляет собой два VGGLike энкодера с shared весами, предобученных с помощью BYOL, после которых идет слой MultiheadAttention, а затем классификационная голова.
Во время предобучения VGGLike энкодера последним слоем был AdaptiveAvgPool2d. Этот слой не использовался во время обучения классфикатора, так как на вход MultiheadAttention требовалась информативная карта признаков (я использовала 8x8).

Таким образом, когда на вход поступает два изображения, каждое из них проходит через энкодер, после чего происходит cross-attention между ними, и на выходе получается вероятность того, что изображения содержат одинаковую фигуру.
Это не самый сложный подход, который можно было придумать, но он позволяет взглянуть на задачу под другим углом - в реальности у нас нет датасета с неограниченным количеством размеченных данных, и нужно уметь работать с тем, что есть.

### Метрики

Вот здесь можно посмотреть метрики в wandb:
- [Обучение BYOL](https://wandb.ai/alexandraroze/contrastive_learning_byol/reports/-BYOL--VmlldzoxMTQzMjA1Mw?accessToken=nh0kzpepsr0faflptx63n91kljc5wl6mt3wi3ay4wxpjmua55bf32nm36qjby0ai)
- [Обучение cross-attention классификатора](https://api.wandb.ai/links/alexandraroze/hmtnzhv9)


## Другие подходы


### Swin transformer


1. Каждое 32×32 изображение делится на патчи размером 4×4, это даёт 64 патча на изображение.
Каждый патч выпрямляется и проходит через линейный слой для получения векторного представления.

2. Далее мы применим early fusion (так как если применить late fusion, нам придется применять cross-attention, чтобы действительно учесть взаимодействие между патчами из разных изображений).
После извлечения патчей из двух изображений мы просто конкатенируем их по оси последовательности, получая 128 токенов.

3. В window multi-head attention мы делим эту последовательность на окна фиксированного размера. Допустим, каждое окно включает 16 токенов подряд. Это значит, что фигура, находящаяся в определённом блоке патчей, будет анализироваться локально вместе со смежными патчами.
Применяем self-attention и затем сдвигаем окна (в целом, как и должно быть в swin blockе).
Дальше идет patch merging, и мы получаем 16 патчей на одно изображение (то есть 32 патча на два изображения).
Достаточно еще двух таких слоев (16 -> 4 -> 1), чтобы у нас остался один патч на изображение.

4. Далее мы используем global average pooling, и передаем выход в классификационную голову.

#### Почему я не стала реализовывать этот подход
Swin transformer хорошо сработает на крупных изображениях с мелкими деталями, но в данной задаче спустя всего 3 слоя мы уже получаем один токен на изображение.
В первом слое локальное внимание ограничено работает сразу для двух изображений, а в следующем слое остается уже не так много токенов, чтобы извлекать информацию о фигурах.



### Siamese Network с Triplet Loss

Вместо простой классификации мы обучаем энкодер, который преобразует изображения в эмбеддинги так, чтобы похожие изображения (круг-круг, квадрат-квадрат) были ближе друг к другу, а разные изображения (круг-квадрат) были дальше.

Используем Triplet Loss, где берём три изображения:
- Anchor – произвольное изображение (например, квадрат).
- Positive – ещё один квадрат.
- Negative – круг.

Модель минимизирует расстояние между anchor и positive и максимизирует его для negative. Чтобы модели было сложнее, используем hard negatives. Например, генерировать изображение с одинаковыми характеристиками, такими как положение фигуры, цвет, блюр, но с другой фигурой.

Используем легкий shared CNN энкодер. Энкодер обрабатывает изображения независимо, но выходные эмбеддинги сравниваются через triplet loss.
Важно, чтобы размерность эмбеддинга была достаточно низкой, чтобы не переобучиться на простой структуре.

Получаем эмбеддинги двух изображений, считаем евклидово расстояние. 
Если меньше порога - фигуры одинаковые, иначе разные.


#### Почему я не стала реализовывать этот подход
Я уже делала это на своей работе, поэтому хотелось попробовать что-нибудь новое :)


### Проблемы с которыми я столкнулась
1. Вначале я решила написать и обучить всю архитектуру целиком (энкодер + cross-attention классификатор), но сразу же столкнулась с тем, что модель просто не обучалась.
   Чтобы это отдебажить, я решила начать с малого - создала простой датасет и научила простую CNN предсказывать метку для двух изображений сразу.
   Дальше, я добавляла углубление в энкодер, параллельно мониторя количество параметров, чтобы понимать, какое количество сэмплов мне нужно для обучения.
   Таким образом, я дошла до финальной архитектуры. 
2. У меня все еще сохранялись проблемы во время обучения (сеть не обучалась). Мониторинг нормы градиентов и весов помог мне понять, что веса из-за attention просто зануляются. Это я решила изменением оптимизатора на AdamW и уменьшением learning rate. 
3. Изначально планировалось показать улучшения в обучении с помощью self-supervised обучения (в сравнении с обучением с нуля), но по факту при тех же самых условиях обучение проходило одинаково в обоих случаях.
Это можно объяснить тем, что изображения были слишком маленькими и простыми, и подходу без предобучения также не требовалось много времени и большого количества данных. 
Чтобы self-supervised метод действительно хорошо сработал (особенно без негативных примеров), нужны сложные аугментации, в этом же случае сложные аугментации сильно искажали изображения.
4. Судя по кривой обучения, итоговый классификатор очень долгое время находился на плато, так как первые 7-8 эпох из 10 лосс не падал, а точность оставалась на уровне 50%. Это можно объяснить тем, что градиенты очень маленькие или очень шумные. Также, все зависит от исходной инициализации, и при маленьком датасете это может стать проблемой, так как по началу накапливается недостаточно сигналов, чтобы сойти с плато.


## Что бы я точно не стала делать
Здесь я опишу подходы, которые сразу пришли мне в голову, но которые я бы точно не стала делать по итогу.
Опять же, я отталкивалась от переноса задачи на реальные данные.

1. Сверточная сеть, которая принимает на вход одно изображение, и выдает для него класс (круг или квадрат). 
Соответственно, получив предсказания для двух изображений, мы можем сделать вывод о том, содержат ли они одинаковую фигуру.
Этот подход очень простой, решает задачу в лоб, но он не масштабируется, так как при переносе на реальные кейсы терпит крах, потому что далеко не всегда у нас есть два четко разделенных класса (да и в целом у нас может и не быть классов, а только изображения, которые нужно сопоставить между собой).

2. Детекция + классификация.
Можно было достаточно просто обучить детектор, который находил бы как определенный класс (круг или квадрат), так и просто "фигуру" без класса (казалось бы, решение предыдущей проблемы).
В реальности же этот подход тоже не масштабируется, так как 1) это дорогостоящая разметка, 2) детекторы могут ошибаться (и для таких кейсов мы бы тогда вообще могли ничего не предсказать), 3) задача может состоять в сравнении нескольких разнородных объектов на изображении, а не одного (например, образование новых опухолей).
То же само касается и сегментации.


Здесь стоит сделать важную поправку, что есть реальные задачи, где эти подходы могут сработать (например, мы точно знаем, что на изображении нас интересует только один объект, а все остальное - неинформативный фон).