File size: 7,180 Bytes
5eca0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16bb651
 
 
 
 
 
5eca0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234acaf
5eca0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0130234
 
5eca0b2
 
0130234
 
 
 
5eca0b2
 
16bb651
5eca0b2
 
0130234
 
 
5eca0b2
 
 
16bb651
5eca0b2
 
 
 
16bb651
 
5eca0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0130234
5eca0b2
 
 
 
 
 
 
 
 
 
0130234
5eca0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234acaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
app.py – Token‑wise heatmaps with JinaV4SimilarityMapper (similarity4)

───────── Feature checklist (for future edits) ─────────
1. Prompt + Image‑URL inputs (placeholders).                       βœ…
2. Run β†’ downloads image (β‰₯512β€―h), gets tokens+heatmaps,           βœ…
   auto‑selects first token, shows overlay, clears inputs.
3. Output widgets hidden until results are ready.                  βœ…
4. Every run saved to examples/auto_<timestamp>/ with:
     β€’ prompt.txt, img_url.txt, image.jpg,
     β€’ heatmaps.json, per‑token PNGs, preview_first_token.jpg.     βœ…
5. On startup, first 3 example folders rendered below output with
   layout: Prompt β†’ Image URL β†’ Tokens β†’ Image+Heatmap.            βœ…
6. Margins: 40β€―px before β€œExamples” heading, 25β€―px between         βœ…
   successive examples (no separators, no extra HTML).
7. Works on gradio==5.35.0 (no gr.Box, no Button.style, etc.).     βœ…
"""

import sys, signal, base64, re, io, json, time
from io import BytesIO
from pathlib import Path
from typing import Dict
import subprocess

import requests
import gradio as gr
from PIL import Image
from similarity import JinaV4SimilarityMapper

EX_DIR = Path("examples"); EX_DIR.mkdir(exist_ok=True)


ButtonsLike = gr.Radio
def buttons_update(toks):
    first = toks[0] if toks else None
    return gr.update(choices=toks, value=first, visible=True)

# ───────── util functions ─────────
def _slug(t: str, n: int = 60) -> str:
    return re.sub(r"[^\w\-]+", "_", t.lower())[:n] or "x"

def overlay(tok: str, maps: Dict[str, str], base: Image.Image) -> Image.Image:
    if tok not in maps:
        return base
    hm = Image.open(BytesIO(base64.b64decode(maps[tok]))).convert("RGBA")
    if hm.size != base.size:
        hm = hm.resize(base.size, Image.BILINEAR)
    return Image.alpha_composite(base.convert("RGBA"), hm)

def save_run(prompt: str, url: str, img: Image.Image, maps: Dict[str, str]) -> None:
    ts   = time.strftime("%Y%m%d_%H%M%S")
    fldr = EX_DIR / f"auto_{_slug(prompt,30)}_{ts}"
    fldr.mkdir(parents=True, exist_ok=True)

    (fldr / "prompt.txt").write_text(prompt)
    (fldr / "img_url.txt").write_text(url)
    img.convert("RGB").save(fldr / "image.jpg", "JPEG")

    with (fldr / "heatmaps.json").open("w") as f:
        json.dump(maps, f)

    for tok, b64png in maps.items():
        (fldr / f"heatmap_{_slug(tok,30)}.png").write_bytes(base64.b64decode(b64png))

    first = next(iter(maps))
    overlay(first, maps, img).convert("RGB").save(fldr / "preview_first_token.jpg", "JPEG")
    print(f"✨ Saved run to {fldr}", flush=True)

# ───────── load mapper ─────────
print("⏳ Loading JinaV4SimilarityMapper …", flush=True)
MAPPER = JinaV4SimilarityMapper(client_type="web")
print("βœ… Mapper ready.", flush=True)

# ───────── load up to 3 example folders ─────────
def load_examples(n: int = 3):
    ex = []
    for fld in sorted(EX_DIR.iterdir())[:n]:
        p_txt, p_url, p_img, p_map = fld/"prompt.txt", fld/"img_url.txt", None, fld/"heatmaps.json"
        for c in fld.glob("image.*"): p_img = c; break
        if not (p_txt.exists() and p_url.exists() and p_img and p_map.exists()): continue
        ex.append(dict(
            prompt=p_txt.read_text().strip(),
            url   =p_url.read_text().strip(),
            base  =Image.open(p_img).convert("RGB"),
            maps  =json.load(open(p_map))
        ))
    return ex

static_examples = load_examples()

# ───────── backend for user Run ─────────
def run_mapper(prompt: str, img_url: str, api_key: str):
    new_client = JinaV4SimilarityMapper(client_type="web")
    if not img_url:
        raise gr.Error("Please provide an image URL.")
    if not prompt:
        raise gr.Error("Please provide a prompt.")
    if not api_key:
        raise gr.Error("Please provide a valid API key.")
    try:
        r = requests.get(img_url, stream=True, timeout=10); r.raise_for_status()
        img = Image.open(io.BytesIO(r.content)).convert("RGB")
    except Exception as e:
        raise gr.Error(f"Image load failed: {e}")
    new_client.model.set_api_key(api_key)
    img_proc, _, _ =  new_client.process_image(img_url)
    toks, maps = new_client.get_token_similarity_maps(prompt, img_proc)
    if not toks:
        raise gr.Error("Mapper returned no tokens.")

    # save_run(prompt, img_url, img_proc, maps)

    first_tok = toks[0]
    info      = f"**Prompt:** {prompt}\n\n**Image URL:** {img_url}"
    return (
        buttons_update(toks), maps, img_proc,
        gr.update(value=overlay(first_tok, maps, img_proc), visible=True),
        gr.update(value=info, visible=True),
        "", "")

# ───────── UI ─────────
css = """
#main-title { margin-bottom: 40px; }
#run-btn { margin: 20px 0; }
#examples-title { margin: 40px 0; }
.example-space { margin: 20px 0; }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Jina Embeddings v4", elem_id="main-title")

    # User input
    prompt_in = gr.Textbox(label="Prompt", placeholder="Describe what to query…")
    url_in    = gr.Textbox(label="Image URL", placeholder="https://example.com/image.jpg")
    api_key_in = gr.Textbox(label="API Key", placeholder="Enter your Jina API key here")
    run_btn   = gr.Button("Run", elem_id="run-btn")

    # Output area
    info_md   = gr.Markdown(visible=False)
    token_sel = ButtonsLike(choices=[], label="Tokens", interactive=True, visible=False)
    maps_st   = gr.State({})
    img_st    = gr.State(None)
    img_out   = gr.Image(label="Image + Heatmap", visible=False)

    run_btn.click(run_mapper,
        [prompt_in, url_in, api_key_in],
        [token_sel, maps_st, img_st, img_out, info_md, prompt_in, url_in])

    (token_sel.select if hasattr(token_sel,"select") else token_sel.change)(
        overlay, [token_sel, maps_st, img_st], [img_out])

    # Margin before examples heading
    gr.Markdown("## Examples", elem_id="examples-title")

    # Render examples
    for ex in static_examples:
        gr.Markdown(f"**Prompt:** {ex['prompt']}")
        gr.Markdown(f"**Image URL:** {ex['url']}")
        ex_img_st  = gr.State(ex["base"])
        ex_map_st  = gr.State(ex["maps"])
        first      = next(iter(ex["maps"]))
        ex_btns    = ButtonsLike(choices=list(ex["maps"].keys()), value=first, interactive=True)
        ex_disp    = gr.Image(value=overlay(first, ex["maps"], ex["base"]))
        (ex_btns.select if hasattr(ex_btns,"select") else ex_btns.change)(
            overlay, [ex_btns, ex_map_st, ex_img_st], [ex_disp])
        # vertical margin after each example
        gr.Markdown("", elem_classes=["example-space"])

# ───────── graceful shutdown ─────────
def _shutdown(*_): print("πŸ›‘ Shutting down …", flush=True); demo.close(); sys.exit(0)
signal.signal(signal.SIGINT, _shutdown); signal.signal(signal.SIGTERM, _shutdown)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, share=False)