qingshan777 commited on
Commit
978d150
·
verified ·
1 Parent(s): 14502a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import io
5
+ from PIL import Image
6
+ from transformers import (
7
+ AutoImageProcessor,
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ )
11
+ import numpy as np
12
+ model_root = "qihoo360/fg-clip-base"
13
+
14
+ model = AutoModelForCausalLM.from_pretrained(model_root,trust_remote_code=True)
15
+ device = model.device
16
+ tokenizer = AutoTokenizer.from_pretrained(model_root)
17
+ image_processor = AutoImageProcessor.from_pretrained(model_root)
18
+
19
+ import math
20
+ import matplotlib
21
+ matplotlib.use('Agg')
22
+ import matplotlib.pyplot as plt
23
+
24
+
25
+
26
+
27
+
28
+
29
+ def Get_Densefeature(image, candidate_labels):
30
+ """
31
+ Takes an image and a comma-separated string of candidate labels,
32
+ and returns the classification scores.
33
+ """
34
+ candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",") if label !=""]
35
+ # print(candidate_labels)
36
+
37
+ image_size=224
38
+ image = image.convert("RGB")
39
+ image = image.resize((image_size,image_size))
40
+ image_input = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].to(device)
41
+
42
+ with torch.no_grad():
43
+ dense_image_feature = model.get_image_dense_features(image_input)
44
+ captions = [candidate_labels[0]]
45
+ caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=device)
46
+ text_feature = model.get_text_features(caption_input,walk_short_pos=True)
47
+ text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
48
+ dense_image_feature = dense_image_feature / dense_image_feature.norm(p=2, dim=-1, keepdim=True)
49
+
50
+ similarity = dense_image_feature.squeeze() @ text_feature.squeeze().T
51
+ similarity = similarity.cpu().numpy()
52
+ patch_size = int(math.sqrt(similarity.shape[0]))
53
+
54
+
55
+ original_shape = (patch_size, patch_size)
56
+ show_image = similarity.reshape(original_shape)
57
+
58
+
59
+ fig = plt.figure(figsize=(6, 6))
60
+ plt.imshow(show_image)
61
+ plt.title('similarity Visualization')
62
+ plt.axis('off')
63
+
64
+ buf = io.BytesIO()
65
+ plt.savefig(buf, format='png')
66
+ buf.seek(0)
67
+ plt.close(fig)
68
+
69
+ pil_img = Image.open(buf)
70
+ # buf.close()
71
+ return pil_img
72
+
73
+
74
+
75
+
76
+
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("# FG-CLIP Densefeature")
79
+ gr.Markdown(
80
+
81
+ "This app uses the FG-CLIP model (qihoo360/fg-clip-base) for Densefeature show on CPU :"
82
+ )
83
+ gr.Markdown(
84
+ "<span style='color: red; font-weight: bold;'>⚠️ (Run DenseFeature) only support one class</span>"
85
+ )
86
+
87
+ with gr.Row():
88
+ with gr.Column():
89
+ image_input = gr.Image(type="pil")
90
+ text_input = gr.Textbox(label="Input a label")
91
+ dfs_button = gr.Button("Run Densefeature", visible=True)
92
+ with gr.Column():
93
+ dfs_output = gr.Image(label="Similarity Visualization", type="pil")
94
+
95
+ examples = [
96
+ ["./cat_dfclor.jpg", "white cat,"],
97
+ ]
98
+ gr.Examples(
99
+ examples=examples,
100
+ inputs=[image_input, text_input],
101
+
102
+ )
103
+ dfs_button.click(fn=Get_Densefeature, inputs=[image_input, text_input], outputs=dfs_output)
104
+ demo.launch()