cels / README.md
alexandraroze's picture
solution
50bd1fc

A newer version of the Streamlit SDK is available: 1.48.0

Upgrade
metadata
title: Cels
emoji: 🌖
colorFrom: blue
colorTo: purple
sdk: streamlit
sdk_version: 1.42.1
app_file: app.py
pinned: false

Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference

Cross Attention Classifier

Ниже технические детали того, как устроен репозиторий и как обучить модель. В самой последней секции "Описание подходов" подробно описано, как я пришла к этому подходу, с какими проблемами встретилась, а также описаны два других подхода, которые я решила не реализовывать (будет мини-эссе, готовьтесь).

Описание проекта

В проекте используется self-supervised обучение (BYOL) и последующая классификацию изображений с помощью Cross Attention.

  • train_byol.py — скрипт для обучения модели-энкодера по методу BYOL.
  • train_cross_classifier.py — скрипт для обучения классификатора, который использует предварительно обученный энкодер и Cross Attention.
  • app.py — Streamlit-приложение для инференса и визуализации предсказаний (генерация случайных изображений и получение метки от модели).

Структура репозитория

.
├── src
│   ├── dataset.py            # Реализация датасетов (RandomAugmentedDataset и RandomPairDataset)
│   ├── inference.py          # Класс для инференса (CrossAttentionInference) и вспомогательные методы
│   └── models.py             # Определения моделей (BYOL, VGGLikeEncode, CrossAttentionClassifier)
├── train_byol.py             # Скрипт обучения модели BYOL
├── train_cross_classifier.py # Скрипт обучения Cross Attention Classifier (использует готовый энкодер)
├── app.py                    # Streamlit-приложение для инференса
├── requirements.txt          # Список Python-зависимостей (pip install -r requirements.txt)
└── pyproject.toml / poetry.lock # Файл для установки зависимостей через Poetry

Установка зависимостей

Можно установить зависимости двумя способами:

  1. Через pip и requirements.txt:
    pip install -r requirements.txt
    
  2. Через Poetry:
    poetry install
    

Как обучить модель

1. Обучение энкодера с помощью BYOL

Нужно запустить:

python train_byol.py
  • Этот скрипт обучает модель энкодера (VGGLikeEncode) методом BYOL на данных, сгенерированных RandomAugmentedDataset.
  • После обучения лучшая модель (с минимальным val_loss) сохраняется в best_byol.pth.

2. Обучение Cross Attention Classifier

best_byol.pth (веса энкодера) должны лежать в корневой папке (можно указать другой путь). Затем нужно запустить:

python train_cross_classifier.py
  • Этот скрипт использует предобученный энкодер и обучает классификатор для определения, содержат ли картинки одинаковую геометрическую фигуру.
  • По итогам сохранит веса модели-классификатора в best_attention_classifier.pth.

Как запустить инференс

Запуск через Streamlit-приложение

  1. Файл весов best_attention_classifier.pth должен лежать в корневой папке
  2. Нужно запустить Streamlit-приложение:
    streamlit run app.py
    
  3. Дальше, нужно перейти по адресу, который выдаст Streamlit (по умолчанию http://localhost:8501).
  4. Нажмите кнопку «Сгенерировать изображения». Приложение сгенерирует пару случайных изображений и покажет предсказанную моделью метку.

Использование класса инференса в коде напрямую

Можно использовать модель напрямую (без интерфейса Streamlit), импортируйте класс из src/inference.py, передайте путь к весам модели и вызовите метод предсказания. Пример:

import torch
from src.inference import CrossAttentionInference

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

inference = CrossAttentionInference(
    model_path="best_attention_classifier.pth",
    device=device
)

pred_label, (img1, img2) = inference.predict_random_pair()
print(f"Предсказанная метка: {pred_label}")

Описание подходов

BYOL + Cross-attention (выбранный подход)

Когда я обдумывала финальную архитектуру, я поняла, что у креативности нет предела, поэтому каждое решение должно быть обосновано не только фразой "прикольно, можно попробовать", но и существующими проблемами, которые хочется решить. И я решила отталкиваться от реальных задач, а именно от проблемы с отсутствием данных в медицинской сфере. В текущей задаче такой проблемы, очевидно, нет, мы можем сгенерировать хоть миллион изображений и для всех будет лейбл. Но вот что, если у нас нет возможности сгенерировать миллион изображений? Или если у нас есть 100к изображений, но только 10000 из них размечены? Например, у нас есть неплохой банк изображений с разными опухолями, но размечены только 10% из них. Как можно использовать эти данные для обучения модели, чтобы она могла классифицировать новые изображения? Саму же задачу можно перенести на задачу вида "мониторинг прогрессирования заболевания" или "сравнение патологий".

Поэтому я решила использовать self-supervised обучение для того, чтобы обучить модель на неразмеченных данных и затем дообучить ее на небольшом датасете с разметкой.

BYOL

BYOL Bootstrap Your Own Latent — это метод self-supervised обучения, который позволяет обучить модель на неразмеченных данных. Важной особенностью конкретно этого подхода заключается в том, что этот метод не требует негативных пар для обучения, как некоторые другие contrastive методы. В BYOL используется две копии одной и той же модели, которые обучаются предсказывать друг друга на основе двух views (аугментаций) одного изображения. Не схлапываться в один вектор помогает то, что архитектура не симметрична, так как в одной из веток добавляется MLP предиктор, а также stop gradient операция. В итоге модель учится извлекать признаки из изображения, которые можно использовать для дообучения на меньшем датасете.
В данном случае я использовала энкодер с похожей на VGG архитектурой. Выбрала VGG я потому, что использовать, например, ResNet со skip-connection нет смысла, так как изображение всего ли 32x32, и через несколько слоев feature мапа была уже 8x8.

В целом сама задача заставляет балансировать между сложными подходами и реальной возможностью обучить модель на подобном датасете, так как реализовать можно (почти) что угодно, но переобучиться на таком датасете достаточно легко.

Изначально я планировала добавить в датасет для предобучения другие фигуры (треугольники, звездочки и тд), но сами эти фигуры занимают несколько пикселей, и аугментации их сильно искажают. В целом на таких маленьких изображениях почти все аугментации становятся агрессивными. Поэтому я остановилась на двух фигурах, но добавила в аугментации реверс цвета, повороты, сдвигы, гауссовский шум и тд.

Cross Attention

Честно скажу, именно на этот подход меня вдохновила задача с собеседования, где нужно было сопоставить два снимка одной и той же области. Я нашла статью - An Adaptive Remote Sensing Image-Matching Network Based on Cross Attention and Deformable Convolution где авторы решают похожую задачу (они тоже кстати используют VGG), но для более сложных изображений (сопоставление фотографий со спутника). Я помню в чем заключается проблема cross-attention - у него квадратичная сложность, и если изображение имеет размер 512x512, то это уже становится проблемой. Но так как в задаче изображения 32x32, я решила, что будет уместно применить данный подход (в предложенных подходах дальше я опишу, как бы решала задачу, если бы изображения были больше). Также, я добавила position эмбеддинги, так как при переходе к cross-attention информация о позиции теряется.

Почему cross-attention? Он позволяет каждому пикселю (точнее, каждому патчу) в одном изображении "смотреть" на все патчи в другом изображении. Таким образом, если фигуры находятся в противоположных углах, модель это учтет. Ну и плюс тенденции последних лет - внимание, внимание, внимание.

Итоговая архитектура

Сама архитектура представляет собой два VGGLike энкодера с shared весами, предобученных с помощью BYOL, после которых идет слой MultiheadAttention, а затем классификационная голова. Во время предобучения VGGLike энкодера последним слоем был AdaptiveAvgPool2d. Этот слой не использовался во время обучения классфикатора, так как на вход MultiheadAttention требовалась информативная карта признаков (я использовала 8x8).

Таким образом, когда на вход поступает два изображения, каждое из них проходит через энкодер, после чего происходит cross-attention между ними, и на выходе получается вероятность того, что изображения содержат одинаковую фигуру. Это не самый сложный подход, который можно было придумать, но он позволяет взглянуть на задачу под другим углом - в реальности у нас нет датасета с неограниченным количеством размеченных данных, и нужно уметь работать с тем, что есть.

Метрики

Вот здесь можно посмотреть метрики в wandb:

Другие подходы

Swin transformer

  1. Каждое 32×32 изображение делится на патчи размером 4×4, это даёт 64 патча на изображение. Каждый патч выпрямляется и проходит через линейный слой для получения векторного представления.

  2. Далее мы применим early fusion (так как если применить late fusion, нам придется применять cross-attention, чтобы действительно учесть взаимодействие между патчами из разных изображений). После извлечения патчей из двух изображений мы просто конкатенируем их по оси последовательности, получая 128 токенов.

  3. В window multi-head attention мы делим эту последовательность на окна фиксированного размера. Допустим, каждое окно включает 16 токенов подряд. Это значит, что фигура, находящаяся в определённом блоке патчей, будет анализироваться локально вместе со смежными патчами. Применяем self-attention и затем сдвигаем окна (в целом, как и должно быть в swin blockе). Дальше идет patch merging, и мы получаем 16 патчей на одно изображение (то есть 32 патча на два изображения). Достаточно еще двух таких слоев (16 -> 4 -> 1), чтобы у нас остался один патч на изображение.

  4. Далее мы используем global average pooling, и передаем выход в классификационную голову.

Почему я не стала реализовывать этот подход

Swin transformer хорошо сработает на крупных изображениях с мелкими деталями, но в данной задаче спустя всего 3 слоя мы уже получаем один токен на изображение. В первом слое локальное внимание ограничено работает сразу для двух изображений, а в следующем слое остается уже не так много токенов, чтобы извлекать информацию о фигурах.

Siamese Network с Triplet Loss

Вместо простой классификации мы обучаем энкодер, который преобразует изображения в эмбеддинги так, чтобы похожие изображения (круг-круг, квадрат-квадрат) были ближе друг к другу, а разные изображения (круг-квадрат) были дальше.

Используем Triplet Loss, где берём три изображения:

  • Anchor – произвольное изображение (например, квадрат).
  • Positive – ещё один квадрат.
  • Negative – круг.

Модель минимизирует расстояние между anchor и positive и максимизирует его для negative. Чтобы модели было сложнее, используем hard negatives. Например, генерировать изображение с одинаковыми характеристиками, такими как положение фигуры, цвет, блюр, но с другой фигурой.

Используем легкий shared CNN энкодер. Энкодер обрабатывает изображения независимо, но выходные эмбеддинги сравниваются через triplet loss. Важно, чтобы размерность эмбеддинга была достаточно низкой, чтобы не переобучиться на простой структуре.

Получаем эмбеддинги двух изображений, считаем евклидово расстояние. Если меньше порога - фигуры одинаковые, иначе разные.

Почему я не стала реализовывать этот подход

Я уже делала это на своей работе, поэтому хотелось попробовать что-нибудь новое :)

Проблемы с которыми я столкнулась

  1. Вначале я решила написать и обучить всю архитектуру целиком (энкодер + cross-attention классификатор), но сразу же столкнулась с тем, что модель просто не обучалась. Чтобы это отдебажить, я решила начать с малого - создала простой датасет и научила простую CNN предсказывать метку для двух изображений сразу. Дальше, я добавляла углубление в энкодер, параллельно мониторя количество параметров, чтобы понимать, какое количество сэмплов мне нужно для обучения. Таким образом, я дошла до финальной архитектуры.
  2. У меня все еще сохранялись проблемы во время обучения (сеть не обучалась). Мониторинг нормы градиентов и весов помог мне понять, что веса из-за attention просто зануляются. Это я решила изменением оптимизатора на AdamW и уменьшением learning rate.
  3. Изначально планировалось показать улучшения в обучении с помощью self-supervised обучения (в сравнении с обучением с нуля), но по факту при тех же самых условиях обучение проходило одинаково в обоих случаях. Это можно объяснить тем, что изображения были слишком маленькими и простыми, и подходу без предобучения также не требовалось много времени и большого количества данных. Чтобы self-supervised метод действительно хорошо сработал (особенно без негативных примеров), нужны сложные аугментации, в этом же случае сложные аугментации сильно искажали изображения.
  4. Судя по кривой обучения, итоговый классификатор очень долгое время находился на плато, так как первые 7-8 эпох из 10 лосс не падал, а точность оставалась на уровне 50%. Это можно объяснить тем, что градиенты очень маленькие или очень шумные. Также, все зависит от исходной инициализации, и при маленьком датасете это может стать проблемой, так как по началу накапливается недостаточно сигналов, чтобы сойти с плато.

Что бы я точно не стала делать

Здесь я опишу подходы, которые сразу пришли мне в голову, но которые я бы точно не стала делать по итогу. Опять же, я отталкивалась от переноса задачи на реальные данные.

  1. Сверточная сеть, которая принимает на вход одно изображение, и выдает для него класс (круг или квадрат). Соответственно, получив предсказания для двух изображений, мы можем сделать вывод о том, содержат ли они одинаковую фигуру. Этот подход очень простой, решает задачу в лоб, но он не масштабируется, так как при переносе на реальные кейсы терпит крах, потому что далеко не всегда у нас есть два четко разделенных класса (да и в целом у нас может и не быть классов, а только изображения, которые нужно сопоставить между собой).

  2. Детекция + классификация. Можно было достаточно просто обучить детектор, который находил бы как определенный класс (круг или квадрат), так и просто "фигуру" без класса (казалось бы, решение предыдущей проблемы). В реальности же этот подход тоже не масштабируется, так как 1) это дорогостоящая разметка, 2) детекторы могут ошибаться (и для таких кейсов мы бы тогда вообще могли ничего не предсказать), 3) задача может состоять в сравнении нескольких разнородных объектов на изображении, а не одного (например, образование новых опухолей). То же само касается и сегментации.

Здесь стоит сделать важную поправку, что есть реальные задачи, где эти подходы могут сработать (например, мы точно знаем, что на изображении нас интересует только один объект, а все остальное - неинформативный фон).