File size: 2,116 Bytes
61c7634
e15dae8
c9911aa
17bb1f6
 
fae45ed
61c7634
236866f
c34d9ea
8465f52
 
3ce1a28
34700d7
78f266b
34700d7
d30d4ce
34700d7
921054e
ec11b9a
6f3fb83
62635cf
17bb1f6
5c5bd98
6f3fb83
8465f52
 
34700d7
8465f52
 
34700d7
 
 
 
cd7c7ec
6115563
31fec50
34700d7
 
62635cf
a1b8369
 
1377bb8
8465f52
 
1377bb8
5c5bd98
8465f52
 
 
 
 
 
 
 
 
 
ec11b9a
db2e7bb
5c5bd98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from turtle import color, onclick
import streamlit as st
from PIL import Image, ImageOps
import glob
import json
import requests
import random
import io

if 'show' not in st.session_state:
    st.session_state.show = False

if 'example_idx' not in st.session_state:
    st.session_state.example_idx = 0

st.set_page_config(layout="wide")
st.markdown("**This is a demo of the *ImageCoDe* dataset. Sample an example with the counter on the right and compare all the images with index counter on the right. Toggle the buttons to show/hide the groundtruth target image!**")

col1, col2 = st.columns(2)

prefix = 'https://raw.githubusercontent.com/BennoKrojer/imagecode-val-set/main/image-sets-val/'
set2ids = json.load(open('set2ids.json', 'r'))
descriptions = json.load(open('valid_list.json', 'r'))

if col1.button('Show groundtruth target image'):
    st.session_state.show = True
if col2.button('Hide groundtruth target image'):
    st.session_state.show = False

# example_idx = int(col1.number_input('Sample an example (description + corresponding images) from the validation set', value=0, min_value=0, max_value=len(descriptions)-1))
if col1.button('Sample an example (description + corresponding images) from the validation set'):
    st.session_state.example_idx = random.randint(0, len(descriptions)-1)
img_set, idx, descr = descriptions[st.session_state.example_idx]
idx = int(idx)
images = [prefix+'/'+img_set+'/'+i for i in set2ids[img_set]]
img_urls = images.copy()
index = int(col2.number_input('Image Index from 0 to 9', value=0, min_value=0, max_value=9))


col1.markdown(f'**Description**:')
col1.markdown(descr)

img = images[index]
images[index] = ImageOps.expand(Image.open(io.BytesIO(requests.get(images[index], stream=True).content)),border=20,fill='blue')

caps = list(range(10))
cap = str(index)

if st.session_state.show:
    caps[idx] = f'{idx} (TARGET IMAGE)'
    if idx == index:
        cap = f'{idx} (TARGET IMAGE)'
else:
    caps[idx] = f'{idx}'
    if idx == index:
        cap = f'{idx}'

col1.image(img, use_column_width=True, caption=cap)
col2.image(images, width=175, caption=caps)