File size: 3,045 Bytes
db8bba4
 
 
 
 
61ad454
db8bba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import onnxruntime as ort
import numpy
import gradio as gr
from PIL import Image

ort_sess = ort.InferenceSession('tiny_doodle_embedding.onnx')

# force reload now!

def get_bounds(img):
    # Assumes a BLACK BACKGROUND!
    # White letters on a black background!
    left = img.shape[1]
    right = 0
    top = img.shape[0]
    bottom = 0
    min_color = numpy.min(img)
    max_color = numpy.max(img)
    mean_color = 0.5*(min_color+max_color)
    # Do this the dumb way.
    for y in range(0, img.shape[0]):
        for x in range(0, img.shape[1]):
            if img[y,x] > mean_color:
                left = min(left, x)
                right = max(right, x)
                top = min(top, y)
                bottom = max(bottom, y)
    return (top, bottom, left, right)

def resize_maxpool(img, out_width: int, out_height: int):
    out = numpy.zeros((out_height, out_width), dtype=img.dtype)
    scale_factor_y = img.shape[0] // out_height
    scale_factor_x = img.shape[1] // out_width
    for y in range(0, out.shape[0]):
        for x in range(0, out.shape[1]):
            out[y,x] = numpy.max(img[y*scale_factor_y:(y+1)*scale_factor_y, x*scale_factor_x:(x+1)*scale_factor_x])
    return out

def process_input(input_msg):
    img = input_msg["composite"]
    # Image is inverted.  255 is white, 0 is what's drawn.
    img_mean = 0.5 * (numpy.max(img) + numpy.min(img))
    img = 1.0 * (img < img_mean)  # Invert the image and convert to a float.
    crop_area = get_bounds(img)
    img = img[crop_area[0]:crop_area[1], crop_area[2]:crop_area[3]]
    img = resize_maxpool(img, 32, 32)
    #img_a = numpy.resize(img_a, (32, 32))
    img = numpy.expand_dims(img, axis=0)  # Unsqueeze
    return img
    

def compare(input_img_a, input_img_b):
    text_out = ""

    img_a = process_input(input_img_a)
    img_b = process_input(input_img_b)

    # We could vcat these and run them in parallel.
    a_embedding = ort_sess.run(None, {'input': img_a.astype(numpy.float32)})[0]
    b_embedding = ort_sess.run(None, {'input': img_b.astype(numpy.float32)})[0]
    a_mag = 1.0#+numpy.dot(a_embedding, a_embedding.T)
    b_mag = 1.0#+numpy.dot(b_embedding, b_embedding.T)
    a_embedding /= a_mag
    b_embedding /= b_mag
    text_out += f"img_a_embedding: {a_embedding}\n"
    text_out += f"img_b_embedding: {b_embedding}\n"
    sim = numpy.dot(a_embedding , b_embedding.T)
    print(sim)
    print(text_out)
    return Image.fromarray(numpy.clip((numpy.hstack([img_a[0], img_b[0]]) * 254), 0, 255).astype(numpy.uint8)), sim[0][0], text_out
    #return sim[0][0], text_out


demo = gr.Interface(
    fn=compare,
    inputs=[
        gr.Sketchpad(image_mode='L', type='numpy'),
        gr.Sketchpad(image_mode='L', type='numpy'),
        #gr.ImageEditor(
        #    width=320, height=320, 
        #    canvas_size=(320, 320),
        #    sources = ["upload", "clipboard"], # Webcam
        #    layers=False,
        #    image_mode='L', type='numpy', 
        #),
    ],
    outputs=["image", "number", "text"],
)

demo.launch(share=True)