Spaces:
Sleeping
Sleeping
from turtle import title | |
import gradio as gr | |
from transformers import pipeline | |
import numpy as np | |
from PIL import Image | |
from dotenv import load_dotenv | |
import google.generativeai as genai | |
import os | |
genai.configure(api_key="AIzaSyByqW3ByYPxC4xLS_NhgwAOAMMEgB7DvoY") | |
model_vision = genai.GenerativeModel('gemini-1.5-flash') | |
def gemini_response_vision(input_texts, image): | |
try: | |
if input_texts != "": | |
response2 = model_vision.generate_content([input_texts, image]) | |
else: | |
response2 = model_vision.generate_content(image) | |
return response2.text | |
except Exception as e: | |
raise e | |
pipes = { | |
"ViT/B-16": pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch16"), | |
"ViT/L-14": pipeline("zero-shot-image-classification", model="openai/clip-vit-large-patch14"), | |
} | |
inputs = [ | |
gr.Image(type='pil', | |
label="Image"), | |
gr.Textbox(lines=1, | |
label="Candidate Labels", placeholder="Add a class label, one by one"), | |
gr.Radio(choices=[ | |
"ViT/B-16", | |
"ViT/L-14", | |
], type="value", label="Model"), | |
gr.Textbox(lines=1, | |
label="Prompt Template Prompt", | |
placeholder="Optional prompt template as prefix", | |
value="a photo of a {}"), | |
gr.Textbox(lines=1, | |
label="Prompt Template Prompt", | |
placeholder="Optional prompt template as suffix", | |
value="in {} {} {} from {} with {}."), | |
gr.Textbox(lines=1, | |
label="Prior Domains", placeholder="Add a domain label, one by one"), | |
] | |
def shot(image, labels_text, model_name, hypothesis_template_prefix, hypothesis_template_suffix, domains_text): | |
labels = [label.strip(" ") for label in labels_text.strip(" ").split(",")] | |
if not domains_text == '': | |
domains = [domain.strip(" ") for domain in domains_text.strip(" ").split(",")] | |
else: | |
input_text = "You are an expert for domain knowledge analysis. Please describe the image from six domain shifts, including Weather (clear, sandstorm, foggy, rainy, snowy), Season (spring-summer, autumn, winter), Time (day, night), Angle (front, side, top) and Occlusion (no occlusion, light occlusion, partial occlusion, moderate occlusion, heavy occlusion). You are supposed to recognize each domain from the above domain shifts based on the image. Finally, you only need to output a list of domains like [clear, autumn, night, front, light occlusion]" | |
ans = gemini_response_vision(input_texts=input_text, image=image) | |
print(ans) | |
domains = [domain.strip(" ") for domain in ans.strip("[").strip("]").split(",")] | |
hypothesis_template = hypothesis_template_prefix + ' ' + hypothesis_template_suffix.format(*domains) | |
print(hypothesis_template) | |
res = pipes[model_name](images=image, | |
candidate_labels=labels, | |
hypothesis_template=hypothesis_template) | |
return {dic["label"]: dic["score"] for dic in res} | |
#clear, winter, day, front, moderate occlusion | |
iface = gr.Interface(shot, | |
inputs, | |
"label", | |
examples=[ | |
["car_occlusion.png", "car, bus, truck", "ViT/B-16", "a photo of a {}", "in {} {} {} from {} with {}.", ""], | |
["foggy_motorcycle.png", "car, motorcycle, truck", "ViT/B-16", "a photo of a {}", "in {} {} {} from {} with {}.", ""], | |
["night_truck.png", "car, bus, truck", "ViT/B-16", "a photo of a {}", "in {} {} {} from {} with {}.", "clear, autumn, night, side, no occlusion"], | |
["rainy_front_truck.png", "car, bus, truck", "ViT/B-16", "a photo of a {}", "in {} {} {} from {} with {}.", ""], | |
["truck_top.png", "car, bus, truck", "ViT/B-16", "a photo of a {}", "in {} {} {} from {} with {}.", "clear, summer, day, top, no occlusion"],], | |
description="""<p> Demo for Domain CLIP, an algorithm for Training-Free Adaptive Domain Generalization. For more information about our project, refer to our paper.<br><br> | |
Paper: <a href='https://arxiv.org/pdf/2403.02714'>https://arxiv.org/pdf/2403.02714</a> <br> | |
To begin with the demo, provide a picture (either upload manually, or select from the given examples) and add class labels one by one. Optionally, you can also add template as a prefix to the class labels. <br>""", | |
title="Cross-Domain Recognition") | |
iface.launch() |