sahalhes commited on
Commit
3bfce0f
·
1 Parent(s): 1da5841
Files changed (2) hide show
  1. app.py +16 -22
  2. requirements.txt +2 -3
app.py CHANGED
@@ -1,33 +1,27 @@
1
  import gradio as gr
2
- from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
3
- import torch
4
- from PIL import Image
5
 
6
- model_name = "nlpconnect/vit-gpt2-image-captioning"
7
- model = VisionEncoderDecoderModel.from_pretrained(model_name)
8
- feature_extractor = ViTImageProcessor.from_pretrained(model_name)
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
10
-
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model.to(device)
13
 
 
14
  def generate_caption(image):
15
- if image is None:
16
- return "Please upload an image."
17
-
18
- pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
19
- pixel_values = pixel_values.to(device)
20
 
21
- output_ids = model.generate(pixel_values, max_length=16, num_beams=4)
22
- caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
23
- return caption
24
 
25
  demo = gr.Interface(
26
  fn=generate_caption,
27
  inputs=gr.Image(type="pil"),
28
- outputs="text",
29
- title="🖼️ Image Caption Generator",
30
- description="Upload an image and get a caption describing it using a VisionEncoderDecoder model (ViT + GPT-2)."
 
31
  )
32
 
33
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
 
3
 
4
+ # Load pipeline
5
+ captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
 
 
 
 
 
6
 
7
+ # Define Gradio interface function
8
  def generate_caption(image):
9
+ result = captioner(image)
10
+ return result[0]["generated_text"].strip()
 
 
 
11
 
12
+ # Create Gradio app
13
+ title = "🖼️ Image Caption Generator using ViT-GPT2"
14
+ description = "Upload an image and get a descriptive caption using the `nlpconnect/vit-gpt2-image-captioning` model by Hugging Face Transformers."
15
 
16
  demo = gr.Interface(
17
  fn=generate_caption,
18
  inputs=gr.Image(type="pil"),
19
+ outputs=gr.Textbox(label="Generated Caption"),
20
+ title=title,
21
+ description=description,
22
+ allow_flagging="never"
23
  )
24
 
25
+ # Launch the app
26
+ if __name__ == "__main__":
27
+ demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- transformers
2
  torch
3
- gradio
4
- Pillow
 
1
+ transformers>=4.37.0
2
  torch
3
+ gradio>=4.0.0