feat: support meta SSL watermarking
Browse files- SSL_watermark.py +87 -0
- app.py +23 -6
- dino_r50.pth +3 -0
- image_utils.py +80 -0
- out2048.pth +3 -0
- requirements.txt +4 -0
- torch_utils.py +84 -0
    	
        SSL_watermark.py
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from torchvision import transforms
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch_utils 
         | 
| 8 | 
            +
            import image_utils
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            torch.manual_seed(0)
         | 
| 13 | 
            +
            np.random.seed(0)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            print('Building backbone and normalization layer...')
         | 
| 16 | 
            +
            backbone = torch_utils.build_backbone(path='dino_r50.pth')
         | 
| 17 | 
            +
            normlayer = torch_utils.load_normalization_layer(path='out2048.pth')
         | 
| 18 | 
            +
            model = torch_utils.NormLayerWrapper(backbone, normlayer)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            print('Building the hypercone...')
         | 
| 21 | 
            +
            FPR = 1e-6
         | 
| 22 | 
            +
            angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
         | 
| 23 | 
            +
            rho = 1 + np.tan(angle)**2
         | 
| 24 | 
            +
            carrier = torch.randn(1, 2048)
         | 
| 25 | 
            +
            carrier /= torch.norm(carrier, dim=1, keepdim=True)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            default_transform = transforms.Compose([
         | 
| 28 | 
            +
                    transforms.ToTensor(), 
         | 
| 29 | 
            +
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         | 
| 30 | 
            +
                ])
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
         | 
| 33 | 
            +
                img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
         | 
| 34 | 
            +
                img = img_orig.clone().to(device, non_blocking=True) 
         | 
| 35 | 
            +
                img.requires_grad = True
         | 
| 36 | 
            +
                optimizer = torch.optim.Adam([img], lr=1e-2)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                for iteration in range(epochs):
         | 
| 39 | 
            +
                    print(f'iteration: {iteration}')
         | 
| 40 | 
            +
                    x = image_utils.ssim_attenuation(img, img_orig)
         | 
| 41 | 
            +
                    x = image_utils.psnr_clip(x, img_orig, psnr)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    ft = model(x) # BxCxWxH -> BxD
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
         | 
| 46 | 
            +
                    norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
         | 
| 47 | 
            +
                    cosines = torch.abs(dot_product/norm)
         | 
| 48 | 
            +
                    log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
         | 
| 49 | 
            +
                    loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
         | 
| 52 | 
            +
                    loss = lambda_w*loss_R + lambda_i*loss_l2_img
         | 
| 53 | 
            +
                    
         | 
| 54 | 
            +
                    optimizer.zero_grad()
         | 
| 55 | 
            +
                    loss.backward()
         | 
| 56 | 
            +
                    optimizer.step()
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    logs = {
         | 
| 59 | 
            +
                        "keyword": "img_optim",
         | 
| 60 | 
            +
                        "iteration": iteration,
         | 
| 61 | 
            +
                        "loss": loss.item(),
         | 
| 62 | 
            +
                        "loss_R": loss_R.item(),
         | 
| 63 | 
            +
                        "loss_l2_img": loss_l2_img.item(),
         | 
| 64 | 
            +
                        "log10_pvalue": log10_pvalue.item(),
         | 
| 65 | 
            +
                    }
         | 
| 66 | 
            +
                    print("__log__:%s" % json.dumps(logs))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                img = image_utils.ssim_attenuation(img, img_orig)
         | 
| 69 | 
            +
                img = image_utils.psnr_clip(img, img_orig, psnr)
         | 
| 70 | 
            +
                img = image_utils.round_pixel(img)
         | 
| 71 | 
            +
                img = img.squeeze(0).detach().cpu()
         | 
| 72 | 
            +
                img = transforms.ToPILImage()(image_utils.unnormalize_img(img).squeeze(0))
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                return img
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            def decode(image):
         | 
| 77 | 
            +
                img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
         | 
| 78 | 
            +
                ft = model(img) # BxCxWxH -> BxD
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
         | 
| 81 | 
            +
                norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
         | 
| 82 | 
            +
                cosines = torch.abs(dot_product/norm)
         | 
| 83 | 
            +
                log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
         | 
| 84 | 
            +
                loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                text_marked = "marked" if loss_R < 0 else "unmarked"
         | 
| 87 | 
            +
                return f'Image is {text_marked}, with p-value={10**log10_pvalue}'
         | 
    	
        app.py
    CHANGED
    
    | @@ -1,6 +1,7 @@ | |
| 1 | 
             
            import gradio as gr 
         | 
| 2 | 
             
            from steganography import Steganography
         | 
| 3 | 
             
            from utils import draw_multiple_line_text, generate_qr_code
         | 
|  | |
| 4 |  | 
| 5 |  | 
| 6 | 
             
            TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
         | 
| @@ -8,20 +9,27 @@ TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>""" | |
| 8 |  | 
| 9 | 
             
            def apply_watermark(radio_button, input_image, watermark_image, watermark_text, watermark_url):
         | 
| 10 | 
             
                input_image = input_image.convert('RGB')
         | 
| 11 | 
            -
             | 
| 12 | 
             
                if radio_button == "Image":
         | 
| 13 | 
             
                    watermark_image = watermark_image.resize((input_image.width, input_image.height)).convert('L').convert('RGB')
         | 
| 14 | 
             
                    return Steganography().merge(input_image, watermark_image, digit=7)
         | 
| 15 | 
             
                elif radio_button == "Text":
         | 
| 16 | 
             
                    watermark_image = draw_multiple_line_text(input_image.size, watermark_text)
         | 
| 17 | 
             
                    return Steganography().merge(input_image, watermark_image, digit=7)
         | 
| 18 | 
            -
                 | 
| 19 | 
             
                    size = min(input_image.width, input_image.height)
         | 
| 20 | 
             
                    watermark_image = generate_qr_code(watermark_url).resize((size, size)).convert('RGB')
         | 
| 21 | 
             
                    return Steganography().merge(input_image, watermark_image, digit=7)
         | 
|  | |
|  | |
|  | |
| 22 |  | 
| 23 | 
            -
            def extract_watermark(input_image_to_extract):
         | 
| 24 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
| 25 |  | 
| 26 |  | 
| 27 | 
             
            with gr.Blocks() as demo:
         | 
| @@ -34,7 +42,7 @@ with gr.Blocks() as demo: | |
| 34 | 
             
                            with gr.Blocks():
         | 
| 35 | 
             
                                gr.Markdown("### Which type of watermark you want to apply?")
         | 
| 36 | 
             
                                radio_button = gr.Radio(
         | 
| 37 | 
            -
                                    choices=["QRCode", "Text", "Image"], 
         | 
| 38 | 
             
                                    label="Watermark type", 
         | 
| 39 | 
             
                                    value="QRCode",
         | 
| 40 | 
             
                                    # info="Which type of watermark you want to apply?"
         | 
| @@ -82,6 +90,11 @@ with gr.Blocks() as demo: | |
| 82 | 
             
                        with gr.Column():
         | 
| 83 | 
             
                            gr.Markdown("### Image to extract watermark")
         | 
| 84 | 
             
                            input_image_to_extract = gr.Image(type='pil')
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 85 | 
             
                        with gr.Column():
         | 
| 86 | 
             
                            gr.Markdown("### Extracted watermark")
         | 
| 87 | 
             
                            extracted_watermark = gr.Image(type='pil')
         | 
| @@ -97,6 +110,10 @@ with gr.Blocks() as demo: | |
| 97 | 
             
                    inputs=[radio_button, input_image, watermark_image, watermark_text, watermark_url], 
         | 
| 98 | 
             
                    outputs=[output_image]
         | 
| 99 | 
             
                )
         | 
| 100 | 
            -
                extract_button.click( | 
|  | |
|  | |
|  | |
|  | |
| 101 |  | 
| 102 | 
             
            demo.launch()
         | 
|  | |
| 1 | 
             
            import gradio as gr 
         | 
| 2 | 
             
            from steganography import Steganography
         | 
| 3 | 
             
            from utils import draw_multiple_line_text, generate_qr_code
         | 
| 4 | 
            +
            from SSL_watermark import encode, decode
         | 
| 5 |  | 
| 6 |  | 
| 7 | 
             
            TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
         | 
|  | |
| 9 |  | 
| 10 | 
             
            def apply_watermark(radio_button, input_image, watermark_image, watermark_text, watermark_url):
         | 
| 11 | 
             
                input_image = input_image.convert('RGB')
         | 
| 12 | 
            +
                print(f'radio_button: {radio_button}')
         | 
| 13 | 
             
                if radio_button == "Image":
         | 
| 14 | 
             
                    watermark_image = watermark_image.resize((input_image.width, input_image.height)).convert('L').convert('RGB')
         | 
| 15 | 
             
                    return Steganography().merge(input_image, watermark_image, digit=7)
         | 
| 16 | 
             
                elif radio_button == "Text":
         | 
| 17 | 
             
                    watermark_image = draw_multiple_line_text(input_image.size, watermark_text)
         | 
| 18 | 
             
                    return Steganography().merge(input_image, watermark_image, digit=7)
         | 
| 19 | 
            +
                elif radio_button == "QRCode":
         | 
| 20 | 
             
                    size = min(input_image.width, input_image.height)
         | 
| 21 | 
             
                    watermark_image = generate_qr_code(watermark_url).resize((size, size)).convert('RGB')
         | 
| 22 | 
             
                    return Steganography().merge(input_image, watermark_image, digit=7)
         | 
| 23 | 
            +
                else:
         | 
| 24 | 
            +
                    print('start encoding ssl watermark...')
         | 
| 25 | 
            +
                    return encode(input_image, epochs=5)
         | 
| 26 |  | 
| 27 | 
            +
            def extract_watermark(extract_radio_button, input_image_to_extract):
         | 
| 28 | 
            +
                if extract_radio_button == 'Steganography':
         | 
| 29 | 
            +
                    return Steganography().unmerge(input_image_to_extract.convert('RGB'), digit=7).convert('RGBA')
         | 
| 30 | 
            +
                else:
         | 
| 31 | 
            +
                    decoded_info = decode(image=input_image_to_extract)
         | 
| 32 | 
            +
                    return draw_multiple_line_text(input_image_size=input_image_to_extract.size, text=decoded_info)
         | 
| 33 |  | 
| 34 |  | 
| 35 | 
             
            with gr.Blocks() as demo:
         | 
|  | |
| 42 | 
             
                            with gr.Blocks():
         | 
| 43 | 
             
                                gr.Markdown("### Which type of watermark you want to apply?")
         | 
| 44 | 
             
                                radio_button = gr.Radio(
         | 
| 45 | 
            +
                                    choices=["QRCode", "Text", "Image", "SSL Watermark"], 
         | 
| 46 | 
             
                                    label="Watermark type", 
         | 
| 47 | 
             
                                    value="QRCode",
         | 
| 48 | 
             
                                    # info="Which type of watermark you want to apply?"
         | 
|  | |
| 90 | 
             
                        with gr.Column():
         | 
| 91 | 
             
                            gr.Markdown("### Image to extract watermark")
         | 
| 92 | 
             
                            input_image_to_extract = gr.Image(type='pil')
         | 
| 93 | 
            +
                            extract_radio_button = gr.Radio(
         | 
| 94 | 
            +
                                choices=["Steganography", "SSL Watermark"], 
         | 
| 95 | 
            +
                                label="Extract methods", 
         | 
| 96 | 
            +
                                value="Steganography"
         | 
| 97 | 
            +
                            )
         | 
| 98 | 
             
                        with gr.Column():
         | 
| 99 | 
             
                            gr.Markdown("### Extracted watermark")
         | 
| 100 | 
             
                            extracted_watermark = gr.Image(type='pil')
         | 
|  | |
| 110 | 
             
                    inputs=[radio_button, input_image, watermark_image, watermark_text, watermark_url], 
         | 
| 111 | 
             
                    outputs=[output_image]
         | 
| 112 | 
             
                )
         | 
| 113 | 
            +
                extract_button.click(
         | 
| 114 | 
            +
                    fn=extract_watermark, 
         | 
| 115 | 
            +
                    inputs=[extract_radio_button, input_image_to_extract], 
         | 
| 116 | 
            +
                    outputs=[extracted_watermark]
         | 
| 117 | 
            +
                )
         | 
| 118 |  | 
| 119 | 
             
            demo.launch()
         | 
    	
        dino_r50.pth
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:ab26d85d00cb1be8e757cf8820cf0fd8aa729ea7e21b1cf6c44875952ba8eb0f
         | 
| 3 | 
            +
            size 788803344
         | 
    	
        image_utils.py
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torchvision import transforms
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from torch.autograd.variable import Variable
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 13 | 
            +
            image_mean = torch.Tensor(NORMALIZE_IMAGENET.mean).view(-1, 1, 1).to(device)
         | 
| 14 | 
            +
            image_std = torch.Tensor(NORMALIZE_IMAGENET.std).view(-1, 1, 1).to(device)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def normalize_img(x):
         | 
| 17 | 
            +
                return (x.to(device) - image_mean) / image_std
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def unnormalize_img(x):
         | 
| 20 | 
            +
                return (x.to(device) * image_std) + image_mean
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            def round_pixel(x):
         | 
| 23 | 
            +
                x_pixel = 255 * unnormalize_img(x)
         | 
| 24 | 
            +
                y = torch.round(x_pixel).clamp(0, 255)
         | 
| 25 | 
            +
                y = normalize_img(y/255.0)
         | 
| 26 | 
            +
                return y
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def project_linf(x, y, radius):
         | 
| 29 | 
            +
                """ Clamp x-y so that Linf(x,y)<=radius """
         | 
| 30 | 
            +
                delta = x - y
         | 
| 31 | 
            +
                delta = 255 * (delta * image_std)
         | 
| 32 | 
            +
                delta = torch.clamp(delta, -radius, radius)
         | 
| 33 | 
            +
                delta = (delta / 255.0) / image_std
         | 
| 34 | 
            +
                return y + delta
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def psnr_clip(x, y, target_psnr):
         | 
| 37 | 
            +
                """ Clip x-y so that PSNR(x,y)=target_psnr """
         | 
| 38 | 
            +
                delta = x - y
         | 
| 39 | 
            +
                delta = 255 * (delta * image_std)
         | 
| 40 | 
            +
                psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
         | 
| 41 | 
            +
                if psnr<target_psnr:
         | 
| 42 | 
            +
                    delta = (torch.sqrt(10**((psnr-target_psnr)/10))) * delta 
         | 
| 43 | 
            +
                psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
         | 
| 44 | 
            +
                delta = (delta / 255.0) / image_std
         | 
| 45 | 
            +
                return y + delta
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            def ssim_heatmap(img1, img2, window_size):
         | 
| 48 | 
            +
                """ Compute the SSIM heatmap between 2 images """
         | 
| 49 | 
            +
                _1D_window = torch.Tensor(
         | 
| 50 | 
            +
                    [np.exp(-(x - window_size//2)**2/float(2*1.5**2)) for x in range(window_size)]
         | 
| 51 | 
            +
                    ).to(device, non_blocking=True)
         | 
| 52 | 
            +
                _1D_window = (_1D_window/_1D_window.sum()).unsqueeze(1)
         | 
| 53 | 
            +
                _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
         | 
| 54 | 
            +
                window = Variable(_2D_window.expand(3, 1, window_size, window_size).contiguous())
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                mu1 = F.conv2d(img1, window, padding = window_size//2, groups = 3)
         | 
| 57 | 
            +
                mu2 = F.conv2d(img2, window, padding = window_size//2, groups = 3)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                mu1_sq = mu1.pow(2)
         | 
| 60 | 
            +
                mu2_sq = mu2.pow(2)
         | 
| 61 | 
            +
                mu1_mu2 = mu1*mu2
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = 3) - mu1_sq
         | 
| 64 | 
            +
                sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = 3) - mu2_sq
         | 
| 65 | 
            +
                sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = 3) - mu1_mu2
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                C1 = 0.01**2
         | 
| 68 | 
            +
                C2 = 0.03**2
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
         | 
| 71 | 
            +
                return ssim_map
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            def ssim_attenuation(x, y):
         | 
| 74 | 
            +
                """ attenuate x-y using SSIM heatmap """
         | 
| 75 | 
            +
                delta = x - y
         | 
| 76 | 
            +
                ssim_map = ssim_heatmap(x, y, window_size=17) # 1xCxHxW
         | 
| 77 | 
            +
                ssim_map = torch.sum(ssim_map, dim=1, keepdim=True)
         | 
| 78 | 
            +
                ssim_map = torch.clamp_min(ssim_map,0)
         | 
| 79 | 
            +
                delta = delta*ssim_map
         | 
| 80 | 
            +
                return y + delta
         | 
    	
        out2048.pth
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:4b256188454d8f7cf440de048df398e2a3209136a52cd7cdac834f5792f526a3
         | 
| 3 | 
            +
            size 16786561
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,4 +1,8 @@ | |
|  | |
|  | |
| 1 | 
             
            Pillow
         | 
| 2 | 
             
            click
         | 
| 3 | 
             
            gradio
         | 
| 4 | 
             
            qrcode
         | 
|  | |
|  | 
|  | |
| 1 | 
            +
            torch==1.10.1
         | 
| 2 | 
            +
            torchvision==0.11.2
         | 
| 3 | 
             
            Pillow
         | 
| 4 | 
             
            click
         | 
| 5 | 
             
            gradio
         | 
| 6 | 
             
            qrcode
         | 
| 7 | 
            +
            scipy
         | 
| 8 | 
            +
            json
         | 
    	
        torch_utils.py
    ADDED
    
    | @@ -0,0 +1,84 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from torchvision import models
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from scipy.optimize import root_scalar
         | 
| 8 | 
            +
            from scipy.special import betainc
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def build_backbone(path, name='resnet50'):
         | 
| 13 | 
            +
                """ Builds a pretrained ResNet-50 backbone. """
         | 
| 14 | 
            +
                model = getattr(models, name)(pretrained=False)
         | 
| 15 | 
            +
                model.head = nn.Identity()
         | 
| 16 | 
            +
                model.fc = nn.Identity()
         | 
| 17 | 
            +
                checkpoint = torch.load(path, map_location=device)
         | 
| 18 | 
            +
                state_dict = checkpoint
         | 
| 19 | 
            +
                for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
         | 
| 20 | 
            +
                    if ckpt_key in checkpoint:
         | 
| 21 | 
            +
                        state_dict = checkpoint[ckpt_key]
         | 
| 22 | 
            +
                state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
         | 
| 23 | 
            +
                state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
         | 
| 24 | 
            +
                msg = model.load_state_dict(state_dict, strict=False)
         | 
| 25 | 
            +
                return model
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def get_linear_layer(weight, bias):
         | 
| 28 | 
            +
                """ Creates a layer that performs feature whitening or centering """
         | 
| 29 | 
            +
                dim_out, dim_in = weight.shape
         | 
| 30 | 
            +
                layer = nn.Linear(dim_in, dim_out)
         | 
| 31 | 
            +
                layer.weight = nn.Parameter(weight)
         | 
| 32 | 
            +
                layer.bias = nn.Parameter(bias)
         | 
| 33 | 
            +
                return layer
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            def load_normalization_layer(path):
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                Loads the normalization layer from a checkpoint and returns the layer.
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                checkpoint = torch.load(path, map_location=device)
         | 
| 40 | 
            +
                if 'whitening' in path or 'out' in path:
         | 
| 41 | 
            +
                    D = checkpoint['weight'].shape[1]
         | 
| 42 | 
            +
                    weight = torch.nn.Parameter(D*checkpoint['weight'])
         | 
| 43 | 
            +
                    bias = torch.nn.Parameter(D*checkpoint['bias'])
         | 
| 44 | 
            +
                else:
         | 
| 45 | 
            +
                    weight = checkpoint['weight']
         | 
| 46 | 
            +
                    bias = checkpoint['bias']
         | 
| 47 | 
            +
                return get_linear_layer(weight, bias).to(device, non_blocking=True)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            class NormLayerWrapper(nn.Module):
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                Wraps backbone model and normalization layer
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                def __init__(self, backbone, head):
         | 
| 54 | 
            +
                    super(NormLayerWrapper, self).__init__()
         | 
| 55 | 
            +
                    backbone.eval(), head.eval()
         | 
| 56 | 
            +
                    self.backbone = backbone
         | 
| 57 | 
            +
                    self.head = head
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self, x):
         | 
| 60 | 
            +
                    output = self.backbone(x)
         | 
| 61 | 
            +
                    return self.head(output)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def cosine_pvalue(c, d, k=1):
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                Returns the probability that the absolute value of the projection
         | 
| 66 | 
            +
                between random unit vectors is higher than c
         | 
| 67 | 
            +
                Args:
         | 
| 68 | 
            +
                    c: cosine value
         | 
| 69 | 
            +
                    d: dimension of the features
         | 
| 70 | 
            +
                    k: number of dimensions of the projection
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                assert k>0
         | 
| 73 | 
            +
                a = (d - k) / 2.0
         | 
| 74 | 
            +
                b = k / 2.0
         | 
| 75 | 
            +
                if c < 0:
         | 
| 76 | 
            +
                    return 1.0
         | 
| 77 | 
            +
                return betainc(a, b, 1 - c ** 2)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            def pvalue_angle(dim, k=1, angle=None, proba=None):
         | 
| 80 | 
            +
                def f(a):
         | 
| 81 | 
            +
                    return cosine_pvalue(np.cos(a), dim, k) - proba
         | 
| 82 | 
            +
                a = root_scalar(f, x0=0.49*np.pi, bracket=[0, np.pi/2])
         | 
| 83 | 
            +
                # a = fsolve(f, x0=0.49*np.pi)[0]
         | 
| 84 | 
            +
                return a.root
         | 
