File size: 4,597 Bytes
bc7c311
 
 
 
4ebfb1d
a6d4446
 
2b8e0e9
bc7c311
c04a40d
 
 
fbb5686
c04a40d
 
 
 
fbb5686
a4e28e0
 
 
6c919b1
a6d4446
 
 
 
 
 
 
 
 
 
 
 
bc7c311
 
 
4ebfb1d
bc7c311
 
6826959
bc7c311
ae1712c
 
6826959
bc7c311
 
ca567df
6826959
bc7c311
b966b7e
88dc9b2
4ebfb1d
 
 
 
 
 
 
 
bc7c311
 
 
4ebfb1d
bc7c311
a6d4446
 
 
 
eea50ca
8313b51
 
 
 
a6d4446
0dd7d52
e43af19
4ebfb1d
bc7c311
 
 
 
 
bbdbb0c
 
bc7c311
 
 
8df8184
1c2df8b
bbdbb0c
fb25698
a80cfc9
 
 
bc7c311
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from turtle import title
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image  
from dotenv import load_dotenv
import google.generativeai as genai
import os

# from openai import OpenAI
# client = OpenAI(api_key="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")
# import base64

# # Open the image file and encode it as a base64 string
# def encode_image(image_path):
#     with open(image_path, "rb") as image_file:
#         return base64.b64encode(image_file.read()).decode("utf-8")

#load_dotenv()
#GOOGLE_API_KEY = os.getenv("AIzaSyByqW3ByYPxC4xLS_NhgwAOAMMEgB7DvoY")
genai.configure(api_key="AIzaSyByqW3ByYPxC4xLS_NhgwAOAMMEgB7DvoY")
model_vision = genai.GenerativeModel('gemini-1.5-flash')

def gemini_response_vision(input_texts, image):
    try:
        if input_texts != "":
            response2 = model_vision.generate_content([input_texts, image])
        else:
            response2 = model_vision.generate_content(image)

        return response2.text

    except Exception as e:
        raise e

pipes = {
    "ViT/B-16": pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch16"),
    "ViT/L-14": pipeline("zero-shot-image-classification", model="openai/clip-vit-large-patch14"),
}
inputs = [
    gr.Image(type='pil', 
                    label="Image"),
    gr.Textbox(lines=1, 
                      label="Candidate Labels", placeholder="Add a class label, one by one"),
    gr.Radio(choices=[
                                "ViT/B-16",
                                "ViT/L-14", 
                            ], type="value", label="Model"), 
    gr.Textbox(lines=1, 
                      label="Prompt Template Prompt", 
                      placeholder="Optional prompt template as prefix",
                      value="a photo of a {}"),

    gr.Textbox(lines=1, 
                      label="Prompt Template Prompt", 
                      placeholder="Optional prompt template as suffix",
                      value="in {} {} {} from {} with {}."),
    
    gr.Textbox(lines=1, 
                      label="Prior Domains", placeholder="Add a domain label, one by one"),
]
images="festival.jpg"

def shot(image, labels_text, model_name, hypothesis_template_prefix, hypothesis_template_suffix, domains_text):
    labels = [label.strip(" ") for label in labels_text.strip(" ").split(",")]

    if not domains_text == '':
        domains = [domain.strip(" ") for domain in domains_text.strip(" ").split(",")]
    else:
        input_text = "You are an expert for domain knowledge analysis. Please describe the image from six domain shifts, including Weather (clear, sandstorm, foggy, rainy, snowy), Season (spring-summer, autumn, winter), Time (day, night), Angle (front, side, top) and Occlusion (no occlusion, light occlusion, partial occlusion, moderate occlusion, heavy occlusion). You are supposed to recognize each domain from the above domain shifts based on the image. Finally, you only need to output a list of domains like [clear, autumn, night, front, light occlusion]"
        ans = gemini_response_vision(input_texts=input_text, image=image)
        print(ans)
        domains = [domain.strip(" ") for domain in ans.strip("[").strip("]").split(",")]
        
        
    hypothesis_template = hypothesis_template_prefix + ' ' + hypothesis_template_suffix.format(*domains)
    print(hypothesis_template)
    
    res = pipes[model_name](images=image, 
           candidate_labels=labels,
           hypothesis_template=hypothesis_template)
    return {dic["label"]: dic["score"] for dic in res}

#clear, winter, day, front, moderate occlusion

iface = gr.Interface(shot, 
            inputs, 
            "label", 
            examples=[
                    #["festival.jpg", "lantern, firecracker, couplet", "ViT/B-16", "a photo of a {}", "in {} {} {} from {} with {}.", "clear, autumn, day, side, partial occlusion"],
                     ["car.png", "car, bike, truck", "ViT/B-16", "a photo of a {}", "in {} {} {} from {} with {}.", ""]], 
            description="""<p> <br><br>
            Paper: <a href='https://arxiv.org/pdf/2403.02714'>https://arxiv.org/pdf/2403.02714</a> <br>
            To begin with the demo, provide a picture (either upload manually, or select from the given examples) and add class labels one by one. Optionally, you can also add template as a prefix to the class labels. <br>""",
            title="Cross-Domain Recognition")

iface.launch()