Sravanth commited on
Commit
71feab2
·
1 Parent(s): 1ff279b

create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel
3
+ import torch
4
+ import open_clip
5
+
6
+ from huggingface_hub import hf_hub_download
7
+
8
+
9
+ git_processor_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
10
+ git_model_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
11
+
12
+ git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
13
+ git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
14
+
15
+
16
+ blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
17
+ blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
18
+
19
+ blip2_processor_8_bit = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b")
20
+ blip2_model_8_bit = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", load_in_8bit=True)
21
+
22
+ coca_model, _, coca_transform = open_clip.create_model_and_transforms(
23
+ model_name="coca_ViT-L-14",
24
+ pretrained="mscoco_finetuned_laion2B-s13B-b90k"
25
+ )
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ git_model_large_coco.to(device)
30
+ git_model_large_textcaps.to(device)
31
+ blip_model_large.to(device)
32
+ coca_model.to(device)
33
+
34
+ def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
35
+ inputs = processor(images=image, return_tensors="pt").to(device)
36
+
37
+ if use_float_16:
38
+ inputs = inputs.to(torch.float16)
39
+
40
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
41
+
42
+ if tokenizer is not None:
43
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
44
+ else:
45
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
46
+
47
+ return generated_caption
48
+
49
+
50
+ def generate_caption_coca(model, transform, image):
51
+ im = transform(image).unsqueeze(0).to(device)
52
+ with torch.no_grad(), torch.cuda.amp.autocast():
53
+ generated = model.generate(im, seq_len=20)
54
+ return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
55
+
56
+
57
+ def generate_captions(image):
58
+
59
+ caption_git_large_coco = generate_caption(git_processor_large_coco, git_model_large_coco, image)
60
+
61
+ caption_git_large_textcaps = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)
62
+
63
+ caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)
64
+
65
+ caption_coca = generate_caption_coca(coca_model, coca_transform, image)
66
+
67
+ caption_blip2_8_bit = generate_caption(blip2_processor_8_bit, blip2_model_8_bit, image, use_float_16=True).strip()
68
+
69
+ return caption_git_large_coco, caption_git_large_textcaps, caption_blip_large, caption_coca, caption_blip2_8_bit
70
+
71
+
72
+
73
+ examples = [["Image1.jpg"], ["Image2.jpg"], ["Image3.jpg"]]
74
+ outputs = [gr.outputs.Textbox(label="Caption generated - 1"), gr.outputs.Textbox(label="Caption generated - 2"), gr.outputs.Textbox(label="Caption generated -3"), gr.outputs.Textbox(label="Caption generated - 4"), gr.outputs.Textbox(label="Caption generated - 5")]
75
+
76
+
77
+ title = "Interactive demo: comparing image captioning models"
78
+ description = "Image Caption Generator by Sravanth Kurmala"
79
+ article = "Assignment for Listed Inc"
80
+
81
+ interface = gr.Interface(fn=generate_captions,
82
+ inputs=gr.inputs.Image(type="pil"),
83
+ outputs=outputs,
84
+ examples=examples,
85
+ title=title,
86
+ description=description,
87
+ article=article,
88
+ enable_queue=True)
89
+ interface.launch(debug=True)