Spaces:
Sleeping
Sleeping
from transformers import CLIPModel, CLIPProcessor | |
import time | |
import gradio as gr | |
def get_zero_shot_classification_tab(): | |
openai_model_name = "openai/clip-vit-large-patch14" | |
openai_model = CLIPModel.from_pretrained(openai_model_name) | |
openai_processor = CLIPProcessor.from_pretrained(openai_model_name) | |
patrickjohncyh_model_name = "patrickjohncyh/fashion-clip" | |
patrickjohncyh_model = CLIPModel.from_pretrained(patrickjohncyh_model_name) | |
patrickjohncyh_processor = CLIPProcessor.from_pretrained(patrickjohncyh_model_name) | |
model_map = { | |
openai_model_name: (openai_model, openai_processor), | |
patrickjohncyh_model_name: (patrickjohncyh_model, patrickjohncyh_processor) | |
} | |
def gradio_process(model_name, image, text): | |
(model, processor) = model_map[model_name] | |
labels = text.split(", ") | |
print (labels) | |
start = time.time() | |
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
outputs = model(**inputs) | |
probs = outputs.logits_per_image.softmax(dim=1)[0] | |
end = time.time() | |
time_spent = end - start | |
probs = list(probs) | |
results = [] | |
for i in range(len(labels)): | |
results.append(f"{labels[i]} - {probs[i].item():.4f}") | |
result = "\n".join(results) | |
return [result, time_spent] | |
with gr.TabItem("Zero-Shot Classification") as zero_shot_image_classification_tab: | |
gr.Markdown("# Zero-Shot Image Classification") | |
with gr.Row(): | |
with gr.Column(): | |
# Input components | |
input_image = gr.Image(label="Upload Image", type="pil") | |
input_text = gr.Textbox(label="Labels (comma separated)") | |
model_selector = gr.Dropdown([openai_model_name, patrickjohncyh_model_name], | |
label = "Select Model") | |
# Process button | |
process_btn = gr.Button("Classificate") | |
with gr.Column(): | |
# Output components | |
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1) | |
output_text = gr.Textbox(label="Classification") | |
# Connect the input components to the processing function | |
process_btn.click( | |
fn=gradio_process, | |
inputs=[ | |
model_selector, | |
input_image, | |
input_text | |
], | |
outputs=[output_text, elapsed_result] | |
) | |
return zero_shot_image_classification_tab | |