Spaces:
Runtime error
Runtime error
changed to pipelines
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __pycache__/constants.cpython-39.pyc +0 -0
- __pycache__/models.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +55 -60
- detection_models/__init__.py +0 -0
- detection_models/__pycache__/__init__.cpython-39.pyc +0 -0
- detection_models/yolo_stamp/__init__.py +0 -0
- detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc +0 -0
- detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc +0 -0
- detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc +0 -0
- detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc +0 -0
- constants.py → detection_models/yolo_stamp/constants.py +0 -8
- detection_models/yolo_stamp/data.py +141 -0
- detection_models/yolo_stamp/loss.py +52 -0
- detection_models/yolo_stamp/model.py +80 -0
- detection_models/yolo_stamp/train.ipynb +185 -0
- detection_models/yolo_stamp/utils.py +275 -0
- detection_models/yolov8/__init__.py +0 -0
- detection_models/yolov8/train.ipynb +144 -0
- embedding_models/__init__.py +0 -0
- embedding_models/__pycache__/__init__.cpython-39.pyc +0 -0
- embedding_models/vae/__init__.py +0 -0
- embedding_models/vae/__pycache__/__init__.cpython-39.pyc +0 -0
- embedding_models/vae/__pycache__/constants.cpython-39.pyc +0 -0
- embedding_models/vae/__pycache__/model.cpython-39.pyc +0 -0
- embedding_models/vae/constants.py +6 -0
- embedding_models/vae/losses.py +77 -0
- embedding_models/vae/model.py +147 -0
- embedding_models/vae/train.ipynb +393 -0
- embedding_models/vits8/__init__.py +0 -0
- embedding_models/vits8/example.py +10 -0
- embedding_models/vits8/model.py +13 -0
- embedding_models/vits8/oml/__init__.py +0 -0
- embedding_models/vits8/oml/create_dataset.py +71 -0
- embedding_models/vits8/oml/data/test/images/99d_15.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99e_20.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99f_25.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99g_30.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99h_35.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99i_40.bmp +0 -0
- embedding_models/vits8/oml/data/train_val/df_stamps.csv +41 -0
- embedding_models/vits8/oml/data/train_val/images/circle12_1239.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle15_1244.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle19_1236.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle21_1249.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle24_1234.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle24_1246.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle2_1233.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle2_1246.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle2_1249.png +0 -0
__pycache__/constants.cpython-39.pyc
ADDED
|
Binary file (660 Bytes). View file
|
|
|
__pycache__/models.cpython-39.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (7.73 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,102 +1,97 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import numpy as np
|
| 3 |
-
from ultralytics import YOLO
|
| 4 |
-
from torchvision.transforms.functional import to_tensor
|
| 5 |
-
from huggingface_hub import hf_hub_download
|
| 6 |
import torch
|
| 7 |
-
import
|
| 8 |
-
from albumentations.pytorch.transforms import ToTensorV2
|
| 9 |
-
import pandas as pd
|
| 10 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
from
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
yolov8 = YOLO(hf_hub_download('stamps-labs/yolov8-finetuned', filename='best.torchscript'), task='detect')
|
| 19 |
|
| 20 |
-
|
| 21 |
-
yolo_stamp.load_state_dict(torch.load(hf_hub_download('stamps-labs/yolo-stamp', filename='state_dict.pth'), map_location='cpu'))
|
| 22 |
-
yolo_stamp = yolo_stamp.to(device)
|
| 23 |
-
yolo_stamp.eval()
|
| 24 |
-
transform = A.Compose([
|
| 25 |
-
A.Normalize(),
|
| 26 |
-
ToTensorV2(p=1.0),
|
| 27 |
-
])
|
| 28 |
|
| 29 |
-
vits8 = torch.jit.load(hf_hub_download('stamps-labs/vits8-stamp', filename='vits8stamp-torchscript.pth'), map_location='cpu')
|
| 30 |
-
vits8 = vits8.to(device)
|
| 31 |
-
vits8.eval()
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
-
def
|
| 40 |
|
| 41 |
-
shape = torch.tensor(image.size)
|
| 42 |
image = image.convert('RGB')
|
| 43 |
|
| 44 |
if det_choice == 'yolov8':
|
| 45 |
-
|
| 46 |
-
image = image.resize((640, 640))
|
| 47 |
-
boxes = yolov8(image)[0].boxes.xyxy.cpu()
|
| 48 |
-
image_with_boxes = visualize_bbox(image, boxes)
|
| 49 |
|
| 50 |
elif det_choice == 'yolo-stamp':
|
| 51 |
-
|
| 52 |
-
image = image.resize((448, 448))
|
| 53 |
-
image_tensor = transform(image=np.array(image))['image']
|
| 54 |
-
output = yolo_stamp(image_tensor.unsqueeze(0).to(device))
|
| 55 |
-
|
| 56 |
-
boxes = output_tensor_to_boxes(output[0].detach().cpu())
|
| 57 |
-
boxes = nonmax_suppression(boxes)
|
| 58 |
-
boxes = xywh2xyxy(torch.tensor(boxes)[:, :4])
|
| 59 |
-
image_with_boxes = visualize_bbox(image, boxes)
|
| 60 |
else:
|
| 61 |
return
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
embeddings = []
|
| 65 |
if emb_choice == 'vits8':
|
| 66 |
-
for
|
| 67 |
-
|
| 68 |
-
embeddings.append(vits8(cropped_stamp.unsqueeze(0).to(device))[0].detach().cpu())
|
| 69 |
|
| 70 |
elif emb_choice == 'vae-encoder':
|
| 71 |
-
for
|
| 72 |
-
|
| 73 |
-
embeddings.append(np.array(encoder(cropped_stamp.unsqueeze(0).to(device))[0][0].detach().cpu()))
|
| 74 |
|
| 75 |
embeddings = np.stack(embeddings)
|
| 76 |
|
| 77 |
similarities = cosine_similarity(embeddings)
|
| 78 |
|
| 79 |
-
boxes = boxes * coef
|
| 80 |
df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
|
| 81 |
|
| 82 |
fig, ax = plt.subplots()
|
| 83 |
im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
|
| 84 |
cmap="YlGn", cbarlabel="Embeddings similarities")
|
| 85 |
texts = annotate_heatmap(im, valfmt="{x:.3f}")
|
| 86 |
-
return image_with_boxes, df_boxes, embeddings, fig
|
| 87 |
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
gr.Image(type="pil"),
|
| 92 |
gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
|
|
|
|
| 93 |
gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
|
| 94 |
]
|
| 95 |
-
|
| 96 |
-
gr.Image(type="pil"),
|
| 97 |
gr.DataFrame(type='pandas', label="Bounding boxes"),
|
|
|
|
| 98 |
gr.DataFrame(type='numpy', label="Embeddings"),
|
| 99 |
gr.Plot(label="Cosine Similarities")
|
| 100 |
]
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
+
import numpy as np
|
|
|
|
|
|
|
| 4 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from PIL import Image, ImageDraw
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import matplotlib
|
| 9 |
|
| 10 |
+
from pipelines.detection.yolo_v8 import Yolov8Pipeline
|
| 11 |
+
from pipelines.detection.yolo_stamp import YoloStampPipeline
|
| 12 |
+
from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline
|
| 13 |
+
from pipelines.feature_extraction.vae import VaePipeline
|
| 14 |
+
from pipelines.feature_extraction.vits8 import Vits8Pipeline
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
from utils import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
yolov8 = Yolov8Pipeline.from_pretrained(local_model_path='yolov8_old_backup.pt')
|
| 20 |
+
yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt')
|
| 21 |
+
vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt')
|
| 22 |
+
vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt')
|
| 23 |
+
dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt')
|
| 24 |
|
| 25 |
|
| 26 |
+
def doc_predict(image, det_choice, seg_choice, emb_choice):
|
| 27 |
|
|
|
|
| 28 |
image = image.convert('RGB')
|
| 29 |
|
| 30 |
if det_choice == 'yolov8':
|
| 31 |
+
boxes = yolov8(image)
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
elif det_choice == 'yolo-stamp':
|
| 34 |
+
boxes = yolo_stamp(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
else:
|
| 36 |
return
|
| 37 |
+
image_with_boxes = visualize_bbox(image, boxes)
|
| 38 |
+
|
| 39 |
+
segmented_stamps = []
|
| 40 |
+
for box in boxes:
|
| 41 |
+
cropped_stamp = image.crop(box.tolist())
|
| 42 |
+
segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)
|
| 43 |
+
|
| 44 |
+
widths, heights = zip(*(i.size for i in segmented_stamps))
|
| 45 |
+
|
| 46 |
+
total_width = sum(widths)
|
| 47 |
+
max_height = max(heights)
|
| 48 |
+
|
| 49 |
+
concatenated_stamps = Image.new('RGB', (total_width, max_height))
|
| 50 |
+
|
| 51 |
+
x_offset = 0
|
| 52 |
+
for im in segmented_stamps:
|
| 53 |
+
concatenated_stamps.paste(im, (x_offset,0))
|
| 54 |
+
x_offset += im.size[0]
|
| 55 |
|
| 56 |
embeddings = []
|
| 57 |
if emb_choice == 'vits8':
|
| 58 |
+
for stamp in segmented_stamps:
|
| 59 |
+
embeddings.append(vits8(stamp))
|
|
|
|
| 60 |
|
| 61 |
elif emb_choice == 'vae-encoder':
|
| 62 |
+
for stamp in segmented_stamps:
|
| 63 |
+
embeddings.append(vae(stamp))
|
|
|
|
| 64 |
|
| 65 |
embeddings = np.stack(embeddings)
|
| 66 |
|
| 67 |
similarities = cosine_similarity(embeddings)
|
| 68 |
|
|
|
|
| 69 |
df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
|
| 70 |
|
| 71 |
fig, ax = plt.subplots()
|
| 72 |
im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
|
| 73 |
cmap="YlGn", cbarlabel="Embeddings similarities")
|
| 74 |
texts = annotate_heatmap(im, valfmt="{x:.3f}")
|
| 75 |
+
return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig
|
| 76 |
|
| 77 |
|
| 78 |
+
doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']]
|
| 79 |
+
doc_inputs = [
|
| 80 |
+
gr.Image(label="Document image", type="pil"),
|
| 81 |
gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
|
| 82 |
+
gr.Checkbox(label="Use segmentation model"),
|
| 83 |
gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
|
| 84 |
]
|
| 85 |
+
doc_outputs = [
|
| 86 |
+
gr.Image(label="Document with bounding boxes", type="pil"),
|
| 87 |
gr.DataFrame(type='pandas', label="Bounding boxes"),
|
| 88 |
+
gr.Image(label="Segmented stamps", type="pil"),
|
| 89 |
gr.DataFrame(type='numpy', label="Embeddings"),
|
| 90 |
gr.Plot(label="Cosine Similarities")
|
| 91 |
]
|
| 92 |
+
|
| 93 |
+
with gr.Blocks() as demo:
|
| 94 |
+
with gr.Tab("Signle document"):
|
| 95 |
+
gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples)
|
| 96 |
+
|
| 97 |
+
demo.launch(inline=False)
|
detection_models/__init__.py
ADDED
|
File without changes
|
detection_models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (165 Bytes). View file
|
|
|
detection_models/yolo_stamp/__init__.py
ADDED
|
File without changes
|
detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc
ADDED
|
Binary file (611 Bytes). View file
|
|
|
detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (3.07 kB). View file
|
|
|
detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (9.31 kB). View file
|
|
|
constants.py → detection_models/yolo_stamp/constants.py
RENAMED
|
@@ -23,11 +23,3 @@ STD = (0.229, 0.224, 0.225)
|
|
| 23 |
MEAN = (0.485, 0.456, 0.406)
|
| 24 |
# box color to show the bounding box on image
|
| 25 |
BOX_COLOR = (0, 0, 255)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# dimenstion of image embedding
|
| 29 |
-
Z_DIM = 128
|
| 30 |
-
# hidden dimensions for encoder model
|
| 31 |
-
ENC_HIDDEN_DIM = 16
|
| 32 |
-
# hidden dimensions for decoder model
|
| 33 |
-
DEC_HIDDEN_DIM = 64
|
|
|
|
| 23 |
MEAN = (0.485, 0.456, 0.406)
|
| 24 |
# box color to show the bounding box on image
|
| 25 |
BOX_COLOR = (0, 0, 255)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
detection_models/yolo_stamp/data.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
import albumentations as A
|
| 6 |
+
from albumentations.pytorch.transforms import ToTensorV2
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from random import randint
|
| 11 |
+
|
| 12 |
+
from utils import *
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
Dataset class for storing stamps data.
|
| 16 |
+
|
| 17 |
+
Arguments:
|
| 18 |
+
data -- list of dictionaries containing file_path (path to the image), box_nb (number of boxes on the image), and boxes of shape (4,)
|
| 19 |
+
image_folder -- path to folder containing images
|
| 20 |
+
transforms -- transforms from albumentations package
|
| 21 |
+
"""
|
| 22 |
+
class StampDataset(Dataset):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
data=read_data(),
|
| 26 |
+
image_folder=Path(IMAGE_FOLDER),
|
| 27 |
+
transforms=None):
|
| 28 |
+
self.data = data
|
| 29 |
+
self.image_folder = image_folder
|
| 30 |
+
self.transforms = transforms
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
item = self.data[idx]
|
| 34 |
+
image_fn = self.image_folder / item['file_path']
|
| 35 |
+
boxes = item['boxes']
|
| 36 |
+
box_nb = item['box_nb']
|
| 37 |
+
labels = torch.zeros((box_nb, 2), dtype=torch.int64)
|
| 38 |
+
labels[:, 0] = 1
|
| 39 |
+
|
| 40 |
+
img = np.array(Image.open(image_fn))
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
if self.transforms:
|
| 44 |
+
sample = self.transforms(**{
|
| 45 |
+
"image":img,
|
| 46 |
+
"bboxes": boxes,
|
| 47 |
+
"labels": labels,
|
| 48 |
+
})
|
| 49 |
+
img = sample['image']
|
| 50 |
+
boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
|
| 51 |
+
except:
|
| 52 |
+
return self.__getitem__(randint(0, len(self.data)-1))
|
| 53 |
+
|
| 54 |
+
target_tensor = boxes_to_tensor(boxes.type(torch.float32))
|
| 55 |
+
return img, target_tensor
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.data)
|
| 59 |
+
|
| 60 |
+
def collate_fn(batch):
|
| 61 |
+
return tuple(zip(*batch))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_datasets(data_path=ANNOTATIONS_PATH, train_transforms=None, val_transforms=None):
|
| 65 |
+
"""
|
| 66 |
+
Creates StampDataset objects.
|
| 67 |
+
|
| 68 |
+
Arguments:
|
| 69 |
+
data_path -- string or Path, specifying path to annotations file
|
| 70 |
+
train_transforms -- transforms to be applied during training
|
| 71 |
+
val_transforms -- transforms to be applied during validation
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
(train_dataset, val_dataset) -- tuple of StampDataset for training and validation
|
| 75 |
+
"""
|
| 76 |
+
data = read_data(data_path)
|
| 77 |
+
if train_transforms is None:
|
| 78 |
+
train_transforms = A.Compose([
|
| 79 |
+
A.RandomCropNearBBox(max_part_shift=0.6, p=0.4),
|
| 80 |
+
A.Resize(height=448, width=448),
|
| 81 |
+
A.HorizontalFlip(p=0.5),
|
| 82 |
+
A.VerticalFlip(p=0.5),
|
| 83 |
+
# A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.1), rotate=(-45, 45), shear=(-30, 30), p=0.3),
|
| 84 |
+
# A.Blur(blur_limit=4, p=0.3),
|
| 85 |
+
A.Normalize(),
|
| 86 |
+
ToTensorV2(p=1.0),
|
| 87 |
+
],
|
| 88 |
+
bbox_params={
|
| 89 |
+
"format":"coco",
|
| 90 |
+
'label_fields': ['labels']
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
if val_transforms is None:
|
| 94 |
+
val_transforms = A.Compose([
|
| 95 |
+
A.Resize(height=448, width=448),
|
| 96 |
+
A.Normalize(),
|
| 97 |
+
ToTensorV2(p=1.0),
|
| 98 |
+
],
|
| 99 |
+
bbox_params={
|
| 100 |
+
"format":"coco",
|
| 101 |
+
'label_fields': ['labels']
|
| 102 |
+
})
|
| 103 |
+
train, test_data = train_test_split(data, test_size=0.1, shuffle=True)
|
| 104 |
+
|
| 105 |
+
train_data, val_data = train_test_split(train, test_size=0.2, shuffle=True)
|
| 106 |
+
|
| 107 |
+
train_dataset = StampDataset(train_data, transforms=train_transforms)
|
| 108 |
+
val_dataset = StampDataset(val_data, transforms=val_transforms)
|
| 109 |
+
test_dataset = StampDataset(test_data, transforms=val_transforms)
|
| 110 |
+
|
| 111 |
+
return train_dataset, val_dataset, test_dataset
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_loaders(batch_size=8, data_path=ANNOTATIONS_PATH, num_workers=0, train_transforms=None, val_transforms=None):
|
| 115 |
+
"""
|
| 116 |
+
Creates StampDataset objects.
|
| 117 |
+
|
| 118 |
+
Arguments:
|
| 119 |
+
batch_size -- integer specifying the number of images in the batch
|
| 120 |
+
data_path -- string or Path, specifying path to annotations file
|
| 121 |
+
train_transforms -- transforms to be applied during training
|
| 122 |
+
val_transforms -- transforms to be applied during validation
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
(train_loader, val_loader) -- tuple of DataLoader for training and validation
|
| 126 |
+
"""
|
| 127 |
+
train_dataset, val_dataset, _ = get_datasets(data_path)
|
| 128 |
+
|
| 129 |
+
train_loader = DataLoader(
|
| 130 |
+
train_dataset,
|
| 131 |
+
batch_size=batch_size,
|
| 132 |
+
shuffle=True,
|
| 133 |
+
num_workers=num_workers,
|
| 134 |
+
collate_fn=collate_fn, drop_last=True)
|
| 135 |
+
|
| 136 |
+
val_loader = DataLoader(
|
| 137 |
+
val_dataset,
|
| 138 |
+
batch_size=batch_size,
|
| 139 |
+
collate_fn=collate_fn)
|
| 140 |
+
|
| 141 |
+
return train_loader, val_loader
|
detection_models/yolo_stamp/loss.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from utils import *
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
Class for loss for training YOLO model.
|
| 7 |
+
|
| 8 |
+
Argmunets:
|
| 9 |
+
h_coord: weight for loss related to coordinates and shapes of box
|
| 10 |
+
h__noobj: weight for loss of predicting presence of box when it is absent.
|
| 11 |
+
"""
|
| 12 |
+
class YOLOLoss(nn.Module):
|
| 13 |
+
def __init__(self, h_coord=0.5, h_noobj=2., h_shape=2., h_obj=10.):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.h_coord = h_coord
|
| 16 |
+
self.h_noobj = h_noobj
|
| 17 |
+
self.h_shape = h_shape
|
| 18 |
+
self.h_obj = h_obj
|
| 19 |
+
|
| 20 |
+
def square_error(self, output, target):
|
| 21 |
+
return (output - target) ** 2
|
| 22 |
+
|
| 23 |
+
def forward(self, output, target):
|
| 24 |
+
|
| 25 |
+
pred_xy, pred_wh, pred_obj = yolo_head(output)
|
| 26 |
+
gt_xy, gt_wh, gt_obj = process_target(target)
|
| 27 |
+
|
| 28 |
+
pred_ul = pred_xy - 0.5 * pred_wh
|
| 29 |
+
pred_br = pred_xy + 0.5 * pred_wh
|
| 30 |
+
pred_area = pred_wh[..., 0] * pred_wh[..., 1]
|
| 31 |
+
|
| 32 |
+
gt_ul = gt_xy - 0.5 * gt_wh
|
| 33 |
+
gt_br = gt_xy + 0.5 * gt_wh
|
| 34 |
+
gt_area = gt_wh[..., 0] * gt_wh[..., 1]
|
| 35 |
+
|
| 36 |
+
intersect_ul = torch.max(pred_ul, gt_ul)
|
| 37 |
+
intersect_br = torch.min(pred_br, gt_br)
|
| 38 |
+
intersect_wh = intersect_br - intersect_ul
|
| 39 |
+
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
|
| 40 |
+
|
| 41 |
+
iou = intersect_area / (pred_area + gt_area - intersect_area)
|
| 42 |
+
max_iou = torch.max(iou, dim=3, keepdim=True)[0]
|
| 43 |
+
best_box_index = torch.unsqueeze(torch.eq(iou, max_iou).float(), dim=-1)
|
| 44 |
+
gt_box_conf = best_box_index * gt_obj
|
| 45 |
+
|
| 46 |
+
xy_loss = (self.square_error(pred_xy, gt_xy) * gt_box_conf).sum()
|
| 47 |
+
wh_loss = (self.square_error(pred_wh, gt_wh) * gt_box_conf).sum()
|
| 48 |
+
obj_loss = (self.square_error(pred_obj, gt_obj) * gt_box_conf).sum()
|
| 49 |
+
noobj_loss = (self.square_error(pred_obj, gt_obj) * (1 - gt_box_conf)).sum()
|
| 50 |
+
|
| 51 |
+
total_loss = self.h_coord * xy_loss + self.h_shape * wh_loss + self.h_obj * obj_loss + self.h_noobj * noobj_loss
|
| 52 |
+
return total_loss
|
detection_models/yolo_stamp/model.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .constants import *
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Class for custom activation.
|
| 8 |
+
"""
|
| 9 |
+
class SymReLU(nn.Module):
|
| 10 |
+
def __init__(self, inplace: bool = False):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.inplace = inplace
|
| 13 |
+
|
| 14 |
+
def forward(self, input):
|
| 15 |
+
return torch.min(torch.max(input, -torch.ones_like(input)), torch.ones_like(input))
|
| 16 |
+
|
| 17 |
+
def extra_repr(self) -> str:
|
| 18 |
+
inplace_str = 'inplace=True' if self.inplace else ''
|
| 19 |
+
return inplace_str
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
Class implementing YOLO-Stamp architecture described in https://link.springer.com/article/10.1134/S1054661822040046.
|
| 24 |
+
"""
|
| 25 |
+
class YOLOStamp(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
anchors=ANCHORS,
|
| 29 |
+
in_channels=3,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.register_buffer('anchors', torch.tensor(anchors))
|
| 34 |
+
|
| 35 |
+
self.act = SymReLU()
|
| 36 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 37 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 38 |
+
self.norm1 = nn.BatchNorm2d(num_features=8)
|
| 39 |
+
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 40 |
+
self.norm2 = nn.BatchNorm2d(num_features=16)
|
| 41 |
+
self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 42 |
+
self.norm3 = nn.BatchNorm2d(num_features=16)
|
| 43 |
+
self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 44 |
+
self.norm4 = nn.BatchNorm2d(num_features=16)
|
| 45 |
+
self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 46 |
+
self.norm5 = nn.BatchNorm2d(num_features=16)
|
| 47 |
+
self.conv6 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 48 |
+
self.norm6 = nn.BatchNorm2d(num_features=24)
|
| 49 |
+
self.conv7 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 50 |
+
self.norm7 = nn.BatchNorm2d(num_features=24)
|
| 51 |
+
self.conv8 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 52 |
+
self.norm8 = nn.BatchNorm2d(num_features=48)
|
| 53 |
+
self.conv9 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 54 |
+
self.norm9 = nn.BatchNorm2d(num_features=48)
|
| 55 |
+
self.conv10 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 56 |
+
self.norm10 = nn.BatchNorm2d(num_features=48)
|
| 57 |
+
self.conv11 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 58 |
+
self.norm11 = nn.BatchNorm2d(num_features=64)
|
| 59 |
+
self.conv12 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
| 60 |
+
self.norm12 = nn.BatchNorm2d(num_features=256)
|
| 61 |
+
self.conv13 = nn.Conv2d(in_channels=256, out_channels=len(anchors) * 5, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
| 62 |
+
|
| 63 |
+
def forward(self, x, head=True):
|
| 64 |
+
x = x.type(self.conv1.weight.dtype)
|
| 65 |
+
x = self.act(self.pool(self.norm1(self.conv1(x))))
|
| 66 |
+
x = self.act(self.pool(self.norm2(self.conv2(x))))
|
| 67 |
+
x = self.act(self.pool(self.norm3(self.conv3(x))))
|
| 68 |
+
x = self.act(self.pool(self.norm4(self.conv4(x))))
|
| 69 |
+
x = self.act(self.pool(self.norm5(self.conv5(x))))
|
| 70 |
+
x = self.act(self.norm6(self.conv6(x)))
|
| 71 |
+
x = self.act(self.norm7(self.conv7(x)))
|
| 72 |
+
x = self.act(self.pool(self.norm8(self.conv8(x))))
|
| 73 |
+
x = self.act(self.norm9(self.conv9(x)))
|
| 74 |
+
x = self.act(self.norm10(self.conv10(x)))
|
| 75 |
+
x = self.act(self.norm11(self.conv11(x)))
|
| 76 |
+
x = self.act(self.norm12(self.conv12(x)))
|
| 77 |
+
x = self.conv13(x)
|
| 78 |
+
nb, _, nh, nw= x.shape
|
| 79 |
+
x = x.permute(0, 2, 3, 1).view(nb, nh, nw, self.anchors.shape[0], 5)
|
| 80 |
+
return x
|
detection_models/yolo_stamp/train.ipynb
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from model import *\n",
|
| 10 |
+
"from loss import *\n",
|
| 11 |
+
"from data import *\n",
|
| 12 |
+
"from torch import optim\n",
|
| 13 |
+
"from tqdm import tqdm\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"import pytorch_lightning as pl\n",
|
| 16 |
+
"from torchmetrics.detection import MeanAveragePrecision\n",
|
| 17 |
+
"from pytorch_lightning.loggers import TensorBoardLogger"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "code",
|
| 22 |
+
"execution_count": 2,
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [],
|
| 25 |
+
"source": [
|
| 26 |
+
"_, _, test_dataset = get_datasets()"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": 3,
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"class LitModel(pl.LightningModule):\n",
|
| 36 |
+
" def __init__(self):\n",
|
| 37 |
+
" super().__init__()\n",
|
| 38 |
+
" self.model = YOLOStamp()\n",
|
| 39 |
+
" self.criterion = YOLOLoss()\n",
|
| 40 |
+
" self.val_map = MeanAveragePrecision(box_format='xywh', iou_type='bbox')\n",
|
| 41 |
+
" \n",
|
| 42 |
+
" def forward(self, x):\n",
|
| 43 |
+
" return self.model(x)\n",
|
| 44 |
+
"\n",
|
| 45 |
+
" def configure_optimizers(self):\n",
|
| 46 |
+
" optimizer = optim.AdamW(self.parameters(), lr=1e-3)\n",
|
| 47 |
+
" # return optimizer\n",
|
| 48 |
+
" scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n",
|
| 49 |
+
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
|
| 50 |
+
"\n",
|
| 51 |
+
" def training_step(self, batch, batch_idx):\n",
|
| 52 |
+
" images, targets = batch\n",
|
| 53 |
+
" tensor_images = torch.stack(images)\n",
|
| 54 |
+
" tensor_targets = torch.stack(targets)\n",
|
| 55 |
+
" output = self.model(tensor_images)\n",
|
| 56 |
+
" loss = self.criterion(output, tensor_targets)\n",
|
| 57 |
+
" self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n",
|
| 58 |
+
" return loss\n",
|
| 59 |
+
"\n",
|
| 60 |
+
" def validation_step(self, batch, batch_idx):\n",
|
| 61 |
+
" images, targets = batch\n",
|
| 62 |
+
" tensor_images = torch.stack(images)\n",
|
| 63 |
+
" tensor_targets = torch.stack(targets)\n",
|
| 64 |
+
" output = self.model(tensor_images)\n",
|
| 65 |
+
" loss = self.criterion(output, tensor_targets)\n",
|
| 66 |
+
" self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n",
|
| 67 |
+
"\n",
|
| 68 |
+
" for i in range(len(images)):\n",
|
| 69 |
+
" boxes = output_tensor_to_boxes(output[i].detach().cpu())\n",
|
| 70 |
+
" boxes = nonmax_suppression(boxes)\n",
|
| 71 |
+
" target = target_tensor_to_boxes(targets[i])[::BOX]\n",
|
| 72 |
+
" if not boxes:\n",
|
| 73 |
+
" boxes = torch.zeros((1, 5))\n",
|
| 74 |
+
" preds = [\n",
|
| 75 |
+
" dict(\n",
|
| 76 |
+
" boxes=torch.tensor(boxes)[:, :4].clone().detach(),\n",
|
| 77 |
+
" scores=torch.tensor(boxes)[:, 4].clone().detach(),\n",
|
| 78 |
+
" labels=torch.zeros(len(boxes)),\n",
|
| 79 |
+
" )\n",
|
| 80 |
+
" ]\n",
|
| 81 |
+
" target = [\n",
|
| 82 |
+
" dict(\n",
|
| 83 |
+
" boxes=torch.tensor(target),\n",
|
| 84 |
+
" labels=torch.zeros(len(target)),\n",
|
| 85 |
+
" )\n",
|
| 86 |
+
" ]\n",
|
| 87 |
+
" self.val_map.update(preds, target)\n",
|
| 88 |
+
" \n",
|
| 89 |
+
" def on_validation_epoch_end(self):\n",
|
| 90 |
+
" mAPs = {\"val_\" + k: v for k, v in self.val_map.compute().items()}\n",
|
| 91 |
+
" mAPs_per_class = mAPs.pop(\"val_map_per_class\")\n",
|
| 92 |
+
" mARs_per_class = mAPs.pop(\"val_mar_100_per_class\")\n",
|
| 93 |
+
" self.log_dict(mAPs)\n",
|
| 94 |
+
" self.val_map.reset()\n",
|
| 95 |
+
"\n",
|
| 96 |
+
" image = test_dataset[randint(0, len(test_dataset) - 1)][0].to(self.device)\n",
|
| 97 |
+
" output = self.model(image.unsqueeze(0))\n",
|
| 98 |
+
" boxes = output_tensor_to_boxes(output[0].detach().cpu())\n",
|
| 99 |
+
" boxes = nonmax_suppression(boxes)\n",
|
| 100 |
+
" img = image.permute(1, 2, 0).cpu().numpy()\n",
|
| 101 |
+
" img = visualize_bbox(img.copy(), boxes=boxes)\n",
|
| 102 |
+
" img = (255. * (img * np.array(STD) + np.array(MEAN))).astype(np.uint8)\n",
|
| 103 |
+
" \n",
|
| 104 |
+
" self.logger.experiment.add_image(\"detected boxes\", torch.tensor(img).permute(2, 0, 1), self.current_epoch)\n"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": 4,
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"litmodel = LitModel()"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "code",
|
| 118 |
+
"execution_count": 5,
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"outputs": [],
|
| 121 |
+
"source": [
|
| 122 |
+
"logger = TensorBoardLogger(\"detection_logs\")"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "code",
|
| 127 |
+
"execution_count": 7,
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [],
|
| 130 |
+
"source": [
|
| 131 |
+
"epochs = 100"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "code",
|
| 136 |
+
"execution_count": 8,
|
| 137 |
+
"metadata": {},
|
| 138 |
+
"outputs": [],
|
| 139 |
+
"source": [
|
| 140 |
+
"train_loader, val_loader = get_loaders(batch_size=8)"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"execution_count": null,
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"outputs": [],
|
| 148 |
+
"source": [
|
| 149 |
+
"trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n",
|
| 150 |
+
"trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)"
|
| 151 |
+
]
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"execution_count": null,
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"outputs": [],
|
| 158 |
+
"source": [
|
| 159 |
+
"%tensorboard"
|
| 160 |
+
]
|
| 161 |
+
}
|
| 162 |
+
],
|
| 163 |
+
"metadata": {
|
| 164 |
+
"kernelspec": {
|
| 165 |
+
"display_name": "Python 3",
|
| 166 |
+
"language": "python",
|
| 167 |
+
"name": "python3"
|
| 168 |
+
},
|
| 169 |
+
"language_info": {
|
| 170 |
+
"codemirror_mode": {
|
| 171 |
+
"name": "ipython",
|
| 172 |
+
"version": 3
|
| 173 |
+
},
|
| 174 |
+
"file_extension": ".py",
|
| 175 |
+
"mimetype": "text/x-python",
|
| 176 |
+
"name": "python",
|
| 177 |
+
"nbconvert_exporter": "python",
|
| 178 |
+
"pygments_lexer": "ipython3",
|
| 179 |
+
"version": "3.9.0"
|
| 180 |
+
},
|
| 181 |
+
"orig_nbformat": 4
|
| 182 |
+
},
|
| 183 |
+
"nbformat": 4,
|
| 184 |
+
"nbformat_minor": 2
|
| 185 |
+
}
|
detection_models/yolo_stamp/utils.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from .constants import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def output_tensor_to_boxes(boxes_tensor):
|
| 11 |
+
"""
|
| 12 |
+
Converts the YOLO output tensor to list of boxes with probabilites.
|
| 13 |
+
|
| 14 |
+
Arguments:
|
| 15 |
+
boxes_tensor -- tensor of shape (S, S, BOX, 5)
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
boxes -- list of shape (None, 5)
|
| 19 |
+
|
| 20 |
+
Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold.
|
| 21 |
+
For example, the actual output size of scores would be (10, 5) if there are 10 boxes
|
| 22 |
+
"""
|
| 23 |
+
cell_w, cell_h = W/S, H/S
|
| 24 |
+
boxes = []
|
| 25 |
+
|
| 26 |
+
for i in range(S):
|
| 27 |
+
for j in range(S):
|
| 28 |
+
for b in range(BOX):
|
| 29 |
+
anchor_wh = torch.tensor(ANCHORS[b])
|
| 30 |
+
data = boxes_tensor[i,j,b]
|
| 31 |
+
xy = torch.sigmoid(data[:2])
|
| 32 |
+
wh = torch.exp(data[2:4])*anchor_wh
|
| 33 |
+
obj_prob = torch.sigmoid(data[4])
|
| 34 |
+
|
| 35 |
+
if obj_prob > OUTPUT_THRESH:
|
| 36 |
+
x_center, y_center, w, h = xy[0], xy[1], wh[0], wh[1]
|
| 37 |
+
x, y = x_center+j-w/2, y_center+i-h/2
|
| 38 |
+
x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h
|
| 39 |
+
box = [x,y,w,h, obj_prob]
|
| 40 |
+
boxes.append(box)
|
| 41 |
+
return boxes
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def plot_img(img, size=(7,7)):
|
| 45 |
+
plt.figure(figsize=size)
|
| 46 |
+
plt.imshow(img)
|
| 47 |
+
plt.show()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def plot_normalized_img(img, std=STD, mean=MEAN, size=(7,7)):
|
| 51 |
+
mean = mean if isinstance(mean, np.ndarray) else np.array(mean)
|
| 52 |
+
std = std if isinstance(std, np.ndarray) else np.array(std)
|
| 53 |
+
plt.figure(figsize=size)
|
| 54 |
+
plt.imshow((255. * (img * std + mean)).astype(np.uint))
|
| 55 |
+
plt.show()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def visualize_bbox(img, boxes, thickness=2, color=BOX_COLOR, draw_center=True):
|
| 59 |
+
"""
|
| 60 |
+
Draws boxes on the given image.
|
| 61 |
+
|
| 62 |
+
Arguments:
|
| 63 |
+
img -- torch.Tensor of shape (3, W, H) or numpy.ndarray of shape (W, H, 3)
|
| 64 |
+
boxes -- list of shape (None, 5)
|
| 65 |
+
thickness -- number specifying the thickness of box border
|
| 66 |
+
color -- RGB tuple of shape (3,) specifying the color of boxes
|
| 67 |
+
draw_center -- boolean specifying whether to draw center or not
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
img_copy -- numpy.ndarray of shape(W, H, 3) containing image with bouning boxes
|
| 71 |
+
"""
|
| 72 |
+
img_copy = img.cpu().permute(1,2,0).numpy() if isinstance(img, torch.Tensor) else img.copy()
|
| 73 |
+
for box in boxes:
|
| 74 |
+
x,y,w,h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
| 75 |
+
img_copy = cv2.rectangle(
|
| 76 |
+
img_copy,
|
| 77 |
+
(x,y),(x+w, y+h),
|
| 78 |
+
color, thickness)
|
| 79 |
+
if draw_center:
|
| 80 |
+
center = (x+w//2, y+h//2)
|
| 81 |
+
img_copy = cv2.circle(img_copy, center=center, radius=3, color=(0,255,0), thickness=2)
|
| 82 |
+
return img_copy
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def read_data(annotations=Path(ANNOTATIONS_PATH)):
|
| 86 |
+
"""
|
| 87 |
+
Reads annotations data from .csv file. Must contain columns: image_name, bbox_x, bbox_y, bbox_width, bbox_height.
|
| 88 |
+
|
| 89 |
+
Arguments:
|
| 90 |
+
annotations_path -- string or Path specifying path of annotations file
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
data -- list of dictionaries containing path, number of boxes and boxes itself
|
| 94 |
+
"""
|
| 95 |
+
data = []
|
| 96 |
+
|
| 97 |
+
boxes = pd.read_csv(annotations)
|
| 98 |
+
image_names = boxes['image_name'].unique()
|
| 99 |
+
|
| 100 |
+
for image_name in image_names:
|
| 101 |
+
cur_boxes = boxes[boxes['image_name'] == image_name]
|
| 102 |
+
img_data = {
|
| 103 |
+
'file_path': image_name,
|
| 104 |
+
'box_nb': len(cur_boxes),
|
| 105 |
+
'boxes': []}
|
| 106 |
+
stamp_nb = img_data['box_nb']
|
| 107 |
+
if stamp_nb <= STAMP_NB_MAX:
|
| 108 |
+
img_data['boxes'] = cur_boxes[['bbox_x', 'bbox_y','bbox_width','bbox_height']].values
|
| 109 |
+
data.append(img_data)
|
| 110 |
+
return data
|
| 111 |
+
|
| 112 |
+
def xywh2xyxy(x):
|
| 113 |
+
"""
|
| 114 |
+
Converts xywh format to xyxy
|
| 115 |
+
|
| 116 |
+
Arguments:
|
| 117 |
+
x -- torch.Tensor or np.array (xywh format)
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
y -- torch.Tensor or np.array (xyxy)
|
| 121 |
+
"""
|
| 122 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 123 |
+
y[..., 0] = x[..., 0]
|
| 124 |
+
y[..., 1] = x[..., 1]
|
| 125 |
+
y[..., 2] = x[..., 0] + x[..., 2]
|
| 126 |
+
y[..., 3] = x[..., 1] + x[..., 3]
|
| 127 |
+
return y
|
| 128 |
+
|
| 129 |
+
def boxes_to_tensor(boxes):
|
| 130 |
+
"""
|
| 131 |
+
Convert list of boxes (and labels) to tensor format
|
| 132 |
+
|
| 133 |
+
Arguments:
|
| 134 |
+
boxes -- list of boxes
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
boxes_tensor -- tensor of shape (S, S, BOX, 5)
|
| 138 |
+
"""
|
| 139 |
+
boxes_tensor = torch.zeros((S, S, BOX, 5))
|
| 140 |
+
cell_w, cell_h = W/S, H/S
|
| 141 |
+
for i, box in enumerate(boxes):
|
| 142 |
+
x, y, w, h = box
|
| 143 |
+
# normalize xywh with cell_size
|
| 144 |
+
x, y, w, h = x / cell_w, y / cell_h, w / cell_w, h / cell_h
|
| 145 |
+
center_x, center_y = x + w / 2, y + h / 2
|
| 146 |
+
grid_x = int(np.floor(center_x))
|
| 147 |
+
grid_y = int(np.floor(center_y))
|
| 148 |
+
|
| 149 |
+
if grid_x < S and grid_y < S:
|
| 150 |
+
boxes_tensor[grid_y, grid_x, :, 0:4] = torch.tensor(BOX * [[center_x - grid_x, center_y - grid_y, w, h]])
|
| 151 |
+
boxes_tensor[grid_y, grid_x, :, 4] = torch.tensor(BOX * [1.])
|
| 152 |
+
return boxes_tensor
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def target_tensor_to_boxes(boxes_tensor, output_threshold=OUTPUT_THRESH):
|
| 156 |
+
"""
|
| 157 |
+
Recover target tensor (tensor output of dataset) to bboxes.
|
| 158 |
+
Arguments:
|
| 159 |
+
boxes_tensor -- tensor of shape (S, S, BOX, 5)
|
| 160 |
+
Returns:
|
| 161 |
+
boxes -- list of boxes, each box is [x, y, w, h]
|
| 162 |
+
"""
|
| 163 |
+
cell_w, cell_h = W/S, H/S
|
| 164 |
+
boxes = []
|
| 165 |
+
for i in range(S):
|
| 166 |
+
for j in range(S):
|
| 167 |
+
for b in range(BOX):
|
| 168 |
+
data = boxes_tensor[i,j,b]
|
| 169 |
+
x_center,y_center, w, h, obj_prob = data[0], data[1], data[2], data[3], data[4]
|
| 170 |
+
if obj_prob > output_threshold:
|
| 171 |
+
x, y = x_center+j-w/2, y_center+i-h/2
|
| 172 |
+
x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h
|
| 173 |
+
box = [x,y,w,h]
|
| 174 |
+
boxes.append(box)
|
| 175 |
+
return boxes
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def overlap(interval_1, interval_2):
|
| 179 |
+
"""
|
| 180 |
+
Calculates length of overlap between two intervals.
|
| 181 |
+
|
| 182 |
+
Arguments:
|
| 183 |
+
interval_1 -- list or tuple of shape (2,) containing endpoints of the first interval
|
| 184 |
+
interval_2 -- list or tuple of shape (2, 2) containing endpoints of the second interval
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
overlap -- length of overlap
|
| 188 |
+
"""
|
| 189 |
+
x1, x2 = interval_1
|
| 190 |
+
x3, x4 = interval_2
|
| 191 |
+
if x3 < x1:
|
| 192 |
+
if x4 < x1:
|
| 193 |
+
return 0
|
| 194 |
+
else:
|
| 195 |
+
return min(x2,x4) - x1
|
| 196 |
+
else:
|
| 197 |
+
if x2 < x3:
|
| 198 |
+
return 0
|
| 199 |
+
else:
|
| 200 |
+
return min(x2,x4) - x3
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def compute_iou(box1, box2):
|
| 204 |
+
"""
|
| 205 |
+
Compute IOU between box1 and box2.
|
| 206 |
+
|
| 207 |
+
Argmunets:
|
| 208 |
+
box1 -- list of shape (5, ). Represents the first box
|
| 209 |
+
box2 -- list of shape (5, ). Represents the second box
|
| 210 |
+
Each box is [x, y, w, h, prob]
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
iou -- intersection over union score between two boxes
|
| 214 |
+
"""
|
| 215 |
+
x1,y1,w1,h1 = box1[0], box1[1], box1[2], box1[3]
|
| 216 |
+
x2,y2,w2,h2 = box2[0], box2[1], box2[2], box2[3]
|
| 217 |
+
|
| 218 |
+
area1, area2 = w1*h1, w2*h2
|
| 219 |
+
intersect_w = overlap((x1,x1+w1), (x2,x2+w2))
|
| 220 |
+
intersect_h = overlap((y1,y1+h1), (y2,y2+w2))
|
| 221 |
+
if intersect_w == w1 and intersect_h == h1 or intersect_w == w2 and intersect_h == h2:
|
| 222 |
+
return 1.
|
| 223 |
+
intersect_area = intersect_w*intersect_h
|
| 224 |
+
iou = intersect_area/(area1 + area2 - intersect_area)
|
| 225 |
+
return iou
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def nonmax_suppression(boxes, iou_thresh = IOU_THRESH):
|
| 229 |
+
"""
|
| 230 |
+
Removes ovelap bboxes
|
| 231 |
+
|
| 232 |
+
Arguments:
|
| 233 |
+
boxes -- list of shape (None, 5)
|
| 234 |
+
iou_thresh -- maximal value of iou when boxes are considered different
|
| 235 |
+
Each box is [x, y, w, h, prob]
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
boxes -- list of shape (None, 5) with removed overlapping boxes
|
| 239 |
+
"""
|
| 240 |
+
boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
|
| 241 |
+
for i, current_box in enumerate(boxes):
|
| 242 |
+
if current_box[4] <= 0:
|
| 243 |
+
continue
|
| 244 |
+
for j in range(i+1, len(boxes)):
|
| 245 |
+
iou = compute_iou(current_box, boxes[j])
|
| 246 |
+
if iou > iou_thresh:
|
| 247 |
+
boxes[j][4] = 0
|
| 248 |
+
boxes = [box for box in boxes if box[4] > 0]
|
| 249 |
+
return boxes
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def yolo_head(yolo_output):
|
| 254 |
+
"""
|
| 255 |
+
Converts a yolo output tensor to separate tensors of coordinates, shapes and probabilities.
|
| 256 |
+
|
| 257 |
+
Arguments:
|
| 258 |
+
yolo_output -- tensor of shape (batch_size, S, S, BOX, 5)
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
xy -- tensor of shape (batch_size, S, S, BOX, 2) containing coordinates of centers of found boxes for each anchor in each grid cell
|
| 262 |
+
wh -- tensor of shape (batch_size, S, S, BOX, 2) containing width and height of found boxes for each anchor in each grid cell
|
| 263 |
+
prob -- tensor of shape (batch_size, S, S, BOX, 1) containing the probability of presence of boxes for each anchor in each grid cell
|
| 264 |
+
"""
|
| 265 |
+
xy = torch.sigmoid(yolo_output[..., 0:2])
|
| 266 |
+
anchors_wh = torch.tensor(ANCHORS, device=yolo_output.device).view(1, 1, 1, len(ANCHORS), 2)
|
| 267 |
+
wh = torch.exp(yolo_output[..., 2:4]) * anchors_wh
|
| 268 |
+
prob = torch.sigmoid(yolo_output[..., 4:5])
|
| 269 |
+
return xy, wh, prob
|
| 270 |
+
|
| 271 |
+
def process_target(target):
|
| 272 |
+
xy = target[..., 0:2]
|
| 273 |
+
wh = target[..., 2:4]
|
| 274 |
+
prob = target[..., 4:5]
|
| 275 |
+
return xy, wh, prob
|
detection_models/yolov8/__init__.py
ADDED
|
File without changes
|
detection_models/yolov8/train.ipynb
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import os\n",
|
| 10 |
+
"HOME = os.getcwd()\n",
|
| 11 |
+
"print(HOME)"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": null,
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"# Pip install method (recommended)\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"%pip install ultralytics==8.0.20\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"from IPython import display\n",
|
| 25 |
+
"display.clear_output()\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"import ultralytics\n",
|
| 28 |
+
"ultralytics.checks()"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "code",
|
| 33 |
+
"execution_count": null,
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"outputs": [],
|
| 36 |
+
"source": [
|
| 37 |
+
"from ultralytics import YOLO\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"from IPython.display import display, Image"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"outputs": [],
|
| 47 |
+
"source": [
|
| 48 |
+
"!mkdir {HOME}/datasets\n",
|
| 49 |
+
"%cd {HOME}/datasets\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"%pip install roboflow --quiet\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"from roboflow import Roboflow\n",
|
| 54 |
+
"rf = Roboflow(api_key=\"YOUR_API_KEY\")\n",
|
| 55 |
+
"project = rf.workspace(\"WORKSPACE\").project(\"PROJECT\")\n",
|
| 56 |
+
"dataset = project.version(1).download(\"yolov8\")"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": null,
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"%cd {HOME}\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"!yolo task=detect mode=train model=yolov8s.pt data={dataset.location}/data.yaml epochs=25 imgsz=800 plots=True"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": null,
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [],
|
| 75 |
+
"source": [
|
| 76 |
+
"%cd {HOME}\n",
|
| 77 |
+
"Image(filename=f'{HOME}/runs/detect/train/confusion_matrix.png', width=600)"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"%cd {HOME}\n",
|
| 87 |
+
"Image(filename=f'{HOME}/runs/detect/train/results.png', width=600)"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": [
|
| 96 |
+
"%cd {HOME}\n",
|
| 97 |
+
"Image(filename=f'{HOME}/runs/detect/train/val_batch0_pred.jpg', width=600)"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": null,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": [
|
| 106 |
+
"%cd {HOME}\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"!yolo task=detect mode=val model={HOME}/runs/detect/train/weights/best.pt data={dataset.location}/data.yaml"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "code",
|
| 113 |
+
"execution_count": null,
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"outputs": [],
|
| 116 |
+
"source": [
|
| 117 |
+
"%cd {HOME}\n",
|
| 118 |
+
"!yolo task=detect mode=predict model={HOME}/runs/detect/train/weights/best.pt conf=0.25 source={dataset.location}/test/images save=True"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "code",
|
| 123 |
+
"execution_count": null,
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"outputs": [],
|
| 126 |
+
"source": [
|
| 127 |
+
"import glob\n",
|
| 128 |
+
"from IPython.display import Image, display\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"for image_path in glob.glob(f'{HOME}/runs/detect/predict3/*.jpg')[:3]:\n",
|
| 131 |
+
" display(Image(filename=image_path, width=600))\n",
|
| 132 |
+
" print(\"\\n\")"
|
| 133 |
+
]
|
| 134 |
+
}
|
| 135 |
+
],
|
| 136 |
+
"metadata": {
|
| 137 |
+
"language_info": {
|
| 138 |
+
"name": "python"
|
| 139 |
+
},
|
| 140 |
+
"orig_nbformat": 4
|
| 141 |
+
},
|
| 142 |
+
"nbformat": 4,
|
| 143 |
+
"nbformat_minor": 2
|
| 144 |
+
}
|
embedding_models/__init__.py
ADDED
|
File without changes
|
embedding_models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (165 Bytes). View file
|
|
|
embedding_models/vae/__init__.py
ADDED
|
File without changes
|
embedding_models/vae/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
embedding_models/vae/__pycache__/constants.cpython-39.pyc
ADDED
|
Binary file (237 Bytes). View file
|
|
|
embedding_models/vae/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (5.87 kB). View file
|
|
|
embedding_models/vae/constants.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dimenstion of image embedding
|
| 2 |
+
Z_DIM = 128
|
| 3 |
+
# hidden dimensions for encoder model
|
| 4 |
+
ENC_HIDDEN_DIM = 16
|
| 5 |
+
# hidden dimensions for decoder model
|
| 6 |
+
DEC_HIDDEN_DIM = 64
|
embedding_models/vae/losses.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.distributions.kl import kl_divergence
|
| 4 |
+
from torch.distributions.normal import Normal
|
| 5 |
+
from torch.nn.functional import relu
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BatchHardTripletLoss(nn.Module):
|
| 10 |
+
def __init__(self, margin=1., squared=False, agg='sum'):
|
| 11 |
+
"""
|
| 12 |
+
Initalize the loss function with a margin parameter, whether or not to consider
|
| 13 |
+
squared Euclidean distance and how to aggregate the loss in a batch
|
| 14 |
+
"""
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.margin = margin
|
| 17 |
+
self.squared = squared
|
| 18 |
+
self.agg = agg
|
| 19 |
+
self.eps = 1e-8
|
| 20 |
+
|
| 21 |
+
def get_pairwise_distances(self, embeddings):
|
| 22 |
+
"""
|
| 23 |
+
Computing Euclidean distance for all possible pairs of embeddings.
|
| 24 |
+
"""
|
| 25 |
+
ab = embeddings.mm(embeddings.t())
|
| 26 |
+
a_squared = ab.diag().unsqueeze(1)
|
| 27 |
+
b_squared = ab.diag().unsqueeze(0)
|
| 28 |
+
distances = a_squared - 2 * ab + b_squared
|
| 29 |
+
distances = relu(distances)
|
| 30 |
+
|
| 31 |
+
if not self.squared:
|
| 32 |
+
distances = torch.sqrt(distances + self.eps)
|
| 33 |
+
|
| 34 |
+
return distances
|
| 35 |
+
|
| 36 |
+
def hardest_triplet_mining(self, dist_mat, labels):
|
| 37 |
+
|
| 38 |
+
assert len(dist_mat.size()) == 2
|
| 39 |
+
assert dist_mat.size(0) == dist_mat.size(1)
|
| 40 |
+
N = dist_mat.size(0)
|
| 41 |
+
|
| 42 |
+
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
|
| 43 |
+
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
|
| 44 |
+
|
| 45 |
+
dist_ap, relative_p_inds = torch.max(
|
| 46 |
+
(dist_mat * is_pos), 1, keepdim=True)
|
| 47 |
+
|
| 48 |
+
dist_an, relative_n_inds = torch.min(
|
| 49 |
+
(dist_mat * is_neg), 1, keepdim=True)
|
| 50 |
+
|
| 51 |
+
return dist_ap, dist_an
|
| 52 |
+
|
| 53 |
+
def forward(self, embeddings, labels):
|
| 54 |
+
|
| 55 |
+
distances = self.get_pairwise_distances(embeddings)
|
| 56 |
+
dist_ap, dist_an = self.hardest_triplet_mining(distances, labels)
|
| 57 |
+
|
| 58 |
+
triplet_loss = relu(dist_ap - dist_an + self.margin).sum()
|
| 59 |
+
return triplet_loss
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class VAELoss(nn.Module):
|
| 63 |
+
def __init__(self):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.reconstruction_loss = nn.BCELoss(reduction='sum')
|
| 66 |
+
|
| 67 |
+
def kl_divergence_loss(self, q_dist):
|
| 68 |
+
return kl_divergence(
|
| 69 |
+
q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
|
| 70 |
+
).sum(-1)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def forward(self, output, target, encoding):
|
| 74 |
+
loss = self.kl_divergence_loss(encoding).sum() + self.reconstruction_loss(output, target)
|
| 75 |
+
return loss
|
| 76 |
+
|
| 77 |
+
|
embedding_models/vae/model.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from torch.distributions.normal import Normal
|
| 3 |
+
|
| 4 |
+
from .constants import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Encoder(nn.Module):
|
| 8 |
+
'''
|
| 9 |
+
Encoder Class
|
| 10 |
+
Values:
|
| 11 |
+
im_chan: the number of channels of the output image, a scalar
|
| 12 |
+
hidden_dim: the inner dimension, a scalar
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
|
| 16 |
+
super(Encoder, self).__init__()
|
| 17 |
+
self.z_dim = output_chan
|
| 18 |
+
self.disc = nn.Sequential(
|
| 19 |
+
self.make_disc_block(im_chan, hidden_dim),
|
| 20 |
+
self.make_disc_block(hidden_dim, hidden_dim * 2),
|
| 21 |
+
self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
|
| 22 |
+
self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
|
| 23 |
+
self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
|
| 27 |
+
'''
|
| 28 |
+
Function to return a sequence of operations corresponding to a encoder block of the VAE,
|
| 29 |
+
corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
|
| 30 |
+
Parameters:
|
| 31 |
+
input_channels: how many channels the input feature representation has
|
| 32 |
+
output_channels: how many channels the output feature representation should have
|
| 33 |
+
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
|
| 34 |
+
stride: the stride of the convolution
|
| 35 |
+
final_layer: whether we're on the final layer (affects activation and batchnorm)
|
| 36 |
+
'''
|
| 37 |
+
if not final_layer:
|
| 38 |
+
return nn.Sequential(
|
| 39 |
+
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
|
| 40 |
+
nn.BatchNorm2d(output_channels),
|
| 41 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
return nn.Sequential(
|
| 45 |
+
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def forward(self, image):
|
| 49 |
+
'''
|
| 50 |
+
Function for completing a forward pass of the Encoder: Given an image tensor,
|
| 51 |
+
returns a 1-dimension tensor representing fake/real.
|
| 52 |
+
Parameters:
|
| 53 |
+
image: a flattened image tensor with dimension (im_dim)
|
| 54 |
+
'''
|
| 55 |
+
disc_pred = self.disc(image)
|
| 56 |
+
encoding = disc_pred.view(len(disc_pred), -1)
|
| 57 |
+
# The stddev output is treated as the log of the variance of the normal
|
| 58 |
+
# distribution by convention and for numerical stability
|
| 59 |
+
return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Decoder(nn.Module):
|
| 63 |
+
'''
|
| 64 |
+
Decoder Class
|
| 65 |
+
Values:
|
| 66 |
+
z_dim: the dimension of the noise vector, a scalar
|
| 67 |
+
im_chan: the number of channels of the output image, a scalar
|
| 68 |
+
hidden_dim: the inner dimension, a scalar
|
| 69 |
+
'''
|
| 70 |
+
|
| 71 |
+
def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM):
|
| 72 |
+
super(Decoder, self).__init__()
|
| 73 |
+
self.z_dim = z_dim
|
| 74 |
+
self.gen = nn.Sequential(
|
| 75 |
+
self.make_gen_block(z_dim, hidden_dim * 16),
|
| 76 |
+
self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1),
|
| 77 |
+
self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
|
| 78 |
+
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4),
|
| 79 |
+
self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4),
|
| 80 |
+
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
|
| 84 |
+
'''
|
| 85 |
+
Function to return a sequence of operations corresponding to a Decoder block of the VAE,
|
| 86 |
+
corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation
|
| 87 |
+
Parameters:
|
| 88 |
+
input_channels: how many channels the input feature representation has
|
| 89 |
+
output_channels: how many channels the output feature representation should have
|
| 90 |
+
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
|
| 91 |
+
stride: the stride of the convolution
|
| 92 |
+
final_layer: whether we're on the final layer (affects activation and batchnorm)
|
| 93 |
+
'''
|
| 94 |
+
if not final_layer:
|
| 95 |
+
return nn.Sequential(
|
| 96 |
+
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
|
| 97 |
+
nn.BatchNorm2d(output_channels),
|
| 98 |
+
nn.ReLU(inplace=True),
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
return nn.Sequential(
|
| 102 |
+
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
|
| 103 |
+
nn.Sigmoid(),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def forward(self, noise):
|
| 107 |
+
'''
|
| 108 |
+
Function for completing a forward pass of the Decoder: Given a noise vector,
|
| 109 |
+
returns a generated image.
|
| 110 |
+
Parameters:
|
| 111 |
+
noise: a noise tensor with dimensions (batch_size, z_dim)
|
| 112 |
+
'''
|
| 113 |
+
x = noise.view(len(noise), self.z_dim, 1, 1)
|
| 114 |
+
return self.gen(x)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class VAE(nn.Module):
|
| 118 |
+
'''
|
| 119 |
+
VAE Class
|
| 120 |
+
Values:
|
| 121 |
+
z_dim: the dimension of the noise vector, a scalar
|
| 122 |
+
im_chan: the number of channels of the output image, a scalar
|
| 123 |
+
MNIST is black-and-white, so that's our default
|
| 124 |
+
hidden_dim: the inner dimension, a scalar
|
| 125 |
+
'''
|
| 126 |
+
|
| 127 |
+
def __init__(self, z_dim=Z_DIM, im_chan=3):
|
| 128 |
+
super(VAE, self).__init__()
|
| 129 |
+
self.z_dim = z_dim
|
| 130 |
+
self.encode = Encoder(im_chan, z_dim)
|
| 131 |
+
self.decode = Decoder(z_dim, im_chan)
|
| 132 |
+
|
| 133 |
+
def forward(self, images):
|
| 134 |
+
'''
|
| 135 |
+
Function for completing a forward pass of the Decoder: Given a noise vector,
|
| 136 |
+
returns a generated image.
|
| 137 |
+
Parameters:
|
| 138 |
+
images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width)
|
| 139 |
+
Returns:
|
| 140 |
+
decoding: the autoencoded image
|
| 141 |
+
q_dist: the z-distribution of the encoding
|
| 142 |
+
'''
|
| 143 |
+
q_mean, q_stddev = self.encode(images)
|
| 144 |
+
q_dist = Normal(q_mean, q_stddev)
|
| 145 |
+
z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
|
| 146 |
+
decoding = self.decode(z_sample)
|
| 147 |
+
return decoding, q_dist
|
embedding_models/vae/train.ipynb
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 4,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import torch\n",
|
| 10 |
+
"import torch.nn as nn\n",
|
| 11 |
+
"import numpy as np\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"from pathlib import Path\n",
|
| 14 |
+
"import os\n",
|
| 15 |
+
"from PIL import Image\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"from model import VAE\n",
|
| 18 |
+
"from losses import *"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": 2,
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"outputs": [],
|
| 26 |
+
"source": [
|
| 27 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 28 |
+
"from torchvision import transforms\n",
|
| 29 |
+
"import pandas as pd\n",
|
| 30 |
+
"import re\n",
|
| 31 |
+
"from sklearn.model_selection import train_test_split"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": 1,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"IMAGE_FOLDER = './data/images/'"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"execution_count": 5,
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"image_names = os.listdir(IMAGE_FOLDER)\n",
|
| 50 |
+
"data = pd.DataFrame({'image_name': image_names})\n",
|
| 51 |
+
"data['label'] = data['image_name'].apply(lambda x: int(re.match('^\\d+', x)[0]))"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": null,
|
| 57 |
+
"metadata": {},
|
| 58 |
+
"outputs": [],
|
| 59 |
+
"source": [
|
| 60 |
+
"class StampDataset(Dataset):\n",
|
| 61 |
+
" def __init__(self, data, image_folder=Path(IMAGE_FOLDER), transform=None):\n",
|
| 62 |
+
" super().__init__()\n",
|
| 63 |
+
" self.image_folder = image_folder\n",
|
| 64 |
+
" self.data = data\n",
|
| 65 |
+
" self.transform = transform\n",
|
| 66 |
+
"\n",
|
| 67 |
+
" def __getitem__(self, idx):\n",
|
| 68 |
+
" image = Image.open(self.image_folder / self.data.iloc[idx]['image_name'])\n",
|
| 69 |
+
" label = self.data.iloc[idx]['label']\n",
|
| 70 |
+
" if self.transform:\n",
|
| 71 |
+
" image = self.transform(image)\n",
|
| 72 |
+
"\n",
|
| 73 |
+
" return image, label\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" \n",
|
| 76 |
+
" def __len__(self):\n",
|
| 77 |
+
" return len(self.data)"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": 6,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"train_data, val_data = train_test_split(data, test_size=0.3, shuffle=True, stratify=data['label'])"
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "code",
|
| 91 |
+
"execution_count": null,
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"outputs": [],
|
| 94 |
+
"source": [
|
| 95 |
+
"train_transform = transforms.Compose([\n",
|
| 96 |
+
" transforms.Resize((118, 118)),\n",
|
| 97 |
+
" transforms.RandomHorizontalFlip(0.5),\n",
|
| 98 |
+
" transforms.RandomVerticalFlip(0.5),\n",
|
| 99 |
+
" transforms.ToTensor(),\n",
|
| 100 |
+
" # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
|
| 101 |
+
"])\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"val_transform = transforms.Compose([\n",
|
| 104 |
+
" transforms.Resize((118, 118)),\n",
|
| 105 |
+
" transforms.ToTensor(),\n",
|
| 106 |
+
" # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
|
| 107 |
+
"])\n",
|
| 108 |
+
"train_dataset = StampDataset(train_data, transform=train_transform)\n",
|
| 109 |
+
"val_dataset = StampDataset(val_data, transform=val_transform)\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"train_loader = DataLoader(train_dataset, shuffle=True, batch_size=256)\n",
|
| 112 |
+
"val_loader = DataLoader(val_dataset, shuffle=True, batch_size=256)"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": 8,
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [],
|
| 120 |
+
"source": [
|
| 121 |
+
"import pytorch_lightning as pl\n",
|
| 122 |
+
"from torch import optim\n",
|
| 123 |
+
"from pytorch_lightning.loggers import TensorBoardLogger\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"from torchvision.utils import make_grid"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": 9,
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"outputs": [],
|
| 133 |
+
"source": [
|
| 134 |
+
"MEAN = torch.tensor((0.76302232, 0.77820438, 0.81879729)).view(3, 1, 1)\n",
|
| 135 |
+
"STD = torch.tensor((0.16563211, 0.14949341, 0.1055889)).view(3, 1, 1)"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "code",
|
| 140 |
+
"execution_count": 9,
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"outputs": [],
|
| 143 |
+
"source": [
|
| 144 |
+
"class LitModel(pl.LightningModule):\n",
|
| 145 |
+
" def __init__(self, alpha=1e-3):\n",
|
| 146 |
+
" super().__init__()\n",
|
| 147 |
+
" self.vae = VAE()\n",
|
| 148 |
+
" self.vae_loss = VAELoss()\n",
|
| 149 |
+
" self.triplet_loss = BatchHardTripletLoss(margin=1.)\n",
|
| 150 |
+
" self.alpha = alpha\n",
|
| 151 |
+
" \n",
|
| 152 |
+
" def forward(self, x):\n",
|
| 153 |
+
" return self.vae(x)\n",
|
| 154 |
+
" \n",
|
| 155 |
+
" def configure_optimizers(self):\n",
|
| 156 |
+
" optimizer = optim.AdamW(self.parameters(), lr=3e-4)\n",
|
| 157 |
+
" return optimizer\n",
|
| 158 |
+
" # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n",
|
| 159 |
+
" # return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" def training_step(self, batch, batch_idx):\n",
|
| 162 |
+
" images, labels = batch\n",
|
| 163 |
+
" labels = labels.unsqueeze(1)\n",
|
| 164 |
+
" recon_images, encoding = self.vae(images)\n",
|
| 165 |
+
" vae_loss = self.vae_loss(recon_images, images, encoding)\n",
|
| 166 |
+
" self.log(\"train_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
|
| 167 |
+
" triplet_loss = self.triplet_loss(encoding.mean, labels)\n",
|
| 168 |
+
" self.log(\"train_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
|
| 169 |
+
" loss = self.alpha * triplet_loss + vae_loss\n",
|
| 170 |
+
" self.log(\"train_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n",
|
| 171 |
+
" return loss\n",
|
| 172 |
+
"\n",
|
| 173 |
+
" def validation_step(self, batch, batch_idx):\n",
|
| 174 |
+
" images, labels = batch\n",
|
| 175 |
+
" labels = labels.unsqueeze(1)\n",
|
| 176 |
+
" recon_images, encoding = self.vae(images)\n",
|
| 177 |
+
" vae_loss = self.vae_loss(recon_images, images, encoding)\n",
|
| 178 |
+
" self.log(\"val_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
|
| 179 |
+
" triplet_loss = self.triplet_loss(encoding.mean, labels)\n",
|
| 180 |
+
" self.log(\"val_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
|
| 181 |
+
" loss = self.alpha * triplet_loss + vae_loss\n",
|
| 182 |
+
" self.log(\"val_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n",
|
| 183 |
+
" return loss\n",
|
| 184 |
+
"\n",
|
| 185 |
+
" def on_validation_epoch_end(self):\n",
|
| 186 |
+
" images, _ = iter(val_loader).next()\n",
|
| 187 |
+
" image_unflat = images.detach().cpu()\n",
|
| 188 |
+
" image_grid = make_grid(image_unflat[:16], nrow=4)\n",
|
| 189 |
+
" self.logger.experiment.add_image('original images', image_grid, self.current_epoch)\n",
|
| 190 |
+
"\n",
|
| 191 |
+
" recon_images, _ = self.vae(images.to(self.device))\n",
|
| 192 |
+
" image_unflat = recon_images.detach().cpu()\n",
|
| 193 |
+
" image_grid = make_grid(image_unflat[:16], nrow=4)\n",
|
| 194 |
+
" self.logger.experiment.add_image('reconstructed images', image_grid, self.current_epoch)"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "code",
|
| 199 |
+
"execution_count": 10,
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"outputs": [],
|
| 202 |
+
"source": [
|
| 203 |
+
"litmodel = LitModel()"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"execution_count": 11,
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"outputs": [],
|
| 211 |
+
"source": [
|
| 212 |
+
"logger = TensorBoardLogger(\"reconstruction_logs\")"
|
| 213 |
+
]
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "code",
|
| 217 |
+
"execution_count": 12,
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": [
|
| 221 |
+
"epochs = 100"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"cell_type": "code",
|
| 226 |
+
"execution_count": null,
|
| 227 |
+
"metadata": {},
|
| 228 |
+
"outputs": [],
|
| 229 |
+
"source": [
|
| 230 |
+
"trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n",
|
| 231 |
+
"trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"cell_type": "code",
|
| 236 |
+
"execution_count": null,
|
| 237 |
+
"metadata": {},
|
| 238 |
+
"outputs": [],
|
| 239 |
+
"source": [
|
| 240 |
+
"%tensorboard"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"cell_type": "code",
|
| 245 |
+
"execution_count": 8,
|
| 246 |
+
"metadata": {},
|
| 247 |
+
"outputs": [],
|
| 248 |
+
"source": [
|
| 249 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
|
| 250 |
+
]
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"cell_type": "code",
|
| 254 |
+
"execution_count": 11,
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"outputs": [],
|
| 257 |
+
"source": [
|
| 258 |
+
"from huggingface_hub import hf_hub_download"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "code",
|
| 263 |
+
"execution_count": 12,
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"emb_model = torch.jit.load(hf_hub_download(repo_id=\"stamps-labs/vits8-stamp\", filename=\"vits8stamp-torchscript.pth\")).to(device)"
|
| 268 |
+
]
|
| 269 |
+
},
|
| 270 |
+
{
|
| 271 |
+
"cell_type": "code",
|
| 272 |
+
"execution_count": 21,
|
| 273 |
+
"metadata": {},
|
| 274 |
+
"outputs": [],
|
| 275 |
+
"source": [
|
| 276 |
+
"val_transform = transforms.Compose([\n",
|
| 277 |
+
" transforms.Resize((224, 224)),\n",
|
| 278 |
+
" transforms.ToTensor(),\n",
|
| 279 |
+
" # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
|
| 280 |
+
"])"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "code",
|
| 285 |
+
"execution_count": 28,
|
| 286 |
+
"metadata": {},
|
| 287 |
+
"outputs": [],
|
| 288 |
+
"source": [
|
| 289 |
+
"train_data['embed'] = train_data['image_name'].apply(lambda x: emb_model(val_transform(Image.open(Path(IMAGE_FOLDER) / x)).unsqueeze(0).to(device))[0].tolist())"
|
| 290 |
+
]
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"cell_type": "code",
|
| 294 |
+
"execution_count": 34,
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"outputs": [
|
| 297 |
+
{
|
| 298 |
+
"name": "stderr",
|
| 299 |
+
"output_type": "stream",
|
| 300 |
+
"text": [
|
| 301 |
+
"C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:1: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
|
| 302 |
+
" embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n",
|
| 303 |
+
"C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:2: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
|
| 304 |
+
" labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)\n"
|
| 305 |
+
]
|
| 306 |
+
}
|
| 307 |
+
],
|
| 308 |
+
"source": [
|
| 309 |
+
"embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n",
|
| 310 |
+
"labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)"
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"cell_type": "code",
|
| 315 |
+
"execution_count": 35,
|
| 316 |
+
"metadata": {},
|
| 317 |
+
"outputs": [],
|
| 318 |
+
"source": [
|
| 319 |
+
"embeds.to_csv('./all_embeds.tsv', sep='\\t', index=False, header=False)"
|
| 320 |
+
]
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"cell_type": "code",
|
| 324 |
+
"execution_count": 36,
|
| 325 |
+
"metadata": {},
|
| 326 |
+
"outputs": [],
|
| 327 |
+
"source": [
|
| 328 |
+
"labels.to_csv('./all_labels.tsv', sep='\\t', index=False, header=False)"
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"cell_type": "code",
|
| 333 |
+
"execution_count": 126,
|
| 334 |
+
"metadata": {},
|
| 335 |
+
"outputs": [],
|
| 336 |
+
"source": [
|
| 337 |
+
"torch.save(litmodel.vae.encode.state_dict(), './models/encoder.pth')"
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"cell_type": "code",
|
| 342 |
+
"execution_count": 129,
|
| 343 |
+
"metadata": {},
|
| 344 |
+
"outputs": [],
|
| 345 |
+
"source": [
|
| 346 |
+
"im = train_dataset[0]"
|
| 347 |
+
]
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"cell_type": "code",
|
| 351 |
+
"execution_count": 132,
|
| 352 |
+
"metadata": {},
|
| 353 |
+
"outputs": [
|
| 354 |
+
{
|
| 355 |
+
"data": {
|
| 356 |
+
"text/plain": [
|
| 357 |
+
"<All keys matched successfully>"
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
"execution_count": 132,
|
| 361 |
+
"metadata": {},
|
| 362 |
+
"output_type": "execute_result"
|
| 363 |
+
}
|
| 364 |
+
],
|
| 365 |
+
"source": [
|
| 366 |
+
"model = Encoder()\n",
|
| 367 |
+
"model.load_state_dict(torch.load('./models/encoder.pth'))"
|
| 368 |
+
]
|
| 369 |
+
}
|
| 370 |
+
],
|
| 371 |
+
"metadata": {
|
| 372 |
+
"kernelspec": {
|
| 373 |
+
"display_name": "Python 3",
|
| 374 |
+
"language": "python",
|
| 375 |
+
"name": "python3"
|
| 376 |
+
},
|
| 377 |
+
"language_info": {
|
| 378 |
+
"codemirror_mode": {
|
| 379 |
+
"name": "ipython",
|
| 380 |
+
"version": 3
|
| 381 |
+
},
|
| 382 |
+
"file_extension": ".py",
|
| 383 |
+
"mimetype": "text/x-python",
|
| 384 |
+
"name": "python",
|
| 385 |
+
"nbconvert_exporter": "python",
|
| 386 |
+
"pygments_lexer": "ipython3",
|
| 387 |
+
"version": "3.9.0"
|
| 388 |
+
},
|
| 389 |
+
"orig_nbformat": 4
|
| 390 |
+
},
|
| 391 |
+
"nbformat": 4,
|
| 392 |
+
"nbformat_minor": 2
|
| 393 |
+
}
|
embedding_models/vits8/__init__.py
ADDED
|
File without changes
|
embedding_models/vits8/example.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from model import ViTStamp
|
| 3 |
+
def get_embeddings(img_path: str):
|
| 4 |
+
model = ViTStamp()
|
| 5 |
+
image = Image.open(img_path)
|
| 6 |
+
embeddings = model(image=image)
|
| 7 |
+
return embeddings
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
print(get_embeddings("oml/data/test/images/99d_15.bmp"))
|
embedding_models/vits8/model.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
|
| 5 |
+
class ViTStamp():
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 8 |
+
self.model = torch.jit.load(hf_hub_download(repo_id="stamps-labs/vits8-stamp", filename="vits8stamp-torchscript.pth"))
|
| 9 |
+
self.transform = transforms.ToTensor()
|
| 10 |
+
def __call__(self, image) -> torch.Tensor():
|
| 11 |
+
img_tensor = self.transform(image).cuda().unsqueeze(0) if self.device == "cuda" else self.transform(image).unsqueeze(0)
|
| 12 |
+
features = self.model(img_tensor)
|
| 13 |
+
return features
|
embedding_models/vits8/oml/__init__.py
ADDED
|
File without changes
|
embedding_models/vits8/oml/create_dataset.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
parser = argparse.ArgumentParser("Create a dataset for training with OML",
|
| 8 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 9 |
+
|
| 10 |
+
parser.add_argument("--root-data-path", help="Path to images for dataset", default="data/train_val/")
|
| 11 |
+
parser.add_argument("--image-data-path", help="Image folder in root data path", default="images/")
|
| 12 |
+
parser.add_argument("--train-val-split",
|
| 13 |
+
help="In which ratio to split data in format train:val (For example 80:20)", default="80:20")
|
| 14 |
+
parser.add_argument("--separator",
|
| 15 |
+
help="What separator is used in image name to separate class name and instance (E.g. circle1_5, separator=_)",
|
| 16 |
+
default="_")
|
| 17 |
+
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
config = vars(args)
|
| 20 |
+
|
| 21 |
+
root_path = config["root_data_path"]
|
| 22 |
+
image_path = config["image_data_path"]
|
| 23 |
+
separator = config["separator"]
|
| 24 |
+
|
| 25 |
+
train_prc, val_prc = tuple(int(num)/100 for num in config["train_val_split"].split(":"))
|
| 26 |
+
|
| 27 |
+
class_names = set()
|
| 28 |
+
for image in os.listdir(root_path+image_path):
|
| 29 |
+
if image.endswith(("png", "jpg", "bmp", "webp")):
|
| 30 |
+
img_name = image.split(".")[0]
|
| 31 |
+
Image.open(root_path+image_path+image).resize((224,224)).save(root_path+image_path+img_name+".png", "PNG")
|
| 32 |
+
if not image.endswith("png"):
|
| 33 |
+
os.remove(root_path+image_path+image)
|
| 34 |
+
img_name = img_name.split(separator)
|
| 35 |
+
class_name = img_name[0]+img_name[1]
|
| 36 |
+
class_names.add(class_name)
|
| 37 |
+
else:
|
| 38 |
+
print("Not all of the images are in supported format")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
#For each class in set assign its index in a set as a class label.
|
| 42 |
+
class_label_dict = {}
|
| 43 |
+
for ind, name in enumerate(class_names):
|
| 44 |
+
class_label_dict[name] = ind
|
| 45 |
+
|
| 46 |
+
class_count = len(class_names)
|
| 47 |
+
train_class_count = int(class_count*train_prc)
|
| 48 |
+
print(train_class_count)
|
| 49 |
+
|
| 50 |
+
df_dict = {"label": [],
|
| 51 |
+
"path": [],
|
| 52 |
+
"split": [],
|
| 53 |
+
"is_query": [],
|
| 54 |
+
"is_gallery": []}
|
| 55 |
+
for image in os.listdir(root_path+image_path):
|
| 56 |
+
if image.endswith((".png", ".jpg", ".bmp", ".webp")):
|
| 57 |
+
img_name = image.split(".")[0].split(separator)
|
| 58 |
+
class_name = img_name[0]+img_name[1]
|
| 59 |
+
label = class_label_dict[class_name]
|
| 60 |
+
path = image_path+image
|
| 61 |
+
split = "train" if label <= train_class_count else "validation"
|
| 62 |
+
is_query, is_gallery = (1, 1) if split=="validation" else (None, None)
|
| 63 |
+
df_dict["label"].append(label)
|
| 64 |
+
df_dict["path"].append(path)
|
| 65 |
+
df_dict["split"].append(split)
|
| 66 |
+
df_dict["is_query"].append(is_query)
|
| 67 |
+
df_dict["is_gallery"].append(is_gallery)
|
| 68 |
+
|
| 69 |
+
df = pd.DataFrame(df_dict)
|
| 70 |
+
|
| 71 |
+
df.to_csv(root_path+"df_stamps.csv", index=False)
|
embedding_models/vits8/oml/data/test/images/99d_15.bmp
ADDED
|
embedding_models/vits8/oml/data/test/images/99e_20.bmp
ADDED
|
embedding_models/vits8/oml/data/test/images/99f_25.bmp
ADDED
|
embedding_models/vits8/oml/data/test/images/99g_30.bmp
ADDED
|
embedding_models/vits8/oml/data/test/images/99h_35.bmp
ADDED
|
embedding_models/vits8/oml/data/test/images/99i_40.bmp
ADDED
|
embedding_models/vits8/oml/data/train_val/df_stamps.csv
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
label,path,split,is_query,is_gallery
|
| 2 |
+
0,images/circle6_1239.png,train,,
|
| 3 |
+
8,images/triangle19_1242.png,train,,
|
| 4 |
+
21,images/rectangle11_1248.png,train,,
|
| 5 |
+
39,images/triangle10_1232.png,validation,1.0,1.0
|
| 6 |
+
33,images/word14_1241.png,validation,1.0,1.0
|
| 7 |
+
38,images/word5_1233.png,validation,1.0,1.0
|
| 8 |
+
15,images/circle19_1236.png,train,,
|
| 9 |
+
22,images/circle15_1244.png,train,,
|
| 10 |
+
32,images/circle21_1249.png,train,,
|
| 11 |
+
26,images/oval20_1242.png,train,,
|
| 12 |
+
6,images/oval5_1237.png,train,,
|
| 13 |
+
23,images/word9_1241.png,train,,
|
| 14 |
+
9,images/triangle22_1238.png,train,,
|
| 15 |
+
31,images/circle12_1239.png,train,,
|
| 16 |
+
11,images/word21_1231.png,train,,
|
| 17 |
+
4,images/oval2_1235.png,train,,
|
| 18 |
+
20,images/rectangle18_1246.png,train,,
|
| 19 |
+
12,images/circle24_1234.png,train,,
|
| 20 |
+
5,images/circle2_1249.png,train,,
|
| 21 |
+
37,images/word22_1238.png,validation,1.0,1.0
|
| 22 |
+
34,images/triangle18_1247.png,validation,1.0,1.0
|
| 23 |
+
1,images/oval7_1241.png,train,,
|
| 24 |
+
10,images/triangle13_1240.png,train,,
|
| 25 |
+
14,images/rectangle12_1236.png,train,,
|
| 26 |
+
36,images/circle8_1237.png,validation,1.0,1.0
|
| 27 |
+
24,images/triangle9_1245.png,train,,
|
| 28 |
+
29,images/word23_1243.png,train,,
|
| 29 |
+
28,images/triangle11_1244.png,train,,
|
| 30 |
+
16,images/circle2_1246.png,train,,
|
| 31 |
+
30,images/circle3_1247.png,train,,
|
| 32 |
+
18,images/oval24_1248.png,train,,
|
| 33 |
+
2,images/oval12_1231.png,train,,
|
| 34 |
+
3,images/oval18_1234.png,train,,
|
| 35 |
+
25,images/rectangle11_1245.png,train,,
|
| 36 |
+
17,images/word9_1244.png,train,,
|
| 37 |
+
13,images/triangle14_1237.png,train,,
|
| 38 |
+
35,images/circle2_1233.png,validation,1.0,1.0
|
| 39 |
+
7,images/word18_1239.png,train,,
|
| 40 |
+
19,images/rectangle13_1236.png,train,,
|
| 41 |
+
27,images/circle24_1246.png,train,,
|
embedding_models/vits8/oml/data/train_val/images/circle12_1239.png
ADDED
|
embedding_models/vits8/oml/data/train_val/images/circle15_1244.png
ADDED
|
embedding_models/vits8/oml/data/train_val/images/circle19_1236.png
ADDED
|
embedding_models/vits8/oml/data/train_val/images/circle21_1249.png
ADDED
|
embedding_models/vits8/oml/data/train_val/images/circle24_1234.png
ADDED
|
embedding_models/vits8/oml/data/train_val/images/circle24_1246.png
ADDED
|
embedding_models/vits8/oml/data/train_val/images/circle2_1233.png
ADDED
|
embedding_models/vits8/oml/data/train_val/images/circle2_1246.png
ADDED
|
embedding_models/vits8/oml/data/train_val/images/circle2_1249.png
ADDED
|