Spaces:
Build error
Build error
Commit
·
0ab9a32
0
Parent(s):
Duplicate from TencentARC/Caption-Anything
Browse filesCo-authored-by: wybertwang <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -0
- .gitignore +135 -0
- DejaVuSansCondensed-Bold.ttf +0 -0
- Image/demo1.svg +0 -0
- Image/demo2.svg +0 -0
- Image/title.svg +1 -0
- LICENSE +28 -0
- README.md +14 -0
- app.py +369 -0
- app_huggingface.py +268 -0
- app_old.py +261 -0
- caas.py +114 -0
- caption_anything.py +132 -0
- captioner/README.md +13 -0
- captioner/__init__.py +15 -0
- captioner/base_captioner.py +200 -0
- captioner/blip.py +66 -0
- captioner/blip2.py +56 -0
- captioner/git.py +57 -0
- captioner/modeling_blip.py +1476 -0
- captioner/modeling_git.py +1587 -0
- captioner/vit_pixel_masks_utils.py +17 -0
- env.sh +6 -0
- image_editing_utils.py +69 -0
- requirements.txt +20 -0
- segmenter/__init__.py +8 -0
- segmenter/base_segmenter.py +156 -0
- segmenter/images/truck.jpg +0 -0
- segmenter/readme.md +68 -0
- segmenter/sam_vit_h_4b8939.pth +3 -0
- test_img/img0.png +0 -0
- test_img/img1.jpg +0 -0
- test_img/img1.jpg.raw_mask.png +0 -0
- test_img/img10.jpg +0 -0
- test_img/img10.jpg.raw_mask.png +0 -0
- test_img/img11.jpg +0 -0
- test_img/img12.jpg +0 -0
- test_img/img12.jpg.raw_mask.png +0 -0
- test_img/img13.jpg +0 -0
- test_img/img13.jpg.raw_mask.png +0 -0
- test_img/img14.jpg +0 -0
- test_img/img14.jpg.raw_mask.png +0 -0
- test_img/img15.jpg +0 -0
- test_img/img15.jpg.raw_mask.png +0 -0
- test_img/img16.jpg +0 -0
- test_img/img16.jpg.raw_mask.png +0 -0
- test_img/img17.jpg +0 -0
- test_img/img18.jpg +3 -0
- test_img/img19.jpg +0 -0
- test_img/img2.jpg +0 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
test_img/img18.jpg filter=lfs diff=lfs merge=lfs -text
|
36 |
+
test_img/img22.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
result/
|
2 |
+
model_cache/
|
3 |
+
*.pth
|
4 |
+
teng_grad_start.sh
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
result/
|
11 |
+
|
12 |
+
# C extensions
|
13 |
+
*.so
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
pip-wheel-metadata/
|
30 |
+
share/python-wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
*.py,cover
|
57 |
+
.hypothesis/
|
58 |
+
.pytest_cache/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
.python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
101 |
+
__pypackages__/
|
102 |
+
|
103 |
+
# Celery stuff
|
104 |
+
celerybeat-schedule
|
105 |
+
celerybeat.pid
|
106 |
+
|
107 |
+
# SageMath parsed files
|
108 |
+
*.sage.py
|
109 |
+
|
110 |
+
# Environments
|
111 |
+
.env
|
112 |
+
.venv
|
113 |
+
env/
|
114 |
+
venv/
|
115 |
+
ENV/
|
116 |
+
env.bak/
|
117 |
+
venv.bak/
|
118 |
+
|
119 |
+
# Spyder project settings
|
120 |
+
.spyderproject
|
121 |
+
.spyproject
|
122 |
+
|
123 |
+
# Rope project settings
|
124 |
+
.ropeproject
|
125 |
+
|
126 |
+
# mkdocs documentation
|
127 |
+
/site
|
128 |
+
|
129 |
+
# mypy
|
130 |
+
.mypy_cache/
|
131 |
+
.dmypy.json
|
132 |
+
dmypy.json
|
133 |
+
|
134 |
+
# Pyre type checker
|
135 |
+
.pyre/
|
DejaVuSansCondensed-Bold.ttf
ADDED
Binary file (632 kB). View file
|
|
Image/demo1.svg
ADDED
|
Image/demo2.svg
ADDED
|
Image/title.svg
ADDED
|
LICENSE
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2023, Teng Wang
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
3. Neither the name of the copyright holder nor the names of its
|
16 |
+
contributors may be used to endorse or promote products derived from
|
17 |
+
this software without specific prior written permission.
|
18 |
+
|
19 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Caption Anything
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.24.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: TencentARC/Caption-Anything
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import string
|
3 |
+
import gradio as gr
|
4 |
+
import requests
|
5 |
+
from caption_anything import CaptionAnything
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
from caption_anything import parse_augment
|
11 |
+
import numpy as np
|
12 |
+
import PIL.ImageDraw as ImageDraw
|
13 |
+
from image_editing_utils import create_bubble_frame
|
14 |
+
import copy
|
15 |
+
from tools import mask_painter
|
16 |
+
from PIL import Image
|
17 |
+
import os
|
18 |
+
from captioner import build_captioner
|
19 |
+
from segment_anything import sam_model_registry
|
20 |
+
from text_refiner import build_text_refiner
|
21 |
+
from segmenter import build_segmenter
|
22 |
+
|
23 |
+
def download_checkpoint(url, folder, filename):
|
24 |
+
os.makedirs(folder, exist_ok=True)
|
25 |
+
filepath = os.path.join(folder, filename)
|
26 |
+
|
27 |
+
if not os.path.exists(filepath):
|
28 |
+
response = requests.get(url, stream=True)
|
29 |
+
with open(filepath, "wb") as f:
|
30 |
+
for chunk in response.iter_content(chunk_size=8192):
|
31 |
+
if chunk:
|
32 |
+
f.write(chunk)
|
33 |
+
|
34 |
+
return filepath
|
35 |
+
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
36 |
+
folder = "segmenter"
|
37 |
+
filename = "sam_vit_h_4b8939.pth"
|
38 |
+
|
39 |
+
download_checkpoint(checkpoint_url, folder, filename)
|
40 |
+
|
41 |
+
|
42 |
+
title = """<h1 align="center">Caption-Anything</h1>"""
|
43 |
+
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
|
44 |
+
"""
|
45 |
+
|
46 |
+
examples = [
|
47 |
+
["test_img/img35.webp"],
|
48 |
+
["test_img/img2.jpg"],
|
49 |
+
["test_img/img5.jpg"],
|
50 |
+
["test_img/img12.jpg"],
|
51 |
+
["test_img/img14.jpg"],
|
52 |
+
["test_img/img0.png"],
|
53 |
+
["test_img/img1.jpg"],
|
54 |
+
]
|
55 |
+
|
56 |
+
args = parse_augment()
|
57 |
+
# args.device = 'cuda:5'
|
58 |
+
# args.disable_gpt = True
|
59 |
+
# args.enable_reduce_tokens = False
|
60 |
+
# args.port=20322
|
61 |
+
# args.captioner = 'blip'
|
62 |
+
# args.regular_box = True
|
63 |
+
shared_captioner = build_captioner(args.captioner, args.device, args)
|
64 |
+
shared_sam_model = sam_model_registry['vit_h'](checkpoint=args.segmenter_checkpoint).to(args.device)
|
65 |
+
|
66 |
+
|
67 |
+
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
|
68 |
+
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
69 |
+
captioner = captioner
|
70 |
+
if session_id is not None:
|
71 |
+
print('Init caption anything for session {}'.format(session_id))
|
72 |
+
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
|
73 |
+
|
74 |
+
|
75 |
+
def init_openai_api_key(api_key=""):
|
76 |
+
text_refiner = None
|
77 |
+
if api_key and len(api_key) > 30:
|
78 |
+
try:
|
79 |
+
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
80 |
+
text_refiner.llm('hi') # test
|
81 |
+
except:
|
82 |
+
text_refiner = None
|
83 |
+
openai_available = text_refiner is not None
|
84 |
+
return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), text_refiner
|
85 |
+
|
86 |
+
|
87 |
+
def get_prompt(chat_input, click_state, click_mode):
|
88 |
+
inputs = json.loads(chat_input)
|
89 |
+
if click_mode == 'Continuous':
|
90 |
+
points = click_state[0]
|
91 |
+
labels = click_state[1]
|
92 |
+
for input in inputs:
|
93 |
+
points.append(input[:2])
|
94 |
+
labels.append(input[2])
|
95 |
+
elif click_mode == 'Single':
|
96 |
+
points = []
|
97 |
+
labels = []
|
98 |
+
for input in inputs:
|
99 |
+
points.append(input[:2])
|
100 |
+
labels.append(input[2])
|
101 |
+
click_state[0] = points
|
102 |
+
click_state[1] = labels
|
103 |
+
else:
|
104 |
+
raise NotImplementedError
|
105 |
+
|
106 |
+
prompt = {
|
107 |
+
"prompt_type":["click"],
|
108 |
+
"input_point":click_state[0],
|
109 |
+
"input_label":click_state[1],
|
110 |
+
"multimask_output":"True",
|
111 |
+
}
|
112 |
+
return prompt
|
113 |
+
|
114 |
+
def update_click_state(click_state, caption, click_mode):
|
115 |
+
if click_mode == 'Continuous':
|
116 |
+
click_state[2].append(caption)
|
117 |
+
elif click_mode == 'Single':
|
118 |
+
click_state[2] = [caption]
|
119 |
+
else:
|
120 |
+
raise NotImplementedError
|
121 |
+
|
122 |
+
|
123 |
+
def chat_with_points(chat_input, click_state, state, text_refiner):
|
124 |
+
if text_refiner is None:
|
125 |
+
response = "Text refiner is not initilzed, please input openai api key."
|
126 |
+
state = state + [(chat_input, response)]
|
127 |
+
return state, state
|
128 |
+
|
129 |
+
points, labels, captions = click_state
|
130 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
|
131 |
+
# # "The image is of width {width} and height {height}."
|
132 |
+
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
133 |
+
prev_visual_context = ""
|
134 |
+
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
135 |
+
if len(captions):
|
136 |
+
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
137 |
+
else:
|
138 |
+
prev_visual_context = 'no point exists.'
|
139 |
+
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
140 |
+
response = text_refiner.llm(chat_prompt)
|
141 |
+
state = state + [(chat_input, response)]
|
142 |
+
return state, state
|
143 |
+
|
144 |
+
def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment, factuality,
|
145 |
+
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
146 |
+
|
147 |
+
model = build_caption_anything_with_models(
|
148 |
+
args,
|
149 |
+
api_key="",
|
150 |
+
captioner=shared_captioner,
|
151 |
+
sam_model=shared_sam_model,
|
152 |
+
text_refiner=text_refiner,
|
153 |
+
session_id=iface.app_id
|
154 |
+
)
|
155 |
+
|
156 |
+
model.segmenter.image_embedding = image_embedding
|
157 |
+
model.segmenter.predictor.original_size = original_size
|
158 |
+
model.segmenter.predictor.input_size = input_size
|
159 |
+
model.segmenter.predictor.is_image_set = True
|
160 |
+
|
161 |
+
if point_prompt == 'Positive':
|
162 |
+
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
163 |
+
else:
|
164 |
+
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
165 |
+
|
166 |
+
controls = {'length': length,
|
167 |
+
'sentiment': sentiment,
|
168 |
+
'factuality': factuality,
|
169 |
+
'language': language}
|
170 |
+
|
171 |
+
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
172 |
+
# chat_input = click_coordinate
|
173 |
+
prompt = get_prompt(coordinate, click_state, click_mode)
|
174 |
+
print('prompt: ', prompt, 'controls: ', controls)
|
175 |
+
|
176 |
+
out = model.inference(image_input, prompt, controls, disable_gpt=True)
|
177 |
+
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
178 |
+
# for k, v in out['generated_captions'].items():
|
179 |
+
# state = state + [(f'{k}: {v}', None)]
|
180 |
+
state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
|
181 |
+
wiki = out['generated_captions'].get('wiki', "")
|
182 |
+
|
183 |
+
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
184 |
+
text = out['generated_captions']['raw_caption']
|
185 |
+
# draw = ImageDraw.Draw(image_input)
|
186 |
+
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
187 |
+
input_mask = np.array(out['mask'].convert('P'))
|
188 |
+
image_input = mask_painter(np.array(image_input), input_mask)
|
189 |
+
origin_image_input = image_input
|
190 |
+
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
191 |
+
|
192 |
+
yield state, state, click_state, chat_input, image_input, wiki
|
193 |
+
if not args.disable_gpt and model.text_refiner:
|
194 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
195 |
+
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
196 |
+
new_cap = refined_caption['caption']
|
197 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
198 |
+
yield state, state, click_state, chat_input, refined_image_input, wiki
|
199 |
+
|
200 |
+
|
201 |
+
def upload_callback(image_input, state):
|
202 |
+
state = [] + [('Image size: ' + str(image_input.size), None)]
|
203 |
+
click_state = [[], [], []]
|
204 |
+
res = 1024
|
205 |
+
width, height = image_input.size
|
206 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
207 |
+
if ratio < 1.0:
|
208 |
+
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
209 |
+
print('Scaling input image to {}'.format(image_input.size))
|
210 |
+
|
211 |
+
model = build_caption_anything_with_models(
|
212 |
+
args,
|
213 |
+
api_key="",
|
214 |
+
captioner=shared_captioner,
|
215 |
+
sam_model=shared_sam_model,
|
216 |
+
session_id=iface.app_id
|
217 |
+
)
|
218 |
+
model.segmenter.set_image(image_input)
|
219 |
+
image_embedding = model.segmenter.image_embedding
|
220 |
+
original_size = model.segmenter.predictor.original_size
|
221 |
+
input_size = model.segmenter.predictor.input_size
|
222 |
+
return state, state, image_input, click_state, image_input, image_embedding, original_size, input_size
|
223 |
+
|
224 |
+
with gr.Blocks(
|
225 |
+
css='''
|
226 |
+
#image_upload{min-height:400px}
|
227 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
|
228 |
+
'''
|
229 |
+
) as iface:
|
230 |
+
state = gr.State([])
|
231 |
+
click_state = gr.State([[],[],[]])
|
232 |
+
origin_image = gr.State(None)
|
233 |
+
image_embedding = gr.State(None)
|
234 |
+
text_refiner = gr.State(None)
|
235 |
+
original_size = gr.State(None)
|
236 |
+
input_size = gr.State(None)
|
237 |
+
|
238 |
+
gr.Markdown(title)
|
239 |
+
gr.Markdown(description)
|
240 |
+
|
241 |
+
with gr.Row():
|
242 |
+
with gr.Column(scale=1.0):
|
243 |
+
with gr.Column(visible=False) as modules_not_need_gpt:
|
244 |
+
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
245 |
+
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
246 |
+
with gr.Row(scale=1.0):
|
247 |
+
with gr.Row(scale=0.4):
|
248 |
+
point_prompt = gr.Radio(
|
249 |
+
choices=["Positive", "Negative"],
|
250 |
+
value="Positive",
|
251 |
+
label="Point Prompt",
|
252 |
+
interactive=True)
|
253 |
+
click_mode = gr.Radio(
|
254 |
+
choices=["Continuous", "Single"],
|
255 |
+
value="Continuous",
|
256 |
+
label="Clicking Mode",
|
257 |
+
interactive=True)
|
258 |
+
with gr.Row(scale=0.4):
|
259 |
+
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
260 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
261 |
+
with gr.Column(visible=False) as modules_need_gpt:
|
262 |
+
with gr.Row(scale=1.0):
|
263 |
+
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
264 |
+
|
265 |
+
sentiment = gr.Radio(
|
266 |
+
choices=["Positive", "Natural", "Negative"],
|
267 |
+
value="Natural",
|
268 |
+
label="Sentiment",
|
269 |
+
interactive=True,
|
270 |
+
)
|
271 |
+
with gr.Row(scale=1.0):
|
272 |
+
factuality = gr.Radio(
|
273 |
+
choices=["Factual", "Imagination"],
|
274 |
+
value="Factual",
|
275 |
+
label="Factuality",
|
276 |
+
interactive=True,
|
277 |
+
)
|
278 |
+
length = gr.Slider(
|
279 |
+
minimum=10,
|
280 |
+
maximum=80,
|
281 |
+
value=10,
|
282 |
+
step=1,
|
283 |
+
interactive=True,
|
284 |
+
label="Length",
|
285 |
+
)
|
286 |
+
with gr.Column(visible=True) as modules_not_need_gpt3:
|
287 |
+
gr.Examples(
|
288 |
+
examples=examples,
|
289 |
+
inputs=[example_image],
|
290 |
+
)
|
291 |
+
with gr.Column(scale=0.5):
|
292 |
+
openai_api_key = gr.Textbox(
|
293 |
+
placeholder="Input openAI API key",
|
294 |
+
show_label=False,
|
295 |
+
label = "OpenAI API Key",
|
296 |
+
lines=1,
|
297 |
+
type="password")
|
298 |
+
with gr.Row(scale=0.5):
|
299 |
+
enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
|
300 |
+
disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True, variant='primary')
|
301 |
+
with gr.Column(visible=False) as modules_need_gpt2:
|
302 |
+
wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
|
303 |
+
with gr.Column(visible=False) as modules_not_need_gpt2:
|
304 |
+
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
|
305 |
+
with gr.Column(visible=False) as modules_need_gpt3:
|
306 |
+
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
307 |
+
with gr.Row():
|
308 |
+
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
309 |
+
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
310 |
+
|
311 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
312 |
+
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
313 |
+
disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
314 |
+
|
315 |
+
clear_button_clike.click(
|
316 |
+
lambda x: ([[], [], []], x, ""),
|
317 |
+
[origin_image],
|
318 |
+
[click_state, image_input, wiki_output],
|
319 |
+
queue=False,
|
320 |
+
show_progress=False
|
321 |
+
)
|
322 |
+
clear_button_image.click(
|
323 |
+
lambda: (None, [], [], [[], [], []], "", ""),
|
324 |
+
[],
|
325 |
+
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
326 |
+
queue=False,
|
327 |
+
show_progress=False
|
328 |
+
)
|
329 |
+
clear_button_text.click(
|
330 |
+
lambda: ([], [], [[], [], []]),
|
331 |
+
[],
|
332 |
+
[chatbot, state, click_state],
|
333 |
+
queue=False,
|
334 |
+
show_progress=False
|
335 |
+
)
|
336 |
+
image_input.clear(
|
337 |
+
lambda: (None, [], [], [[], [], []], "", ""),
|
338 |
+
[],
|
339 |
+
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
340 |
+
queue=False,
|
341 |
+
show_progress=False
|
342 |
+
)
|
343 |
+
|
344 |
+
image_input.upload(upload_callback,[image_input, state], [chatbot, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
345 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, state, text_refiner], [chatbot, state])
|
346 |
+
example_image.change(upload_callback,[example_image, state], [state, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
347 |
+
|
348 |
+
# select coordinate
|
349 |
+
image_input.select(inference_seg_cap,
|
350 |
+
inputs=[
|
351 |
+
origin_image,
|
352 |
+
point_prompt,
|
353 |
+
click_mode,
|
354 |
+
language,
|
355 |
+
sentiment,
|
356 |
+
factuality,
|
357 |
+
length,
|
358 |
+
image_embedding,
|
359 |
+
state,
|
360 |
+
click_state,
|
361 |
+
original_size,
|
362 |
+
input_size,
|
363 |
+
text_refiner
|
364 |
+
],
|
365 |
+
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
366 |
+
show_progress=False, queue=True)
|
367 |
+
|
368 |
+
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
369 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
app_huggingface.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import string
|
3 |
+
import gradio as gr
|
4 |
+
import requests
|
5 |
+
from caption_anything import CaptionAnything
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
from caption_anything import parse_augment
|
11 |
+
import numpy as np
|
12 |
+
import PIL.ImageDraw as ImageDraw
|
13 |
+
from image_editing_utils import create_bubble_frame
|
14 |
+
import copy
|
15 |
+
from tools import mask_painter
|
16 |
+
from PIL import Image
|
17 |
+
import os
|
18 |
+
|
19 |
+
def download_checkpoint(url, folder, filename):
|
20 |
+
os.makedirs(folder, exist_ok=True)
|
21 |
+
filepath = os.path.join(folder, filename)
|
22 |
+
|
23 |
+
if not os.path.exists(filepath):
|
24 |
+
response = requests.get(url, stream=True)
|
25 |
+
with open(filepath, "wb") as f:
|
26 |
+
for chunk in response.iter_content(chunk_size=8192):
|
27 |
+
if chunk:
|
28 |
+
f.write(chunk)
|
29 |
+
|
30 |
+
return filepath
|
31 |
+
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
32 |
+
folder = "segmenter"
|
33 |
+
filename = "sam_vit_h_4b8939.pth"
|
34 |
+
|
35 |
+
download_checkpoint(checkpoint_url, folder, filename)
|
36 |
+
|
37 |
+
|
38 |
+
title = """<h1 align="center">Caption-Anything</h1>"""
|
39 |
+
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
|
40 |
+
"""
|
41 |
+
|
42 |
+
examples = [
|
43 |
+
["test_img/img2.jpg"],
|
44 |
+
["test_img/img5.jpg"],
|
45 |
+
["test_img/img12.jpg"],
|
46 |
+
["test_img/img14.jpg"],
|
47 |
+
]
|
48 |
+
|
49 |
+
args = parse_augment()
|
50 |
+
args.captioner = 'blip2'
|
51 |
+
args.seg_crop_mode = 'wo_bg'
|
52 |
+
args.regular_box = True
|
53 |
+
# args.device = 'cuda:5'
|
54 |
+
# args.disable_gpt = False
|
55 |
+
# args.enable_reduce_tokens = True
|
56 |
+
# args.port=20322
|
57 |
+
model = CaptionAnything(args)
|
58 |
+
|
59 |
+
def init_openai_api_key(api_key):
|
60 |
+
os.environ['OPENAI_API_KEY'] = api_key
|
61 |
+
model.init_refiner()
|
62 |
+
|
63 |
+
|
64 |
+
def get_prompt(chat_input, click_state):
|
65 |
+
points = click_state[0]
|
66 |
+
labels = click_state[1]
|
67 |
+
inputs = json.loads(chat_input)
|
68 |
+
for input in inputs:
|
69 |
+
points.append(input[:2])
|
70 |
+
labels.append(input[2])
|
71 |
+
|
72 |
+
prompt = {
|
73 |
+
"prompt_type":["click"],
|
74 |
+
"input_point":points,
|
75 |
+
"input_label":labels,
|
76 |
+
"multimask_output":"True",
|
77 |
+
}
|
78 |
+
return prompt
|
79 |
+
|
80 |
+
def chat_with_points(chat_input, click_state, state):
|
81 |
+
if not hasattr(model, "text_refiner"):
|
82 |
+
response = "Text refiner is not initilzed, please input openai api key."
|
83 |
+
state = state + [(chat_input, response)]
|
84 |
+
return state, state
|
85 |
+
|
86 |
+
points, labels, captions = click_state
|
87 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
|
88 |
+
# # "The image is of width {width} and height {height}."
|
89 |
+
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
90 |
+
prev_visual_context = ""
|
91 |
+
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
92 |
+
if len(captions):
|
93 |
+
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
94 |
+
else:
|
95 |
+
prev_visual_context = 'no point exists.'
|
96 |
+
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
97 |
+
response = model.text_refiner.llm(chat_prompt)
|
98 |
+
state = state + [(chat_input, response)]
|
99 |
+
return state, state
|
100 |
+
|
101 |
+
def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
|
102 |
+
|
103 |
+
if point_prompt == 'Positive':
|
104 |
+
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
105 |
+
else:
|
106 |
+
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
107 |
+
|
108 |
+
controls = {'length': length,
|
109 |
+
'sentiment': sentiment,
|
110 |
+
'factuality': factuality,
|
111 |
+
'language': language}
|
112 |
+
|
113 |
+
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
114 |
+
# chat_input = click_coordinate
|
115 |
+
prompt = get_prompt(coordinate, click_state)
|
116 |
+
print('prompt: ', prompt, 'controls: ', controls)
|
117 |
+
|
118 |
+
out = model.inference(image_input, prompt, controls)
|
119 |
+
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
120 |
+
# for k, v in out['generated_captions'].items():
|
121 |
+
# state = state + [(f'{k}: {v}', None)]
|
122 |
+
state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
|
123 |
+
wiki = out['generated_captions'].get('wiki', "")
|
124 |
+
click_state[2].append(out['generated_captions']['raw_caption'])
|
125 |
+
|
126 |
+
text = out['generated_captions']['raw_caption']
|
127 |
+
# draw = ImageDraw.Draw(image_input)
|
128 |
+
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
129 |
+
input_mask = np.array(Image.open(out['mask_save_path']).convert('P'))
|
130 |
+
image_input = mask_painter(np.array(image_input), input_mask)
|
131 |
+
origin_image_input = image_input
|
132 |
+
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
133 |
+
|
134 |
+
yield state, state, click_state, chat_input, image_input, wiki
|
135 |
+
if not args.disable_gpt and hasattr(model, "text_refiner"):
|
136 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
137 |
+
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
138 |
+
new_cap = refined_caption['caption']
|
139 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
140 |
+
yield state, state, click_state, chat_input, refined_image_input, wiki
|
141 |
+
|
142 |
+
|
143 |
+
def upload_callback(image_input, state):
|
144 |
+
state = [] + [('Image size: ' + str(image_input.size), None)]
|
145 |
+
click_state = [[], [], []]
|
146 |
+
model.segmenter.image = None
|
147 |
+
model.segmenter.image_embedding = None
|
148 |
+
model.segmenter.set_image(image_input)
|
149 |
+
return state, image_input, click_state
|
150 |
+
|
151 |
+
with gr.Blocks(
|
152 |
+
css='''
|
153 |
+
#image_upload{min-height:400px}
|
154 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
|
155 |
+
'''
|
156 |
+
) as iface:
|
157 |
+
state = gr.State([])
|
158 |
+
click_state = gr.State([[],[],[]])
|
159 |
+
origin_image = gr.State(None)
|
160 |
+
|
161 |
+
gr.Markdown(title)
|
162 |
+
gr.Markdown(description)
|
163 |
+
|
164 |
+
with gr.Row():
|
165 |
+
with gr.Column(scale=1.0):
|
166 |
+
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
167 |
+
with gr.Row(scale=1.0):
|
168 |
+
point_prompt = gr.Radio(
|
169 |
+
choices=["Positive", "Negative"],
|
170 |
+
value="Positive",
|
171 |
+
label="Point Prompt",
|
172 |
+
interactive=True)
|
173 |
+
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
174 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
175 |
+
with gr.Row(scale=1.0):
|
176 |
+
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
177 |
+
|
178 |
+
sentiment = gr.Radio(
|
179 |
+
choices=["Positive", "Natural", "Negative"],
|
180 |
+
value="Natural",
|
181 |
+
label="Sentiment",
|
182 |
+
interactive=True,
|
183 |
+
)
|
184 |
+
with gr.Row(scale=1.0):
|
185 |
+
factuality = gr.Radio(
|
186 |
+
choices=["Factual", "Imagination"],
|
187 |
+
value="Factual",
|
188 |
+
label="Factuality",
|
189 |
+
interactive=True,
|
190 |
+
)
|
191 |
+
length = gr.Slider(
|
192 |
+
minimum=10,
|
193 |
+
maximum=80,
|
194 |
+
value=10,
|
195 |
+
step=1,
|
196 |
+
interactive=True,
|
197 |
+
label="Length",
|
198 |
+
)
|
199 |
+
|
200 |
+
with gr.Column(scale=0.5):
|
201 |
+
openai_api_key = gr.Textbox(
|
202 |
+
placeholder="Input your openAI API key and press Enter",
|
203 |
+
show_label=False,
|
204 |
+
label = "OpenAI API Key",
|
205 |
+
lines=1,
|
206 |
+
type="password"
|
207 |
+
)
|
208 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
|
209 |
+
wiki_output = gr.Textbox(lines=6, label="Wiki")
|
210 |
+
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
|
211 |
+
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
212 |
+
with gr.Row():
|
213 |
+
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
214 |
+
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
215 |
+
clear_button_clike.click(
|
216 |
+
lambda x: ([[], [], []], x, ""),
|
217 |
+
[origin_image],
|
218 |
+
[click_state, image_input, wiki_output],
|
219 |
+
queue=False,
|
220 |
+
show_progress=False
|
221 |
+
)
|
222 |
+
clear_button_image.click(
|
223 |
+
lambda: (None, [], [], [[], [], []], ""),
|
224 |
+
[],
|
225 |
+
[image_input, chatbot, state, click_state, wiki_output],
|
226 |
+
queue=False,
|
227 |
+
show_progress=False
|
228 |
+
)
|
229 |
+
clear_button_text.click(
|
230 |
+
lambda: ([], [], [[], [], []]),
|
231 |
+
[],
|
232 |
+
[chatbot, state, click_state],
|
233 |
+
queue=False,
|
234 |
+
show_progress=False
|
235 |
+
)
|
236 |
+
image_input.clear(
|
237 |
+
lambda: (None, [], [], [[], [], []], ""),
|
238 |
+
[],
|
239 |
+
[image_input, chatbot, state, click_state, wiki_output],
|
240 |
+
queue=False,
|
241 |
+
show_progress=False
|
242 |
+
)
|
243 |
+
|
244 |
+
examples = gr.Examples(
|
245 |
+
examples=examples,
|
246 |
+
inputs=[image_input],
|
247 |
+
)
|
248 |
+
|
249 |
+
image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
|
250 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
|
251 |
+
|
252 |
+
# select coordinate
|
253 |
+
image_input.select(inference_seg_cap,
|
254 |
+
inputs=[
|
255 |
+
origin_image,
|
256 |
+
point_prompt,
|
257 |
+
language,
|
258 |
+
sentiment,
|
259 |
+
factuality,
|
260 |
+
length,
|
261 |
+
state,
|
262 |
+
click_state
|
263 |
+
],
|
264 |
+
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
265 |
+
show_progress=False, queue=True)
|
266 |
+
|
267 |
+
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
268 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
app_old.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import string
|
3 |
+
import gradio as gr
|
4 |
+
import requests
|
5 |
+
from caption_anything import CaptionAnything
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
from caption_anything import parse_augment
|
11 |
+
import os
|
12 |
+
|
13 |
+
# download sam checkpoint if not downloaded
|
14 |
+
def download_checkpoint(url, folder, filename):
|
15 |
+
os.makedirs(folder, exist_ok=True)
|
16 |
+
filepath = os.path.join(folder, filename)
|
17 |
+
|
18 |
+
if not os.path.exists(filepath):
|
19 |
+
response = requests.get(url, stream=True)
|
20 |
+
with open(filepath, "wb") as f:
|
21 |
+
for chunk in response.iter_content(chunk_size=8192):
|
22 |
+
if chunk:
|
23 |
+
f.write(chunk)
|
24 |
+
|
25 |
+
return filepath
|
26 |
+
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
27 |
+
folder = "segmenter"
|
28 |
+
filename = "sam_vit_h_4b8939.pth"
|
29 |
+
|
30 |
+
title = """<h1 align="center">Caption-Anything</h1>"""
|
31 |
+
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them.
|
32 |
+
<br> <strong>Code</strong>: GitHub repo: <a href='https://github.com/ttengwang/Caption-Anything' target='_blank'></a>
|
33 |
+
"""
|
34 |
+
|
35 |
+
examples = [
|
36 |
+
["test_img/img2.jpg", "[[1000, 700, 1]]"]
|
37 |
+
]
|
38 |
+
|
39 |
+
args = parse_augment()
|
40 |
+
|
41 |
+
def get_prompt(chat_input, click_state):
|
42 |
+
points = click_state[0]
|
43 |
+
labels = click_state[1]
|
44 |
+
inputs = json.loads(chat_input)
|
45 |
+
for input in inputs:
|
46 |
+
points.append(input[:2])
|
47 |
+
labels.append(input[2])
|
48 |
+
|
49 |
+
prompt = {
|
50 |
+
"prompt_type":["click"],
|
51 |
+
"input_point":points,
|
52 |
+
"input_label":labels,
|
53 |
+
"multimask_output":"True",
|
54 |
+
}
|
55 |
+
return prompt
|
56 |
+
|
57 |
+
def inference_seg_cap(image_input, chat_input, language, sentiment, factuality, length, state, click_state):
|
58 |
+
controls = {'length': length,
|
59 |
+
'sentiment': sentiment,
|
60 |
+
'factuality': factuality,
|
61 |
+
'language': language}
|
62 |
+
prompt = get_prompt(chat_input, click_state)
|
63 |
+
print('prompt: ', prompt, 'controls: ', controls)
|
64 |
+
out = model.inference(image_input, prompt, controls)
|
65 |
+
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
66 |
+
for k, v in out['generated_captions'].items():
|
67 |
+
state = state + [(f'{k}: {v}', None)]
|
68 |
+
click_state[2].append(out['generated_captions']['raw_caption'])
|
69 |
+
image_output_mask = out['mask_save_path']
|
70 |
+
image_output_crop = out['crop_save_path']
|
71 |
+
return state, state, click_state, image_output_mask, image_output_crop
|
72 |
+
|
73 |
+
|
74 |
+
def upload_callback(image_input, state):
|
75 |
+
state = state + [('Image size: ' + str(image_input.size), None)]
|
76 |
+
return state
|
77 |
+
|
78 |
+
# get coordinate in format [[x,y,positive/negative]]
|
79 |
+
def get_select_coords(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt: gr.SelectData):
|
80 |
+
print("point_prompt: ", point_prompt)
|
81 |
+
if point_prompt == 'Positive Point':
|
82 |
+
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
83 |
+
else:
|
84 |
+
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
85 |
+
return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
|
86 |
+
|
87 |
+
def chat_with_points(chat_input, click_state, state):
|
88 |
+
points, labels, captions = click_state
|
89 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
|
90 |
+
# "The image is of width {width} and height {height}."
|
91 |
+
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
92 |
+
prev_visual_context = ""
|
93 |
+
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
94 |
+
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
95 |
+
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
96 |
+
response = model.text_refiner.llm(chat_prompt)
|
97 |
+
state = state + [(chat_input, response)]
|
98 |
+
return state, state
|
99 |
+
|
100 |
+
def init_openai_api_key(api_key):
|
101 |
+
# os.environ['OPENAI_API_KEY'] = api_key
|
102 |
+
global model
|
103 |
+
model = CaptionAnything(args, api_key)
|
104 |
+
|
105 |
+
css='''
|
106 |
+
#image_upload{min-height:200px}
|
107 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 200px}
|
108 |
+
'''
|
109 |
+
|
110 |
+
with gr.Blocks(css=css) as iface:
|
111 |
+
state = gr.State([])
|
112 |
+
click_state = gr.State([[],[],[]])
|
113 |
+
caption_state = gr.State([[]])
|
114 |
+
gr.Markdown(title)
|
115 |
+
gr.Markdown(description)
|
116 |
+
|
117 |
+
with gr.Column():
|
118 |
+
openai_api_key = gr.Textbox(
|
119 |
+
placeholder="Input your openAI API key and press Enter",
|
120 |
+
show_label=False,
|
121 |
+
lines=1,
|
122 |
+
type="password",
|
123 |
+
)
|
124 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
|
125 |
+
|
126 |
+
with gr.Row():
|
127 |
+
with gr.Column(scale=0.7):
|
128 |
+
image_input = gr.Image(type="pil", interactive=True, label="Image", elem_id="image_upload").style(height=260,scale=1.0)
|
129 |
+
|
130 |
+
with gr.Row(scale=0.7):
|
131 |
+
point_prompt = gr.Radio(
|
132 |
+
choices=["Positive Point", "Negative Point"],
|
133 |
+
value="Positive Point",
|
134 |
+
label="Points",
|
135 |
+
interactive=True,
|
136 |
+
)
|
137 |
+
|
138 |
+
# with gr.Row():
|
139 |
+
language = gr.Radio(
|
140 |
+
choices=["English", "Chinese", "French", "Spanish", "Arabic", "Portuguese","Cantonese"],
|
141 |
+
value="English",
|
142 |
+
label="Language",
|
143 |
+
interactive=True,
|
144 |
+
)
|
145 |
+
sentiment = gr.Radio(
|
146 |
+
choices=["Positive", "Natural", "Negative"],
|
147 |
+
value="Natural",
|
148 |
+
label="Sentiment",
|
149 |
+
interactive=True,
|
150 |
+
)
|
151 |
+
factuality = gr.Radio(
|
152 |
+
choices=["Factual", "Imagination"],
|
153 |
+
value="Factual",
|
154 |
+
label="Factuality",
|
155 |
+
interactive=True,
|
156 |
+
)
|
157 |
+
length = gr.Slider(
|
158 |
+
minimum=5,
|
159 |
+
maximum=100,
|
160 |
+
value=10,
|
161 |
+
step=1,
|
162 |
+
interactive=True,
|
163 |
+
label="Length",
|
164 |
+
)
|
165 |
+
|
166 |
+
with gr.Column(scale=1.5):
|
167 |
+
with gr.Row():
|
168 |
+
image_output_mask= gr.Image(type="pil", interactive=False, label="Mask").style(height=260,scale=1.0)
|
169 |
+
image_output_crop= gr.Image(type="pil", interactive=False, label="Cropped Image by Mask", show_progress=False).style(height=260,scale=1.0)
|
170 |
+
chatbot = gr.Chatbot(label="Chat Output",).style(height=450,scale=0.5)
|
171 |
+
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Column(scale=0.7):
|
174 |
+
prompt_input = gr.Textbox(lines=1, label="Input Prompt (A list of points like : [[100, 200, 1], [200,300,0]])")
|
175 |
+
prompt_input.submit(
|
176 |
+
inference_seg_cap,
|
177 |
+
[
|
178 |
+
image_input,
|
179 |
+
prompt_input,
|
180 |
+
language,
|
181 |
+
sentiment,
|
182 |
+
factuality,
|
183 |
+
length,
|
184 |
+
state,
|
185 |
+
click_state
|
186 |
+
],
|
187 |
+
[chatbot, state, click_state, image_output_mask, image_output_crop],
|
188 |
+
show_progress=False
|
189 |
+
)
|
190 |
+
|
191 |
+
image_input.upload(
|
192 |
+
upload_callback,
|
193 |
+
[image_input, state],
|
194 |
+
[chatbot]
|
195 |
+
)
|
196 |
+
|
197 |
+
with gr.Row():
|
198 |
+
clear_button = gr.Button(value="Clear Click", interactive=True)
|
199 |
+
clear_button.click(
|
200 |
+
lambda: ("", [[], [], []], None, None),
|
201 |
+
[],
|
202 |
+
[prompt_input, click_state, image_output_mask, image_output_crop],
|
203 |
+
queue=False,
|
204 |
+
show_progress=False
|
205 |
+
)
|
206 |
+
|
207 |
+
clear_button = gr.Button(value="Clear", interactive=True)
|
208 |
+
clear_button.click(
|
209 |
+
lambda: ("", [], [], [[], [], []], None, None),
|
210 |
+
[],
|
211 |
+
[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
|
212 |
+
queue=False,
|
213 |
+
show_progress=False
|
214 |
+
)
|
215 |
+
|
216 |
+
submit_button = gr.Button(
|
217 |
+
value="Submit", interactive=True, variant="primary"
|
218 |
+
)
|
219 |
+
submit_button.click(
|
220 |
+
inference_seg_cap,
|
221 |
+
[
|
222 |
+
image_input,
|
223 |
+
prompt_input,
|
224 |
+
language,
|
225 |
+
sentiment,
|
226 |
+
factuality,
|
227 |
+
length,
|
228 |
+
state,
|
229 |
+
click_state
|
230 |
+
],
|
231 |
+
[chatbot, state, click_state, image_output_mask, image_output_crop],
|
232 |
+
show_progress=False
|
233 |
+
)
|
234 |
+
|
235 |
+
# select coordinate
|
236 |
+
image_input.select(
|
237 |
+
get_select_coords,
|
238 |
+
inputs=[image_input,point_prompt,language,sentiment,factuality,length,state,click_state],
|
239 |
+
outputs=[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
|
240 |
+
show_progress=False
|
241 |
+
)
|
242 |
+
|
243 |
+
image_input.change(
|
244 |
+
lambda: ("", [], [[], [], []]),
|
245 |
+
[],
|
246 |
+
[chatbot, state, click_state],
|
247 |
+
queue=False,
|
248 |
+
)
|
249 |
+
|
250 |
+
with gr.Column(scale=1.5):
|
251 |
+
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
252 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
|
253 |
+
|
254 |
+
|
255 |
+
examples = gr.Examples(
|
256 |
+
examples=examples,
|
257 |
+
inputs=[image_input, prompt_input],
|
258 |
+
)
|
259 |
+
|
260 |
+
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
261 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
|
caas.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from captioner import build_captioner, BaseCaptioner
|
2 |
+
from segmenter import build_segmenter
|
3 |
+
from text_refiner import build_text_refiner
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import pdb
|
7 |
+
import time
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
class CaptionAnything():
|
11 |
+
def __init__(self, args):
|
12 |
+
self.args = args
|
13 |
+
self.captioner = build_captioner(args.captioner, args.device, args)
|
14 |
+
self.segmenter = build_segmenter(args.segmenter, args.device, args)
|
15 |
+
if not args.disable_gpt:
|
16 |
+
self.init_refiner()
|
17 |
+
|
18 |
+
|
19 |
+
def init_refiner(self):
|
20 |
+
if os.environ.get('OPENAI_API_KEY', None):
|
21 |
+
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args)
|
22 |
+
|
23 |
+
def inference(self, image, prompt, controls, disable_gpt=False):
|
24 |
+
# segment with prompt
|
25 |
+
print("CA prompt: ", prompt, "CA controls",controls)
|
26 |
+
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
27 |
+
mask_save_path = f'result/mask_{time.time()}.png'
|
28 |
+
if not os.path.exists(os.path.dirname(mask_save_path)):
|
29 |
+
os.makedirs(os.path.dirname(mask_save_path))
|
30 |
+
new_p = Image.fromarray(seg_mask.astype('int') * 255.)
|
31 |
+
if new_p.mode != 'RGB':
|
32 |
+
new_p = new_p.convert('RGB')
|
33 |
+
new_p.save(mask_save_path)
|
34 |
+
print('seg_mask path: ', mask_save_path)
|
35 |
+
print("seg_mask.shape: ", seg_mask.shape)
|
36 |
+
# captioning with mask
|
37 |
+
if self.args.enable_reduce_tokens:
|
38 |
+
caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
|
39 |
+
else:
|
40 |
+
caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
|
41 |
+
# refining with TextRefiner
|
42 |
+
context_captions = []
|
43 |
+
if self.args.context_captions:
|
44 |
+
context_captions.append(self.captioner.inference(image))
|
45 |
+
if not disable_gpt and hasattr(self, "text_refiner"):
|
46 |
+
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
47 |
+
else:
|
48 |
+
refined_caption = {'raw_caption': caption}
|
49 |
+
out = {'generated_captions': refined_caption,
|
50 |
+
'crop_save_path': crop_save_path,
|
51 |
+
'mask_save_path': mask_save_path,
|
52 |
+
'context_captions': context_captions}
|
53 |
+
return out
|
54 |
+
|
55 |
+
def parse_augment():
|
56 |
+
parser = argparse.ArgumentParser()
|
57 |
+
parser.add_argument('--captioner', type=str, default="blip")
|
58 |
+
parser.add_argument('--segmenter', type=str, default="base")
|
59 |
+
parser.add_argument('--text_refiner', type=str, default="base")
|
60 |
+
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
61 |
+
parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
|
62 |
+
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
63 |
+
parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption")
|
64 |
+
parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box")
|
65 |
+
parser.add_argument('--device', type=str, default="cuda:0")
|
66 |
+
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
67 |
+
parser.add_argument('--debug', action="store_true")
|
68 |
+
parser.add_argument('--gradio_share', action="store_true")
|
69 |
+
parser.add_argument('--disable_gpt', action="store_true")
|
70 |
+
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
71 |
+
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
72 |
+
args = parser.parse_args()
|
73 |
+
|
74 |
+
if args.debug:
|
75 |
+
print(args)
|
76 |
+
return args
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
args = parse_augment()
|
80 |
+
# image_path = 'test_img/img3.jpg'
|
81 |
+
image_path = 'test_img/img13.jpg'
|
82 |
+
prompts = [
|
83 |
+
{
|
84 |
+
"prompt_type":["click"],
|
85 |
+
"input_point":[[500, 300], [1000, 500]],
|
86 |
+
"input_label":[1, 0],
|
87 |
+
"multimask_output":"True",
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"prompt_type":["click"],
|
91 |
+
"input_point":[[900, 800]],
|
92 |
+
"input_label":[1],
|
93 |
+
"multimask_output":"True",
|
94 |
+
}
|
95 |
+
]
|
96 |
+
controls = {
|
97 |
+
"length": "30",
|
98 |
+
"sentiment": "positive",
|
99 |
+
# "imagination": "True",
|
100 |
+
"imagination": "False",
|
101 |
+
"language": "English",
|
102 |
+
}
|
103 |
+
|
104 |
+
model = CaptionAnything(args)
|
105 |
+
for prompt in prompts:
|
106 |
+
print('*'*30)
|
107 |
+
print('Image path: ', image_path)
|
108 |
+
image = Image.open(image_path)
|
109 |
+
print(image)
|
110 |
+
print('Visual controls (SAM prompt):\n', prompt)
|
111 |
+
print('Language controls:\n', controls)
|
112 |
+
out = model.inference(image_path, prompt, controls)
|
113 |
+
|
114 |
+
|
caption_anything.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from captioner import build_captioner, BaseCaptioner
|
2 |
+
from segmenter import build_segmenter
|
3 |
+
from text_refiner import build_text_refiner
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import pdb
|
7 |
+
import time
|
8 |
+
from PIL import Image
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
class CaptionAnything():
|
13 |
+
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
|
14 |
+
self.args = args
|
15 |
+
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
16 |
+
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
17 |
+
|
18 |
+
self.text_refiner = None
|
19 |
+
if not args.disable_gpt:
|
20 |
+
if text_refiner is not None:
|
21 |
+
self.text_refiner = text_refiner
|
22 |
+
else:
|
23 |
+
self.init_refiner(api_key)
|
24 |
+
|
25 |
+
def init_refiner(self, api_key):
|
26 |
+
try:
|
27 |
+
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
|
28 |
+
self.text_refiner.llm('hi') # test
|
29 |
+
except:
|
30 |
+
self.text_refiner = None
|
31 |
+
print('OpenAI GPT is not available')
|
32 |
+
|
33 |
+
def inference(self, image, prompt, controls, disable_gpt=False):
|
34 |
+
# segment with prompt
|
35 |
+
print("CA prompt: ", prompt, "CA controls",controls)
|
36 |
+
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
37 |
+
if self.args.enable_morphologyex:
|
38 |
+
seg_mask = 255 * seg_mask.astype(np.uint8)
|
39 |
+
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
|
40 |
+
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
|
41 |
+
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
|
42 |
+
seg_mask = seg_mask[:,:,0] > 0
|
43 |
+
mask_save_path = f'result/mask_{time.time()}.png'
|
44 |
+
if not os.path.exists(os.path.dirname(mask_save_path)):
|
45 |
+
os.makedirs(os.path.dirname(mask_save_path))
|
46 |
+
seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
|
47 |
+
if seg_mask_img.mode != 'RGB':
|
48 |
+
seg_mask_img = seg_mask_img.convert('RGB')
|
49 |
+
seg_mask_img.save(mask_save_path)
|
50 |
+
print('seg_mask path: ', mask_save_path)
|
51 |
+
print("seg_mask.shape: ", seg_mask.shape)
|
52 |
+
# captioning with mask
|
53 |
+
if self.args.enable_reduce_tokens:
|
54 |
+
caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
|
55 |
+
else:
|
56 |
+
caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
|
57 |
+
# refining with TextRefiner
|
58 |
+
context_captions = []
|
59 |
+
if self.args.context_captions:
|
60 |
+
context_captions.append(self.captioner.inference(image))
|
61 |
+
if not disable_gpt and self.text_refiner is not None:
|
62 |
+
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
63 |
+
else:
|
64 |
+
refined_caption = {'raw_caption': caption}
|
65 |
+
out = {'generated_captions': refined_caption,
|
66 |
+
'crop_save_path': crop_save_path,
|
67 |
+
'mask_save_path': mask_save_path,
|
68 |
+
'mask': seg_mask_img,
|
69 |
+
'context_captions': context_captions}
|
70 |
+
return out
|
71 |
+
|
72 |
+
def parse_augment():
|
73 |
+
parser = argparse.ArgumentParser()
|
74 |
+
parser.add_argument('--captioner', type=str, default="blip2")
|
75 |
+
parser.add_argument('--segmenter', type=str, default="base")
|
76 |
+
parser.add_argument('--text_refiner', type=str, default="base")
|
77 |
+
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
78 |
+
parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
|
79 |
+
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
80 |
+
parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
|
81 |
+
parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
|
82 |
+
parser.add_argument('--device', type=str, default="cuda:0")
|
83 |
+
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
84 |
+
parser.add_argument('--debug', action="store_true")
|
85 |
+
parser.add_argument('--gradio_share', action="store_true")
|
86 |
+
parser.add_argument('--disable_gpt', action="store_true")
|
87 |
+
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
88 |
+
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
89 |
+
parser.add_argument('--enable_morphologyex', action="store_true", default=False)
|
90 |
+
args = parser.parse_args()
|
91 |
+
|
92 |
+
if args.debug:
|
93 |
+
print(args)
|
94 |
+
return args
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
args = parse_augment()
|
98 |
+
# image_path = 'test_img/img3.jpg'
|
99 |
+
image_path = 'test_img/img13.jpg'
|
100 |
+
prompts = [
|
101 |
+
{
|
102 |
+
"prompt_type":["click"],
|
103 |
+
"input_point":[[500, 300], [1000, 500]],
|
104 |
+
"input_label":[1, 0],
|
105 |
+
"multimask_output":"True",
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"prompt_type":["click"],
|
109 |
+
"input_point":[[900, 800]],
|
110 |
+
"input_label":[1],
|
111 |
+
"multimask_output":"True",
|
112 |
+
}
|
113 |
+
]
|
114 |
+
controls = {
|
115 |
+
"length": "30",
|
116 |
+
"sentiment": "positive",
|
117 |
+
# "imagination": "True",
|
118 |
+
"imagination": "False",
|
119 |
+
"language": "English",
|
120 |
+
}
|
121 |
+
|
122 |
+
model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
|
123 |
+
for prompt in prompts:
|
124 |
+
print('*'*30)
|
125 |
+
print('Image path: ', image_path)
|
126 |
+
image = Image.open(image_path)
|
127 |
+
print(image)
|
128 |
+
print('Visual controls (SAM prompt):\n', prompt)
|
129 |
+
print('Language controls:\n', controls)
|
130 |
+
out = model.inference(image_path, prompt, controls)
|
131 |
+
|
132 |
+
|
captioner/README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
To run BLIP/BLIP2, you should install transformers from source!
|
2 |
+
```
|
3 |
+
!pip install git+https://github.com/huggingface/transformers.git
|
4 |
+
```
|
5 |
+
To run filter module, you should install CLIP repo as a Python package as follow:
|
6 |
+
```
|
7 |
+
!pip install ftfy regex tqdm
|
8 |
+
!pip install git+https://github.com/openai/CLIP.git
|
9 |
+
```
|
10 |
+
To accelerate BLIP2 with int8, you should install accelerate
|
11 |
+
```
|
12 |
+
!pip install accelerate bitsandbytes
|
13 |
+
```
|
captioner/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .blip import BLIPCaptioner
|
2 |
+
from .blip2 import BLIP2Captioner
|
3 |
+
from .git import GITCaptioner
|
4 |
+
from .base_captioner import BaseCaptioner
|
5 |
+
|
6 |
+
|
7 |
+
def build_captioner(type, device, args=None):
|
8 |
+
if type == 'blip':
|
9 |
+
return BLIPCaptioner(device, enable_filter=args.clip_filter)
|
10 |
+
elif type == 'blip2':
|
11 |
+
return BLIP2Captioner(device, enable_filter=args.clip_filter)
|
12 |
+
elif type == 'git':
|
13 |
+
return GITCaptioner(device, enable_filter=args.clip_filter)
|
14 |
+
else:
|
15 |
+
raise NotImplementedError("")
|
captioner/base_captioner.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image, ImageDraw, ImageOps
|
3 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
4 |
+
import json
|
5 |
+
import pdb
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from typing import Union
|
9 |
+
import time
|
10 |
+
import clip
|
11 |
+
|
12 |
+
def boundary(inputs):
|
13 |
+
|
14 |
+
col = inputs.shape[1]
|
15 |
+
inputs = inputs.reshape(-1)
|
16 |
+
lens = len(inputs)
|
17 |
+
|
18 |
+
for i in range(lens):
|
19 |
+
if inputs[i] != False:
|
20 |
+
break
|
21 |
+
for j in range(lens):
|
22 |
+
if inputs[lens - 1 - j] != False:
|
23 |
+
break
|
24 |
+
start = i
|
25 |
+
end = lens - 1 - j
|
26 |
+
top = start // col
|
27 |
+
bottom = end // col
|
28 |
+
|
29 |
+
return top, bottom
|
30 |
+
|
31 |
+
def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
32 |
+
|
33 |
+
if type(seg_mask) == str:
|
34 |
+
seg_mask = Image.open(seg_mask)
|
35 |
+
elif type(seg_mask) == np.ndarray:
|
36 |
+
seg_mask = Image.fromarray(seg_mask)
|
37 |
+
seg_mask = np.array(seg_mask) > 0
|
38 |
+
size = max(seg_mask.shape[0], seg_mask.shape[1])
|
39 |
+
top, bottom = boundary(seg_mask)
|
40 |
+
left, right = boundary(seg_mask.T)
|
41 |
+
return [left / size, top / size, right / size, bottom / size]
|
42 |
+
|
43 |
+
def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
44 |
+
if type(seg_mask) == str:
|
45 |
+
seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE)
|
46 |
+
_, seg_mask = cv2.threshold(seg_mask, 127, 255, 0)
|
47 |
+
elif type(seg_mask) == np.ndarray:
|
48 |
+
assert seg_mask.ndim == 2 # only support single-channel segmentation mask
|
49 |
+
seg_mask = seg_mask.astype('uint8')
|
50 |
+
if seg_mask.dtype == 'bool':
|
51 |
+
seg_mask = seg_mask * 255
|
52 |
+
contours, hierarchy = cv2.findContours(seg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
53 |
+
contours = np.concatenate(contours, axis=0)
|
54 |
+
rect = cv2.minAreaRect(contours)
|
55 |
+
box = cv2.boxPoints(rect)
|
56 |
+
if rect[-1] >= 45:
|
57 |
+
newstart = box.argmin(axis=0)[1] # leftmost
|
58 |
+
else:
|
59 |
+
newstart = box.argmax(axis=0)[0] # topmost
|
60 |
+
box = np.concatenate([box[newstart:], box[:newstart]], axis=0)
|
61 |
+
box = np.int0(box)
|
62 |
+
return box
|
63 |
+
|
64 |
+
def get_w_h(rect_points):
|
65 |
+
w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int')
|
66 |
+
h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int')
|
67 |
+
return w, h
|
68 |
+
|
69 |
+
def cut_box(img, rect_points):
|
70 |
+
w, h = get_w_h(rect_points)
|
71 |
+
dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0],], dtype="float32")
|
72 |
+
transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts)
|
73 |
+
cropped_img = cv2.warpPerspective(img, transform, (h, w))
|
74 |
+
return cropped_img
|
75 |
+
|
76 |
+
class BaseCaptioner:
|
77 |
+
def __init__(self, device, enable_filter=False):
|
78 |
+
print(f"Initializing ImageCaptioning to {device}")
|
79 |
+
self.device = device
|
80 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
81 |
+
self.processor = None
|
82 |
+
self.model = None
|
83 |
+
self.enable_filter = enable_filter
|
84 |
+
if enable_filter:
|
85 |
+
self.filter, self.preprocess = clip.load('ViT-B/32', device)
|
86 |
+
self.threshold = 0.2
|
87 |
+
|
88 |
+
@torch.no_grad()
|
89 |
+
def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str):
|
90 |
+
|
91 |
+
if type(image) == str: # input path
|
92 |
+
image = Image.open(image)
|
93 |
+
elif type(image) == np.ndarray:
|
94 |
+
image = Image.fromarray(image)
|
95 |
+
|
96 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
|
97 |
+
text = clip.tokenize(caption).to(self.device) # (1, 77)
|
98 |
+
image_features = self.filter.encode_image(image) # (1, 512)
|
99 |
+
text_features = self.filter.encode_text(text) # (1, 512)
|
100 |
+
image_features /= image_features.norm(dim = -1, keepdim = True)
|
101 |
+
text_features /= text_features.norm(dim = -1, keepdim = True)
|
102 |
+
similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
|
103 |
+
if similarity < self.threshold:
|
104 |
+
print('There seems to be nothing where you clicked.')
|
105 |
+
out = ""
|
106 |
+
else:
|
107 |
+
out = caption
|
108 |
+
print(f'Clip score of the caption is {similarity}')
|
109 |
+
return out
|
110 |
+
|
111 |
+
|
112 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool=False):
|
113 |
+
raise NotImplementedError()
|
114 |
+
|
115 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool=False):
|
116 |
+
raise NotImplementedError()
|
117 |
+
|
118 |
+
def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False):
|
119 |
+
if type(image) == str: # input path
|
120 |
+
image = Image.open(image)
|
121 |
+
elif type(image) == np.ndarray:
|
122 |
+
image = Image.fromarray(image)
|
123 |
+
|
124 |
+
if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
|
125 |
+
size = max(image.width, image.height)
|
126 |
+
x1, y1, x2, y2 = box
|
127 |
+
image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
|
128 |
+
elif np.array(box).size == 8: # four corners of an irregular rectangle
|
129 |
+
image_crop = cut_box(np.array(image), box)
|
130 |
+
|
131 |
+
crop_save_path = f'result/crop_{time.time()}.png'
|
132 |
+
Image.fromarray(image_crop).save(crop_save_path)
|
133 |
+
print(f'croped image saved in {crop_save_path}')
|
134 |
+
caption = self.inference(image_crop, filter)
|
135 |
+
return caption, crop_save_path
|
136 |
+
|
137 |
+
|
138 |
+
def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False, disable_regular_box = False):
|
139 |
+
if type(image) == str:
|
140 |
+
image = Image.open(image)
|
141 |
+
if type(seg_mask) == str:
|
142 |
+
seg_mask = Image.open(seg_mask)
|
143 |
+
elif type(seg_mask) == np.ndarray:
|
144 |
+
seg_mask = Image.fromarray(seg_mask)
|
145 |
+
seg_mask = seg_mask.resize(image.size)
|
146 |
+
seg_mask = np.array(seg_mask) > 0
|
147 |
+
|
148 |
+
if crop_mode=="wo_bg":
|
149 |
+
image = np.array(image) * seg_mask[:,:,np.newaxis] + (1 - seg_mask[:,:,np.newaxis]) * 255
|
150 |
+
image = np.uint8(image)
|
151 |
+
else:
|
152 |
+
image = np.array(image)
|
153 |
+
|
154 |
+
if disable_regular_box:
|
155 |
+
min_area_box = seg_to_box(seg_mask)
|
156 |
+
else:
|
157 |
+
min_area_box = new_seg_to_box(seg_mask)
|
158 |
+
return self.inference_box(image, min_area_box, filter)
|
159 |
+
|
160 |
+
|
161 |
+
def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", disable_regular_box = False):
|
162 |
+
if type(image) == str:
|
163 |
+
image = Image.open(image)
|
164 |
+
if type(seg_mask) == str:
|
165 |
+
seg_mask = Image.open(seg_mask)
|
166 |
+
elif type(seg_mask) == np.ndarray:
|
167 |
+
seg_mask = Image.fromarray(seg_mask)
|
168 |
+
seg_mask = seg_mask.resize(image.size)
|
169 |
+
seg_mask = np.array(seg_mask) > 0
|
170 |
+
|
171 |
+
if crop_mode=="wo_bg":
|
172 |
+
image = np.array(image) * seg_mask[:,:,np.newaxis] + (1- seg_mask[:,:,np.newaxis]) * 255
|
173 |
+
else:
|
174 |
+
image = np.array(image)
|
175 |
+
|
176 |
+
if disable_regular_box:
|
177 |
+
box = seg_to_box(seg_mask)
|
178 |
+
else:
|
179 |
+
box = new_seg_to_box(seg_mask)
|
180 |
+
|
181 |
+
if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
|
182 |
+
size = max(image.shape[0], image.shape[1])
|
183 |
+
x1, y1, x2, y2 = box
|
184 |
+
image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
|
185 |
+
elif np.array(box).size == 8: # four corners of an irregular rectangle
|
186 |
+
image_crop = cut_box(np.array(image), box)
|
187 |
+
crop_save_path = f'result/crop_{time.time()}.png'
|
188 |
+
Image.fromarray(image_crop).save(crop_save_path)
|
189 |
+
print(f'croped image saved in {crop_save_path}')
|
190 |
+
return crop_save_path
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == '__main__':
|
194 |
+
model = BaseCaptioner(device='cuda:0')
|
195 |
+
image_path = 'test_img/img2.jpg'
|
196 |
+
seg_mask = np.zeros((15,15))
|
197 |
+
seg_mask[5:10, 5:10] = 1
|
198 |
+
seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
|
199 |
+
print(model.inference_seg(image_path, seg_mask))
|
200 |
+
|
captioner/blip.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image, ImageDraw, ImageOps
|
3 |
+
from transformers import BlipProcessor
|
4 |
+
from .modeling_blip import BlipForConditionalGeneration
|
5 |
+
import json
|
6 |
+
import pdb
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from typing import Union
|
10 |
+
from .base_captioner import BaseCaptioner
|
11 |
+
import torchvision.transforms.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class BLIPCaptioner(BaseCaptioner):
|
15 |
+
def __init__(self, device, enable_filter=False):
|
16 |
+
super().__init__(device, enable_filter)
|
17 |
+
self.device = device
|
18 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
19 |
+
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
20 |
+
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=self.torch_dtype).to(self.device)
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
24 |
+
if type(image) == str: # input path
|
25 |
+
image = Image.open(image)
|
26 |
+
inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
|
27 |
+
out = self.model.generate(**inputs, max_new_tokens=50)
|
28 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
29 |
+
if self.enable_filter and filter:
|
30 |
+
captions = self.filter_caption(image, captions)
|
31 |
+
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
32 |
+
return captions
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
|
36 |
+
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
|
37 |
+
if type(image) == str: # input path
|
38 |
+
image = Image.open(image)
|
39 |
+
inputs = self.processor(image, return_tensors="pt")
|
40 |
+
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
|
41 |
+
_, _, H, W = pixel_values.shape
|
42 |
+
seg_mask = Image.fromarray(seg_mask.astype(float))
|
43 |
+
seg_mask = seg_mask.resize((H, W))
|
44 |
+
seg_mask = F.pil_to_tensor(seg_mask) > 0.5
|
45 |
+
seg_mask = seg_mask.float()
|
46 |
+
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
|
47 |
+
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
|
48 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
49 |
+
if self.enable_filter and filter:
|
50 |
+
captions = self.filter_caption(image, captions)
|
51 |
+
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
52 |
+
return captions, crop_save_path
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
model = BLIPCaptioner(device='cuda:0')
|
57 |
+
# image_path = 'test_img/img2.jpg'
|
58 |
+
image_path = '/group/30042/wybertwang/project/woa_visgpt/chatARC/image/SAM/img10.jpg'
|
59 |
+
seg_mask = np.zeros((15,15))
|
60 |
+
seg_mask[5:10, 5:10] = 1
|
61 |
+
seg_mask = 'test_img/img10.jpg.raw_mask.png'
|
62 |
+
image_path = 'test_img/img2.jpg'
|
63 |
+
seg_mask = 'test_img/img2.jpg.raw_mask.png'
|
64 |
+
print(f'process image {image_path}')
|
65 |
+
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
66 |
+
|
captioner/blip2.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image, ImageDraw, ImageOps
|
3 |
+
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
4 |
+
import json
|
5 |
+
import pdb
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from typing import Union
|
9 |
+
from .base_captioner import BaseCaptioner
|
10 |
+
|
11 |
+
class BLIP2Captioner(BaseCaptioner):
|
12 |
+
def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
|
13 |
+
super().__init__(device, enable_filter)
|
14 |
+
self.device = device
|
15 |
+
self.dialogue = dialogue
|
16 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
17 |
+
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
18 |
+
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map = 'sequential', load_in_8bit=True)
|
19 |
+
@torch.no_grad()
|
20 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
21 |
+
if type(image) == str: # input path
|
22 |
+
image = Image.open(image)
|
23 |
+
|
24 |
+
if not self.dialogue:
|
25 |
+
text_prompt = 'Context: ignore the white part in this image. Question: describe this image. Answer:'
|
26 |
+
inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
|
27 |
+
out = self.model.generate(**inputs, max_new_tokens=50)
|
28 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
29 |
+
if self.enable_filter and filter:
|
30 |
+
captions = self.filter_caption(image, captions)
|
31 |
+
print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
|
32 |
+
return captions
|
33 |
+
else:
|
34 |
+
context = []
|
35 |
+
template = "Question: {} Answer: {}."
|
36 |
+
while(True):
|
37 |
+
input_texts = input()
|
38 |
+
if input_texts == 'end':
|
39 |
+
break
|
40 |
+
prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:"
|
41 |
+
inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype)
|
42 |
+
out = self.model.generate(**inputs, max_new_tokens=50)
|
43 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
44 |
+
context.append((input_texts, captions))
|
45 |
+
|
46 |
+
return captions
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
|
50 |
+
dialogue = False
|
51 |
+
model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
|
52 |
+
image_path = 'test_img/img2.jpg'
|
53 |
+
seg_mask = np.zeros((224,224))
|
54 |
+
seg_mask[50:200, 50:200] = 1
|
55 |
+
print(f'process image {image_path}')
|
56 |
+
print(model.inference_seg(image_path, seg_mask))
|
captioner/git.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GitProcessor, AutoProcessor
|
2 |
+
from .modeling_git import GitForCausalLM
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from .base_captioner import BaseCaptioner
|
6 |
+
import numpy as np
|
7 |
+
from typing import Union
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class GITCaptioner(BaseCaptioner):
|
12 |
+
def __init__(self, device, enable_filter=False):
|
13 |
+
super().__init__(device, enable_filter)
|
14 |
+
self.device = device
|
15 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
16 |
+
self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
|
17 |
+
self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
21 |
+
if type(image) == str: # input path
|
22 |
+
image = Image.open(image)
|
23 |
+
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
|
24 |
+
generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
|
25 |
+
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
26 |
+
if self.enable_filter and filter:
|
27 |
+
captions = self.filter_caption(image, captions)
|
28 |
+
print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}")
|
29 |
+
return generated_caption
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
|
33 |
+
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
|
34 |
+
if type(image) == str: # input path
|
35 |
+
image = Image.open(image)
|
36 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
37 |
+
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
|
38 |
+
_, _, H, W = pixel_values.shape
|
39 |
+
seg_mask = Image.fromarray(seg_mask.astype(float))
|
40 |
+
seg_mask = seg_mask.resize((H, W))
|
41 |
+
seg_mask = F.pil_to_tensor(seg_mask) > 0.5
|
42 |
+
seg_mask = seg_mask.float()
|
43 |
+
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
|
44 |
+
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
|
45 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
46 |
+
if self.enable_filter and filter:
|
47 |
+
captions = self.filter_caption(image, captions)
|
48 |
+
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
49 |
+
return captions, crop_save_path
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
model = GITCaptioner(device='cuda:2', enable_filter=False)
|
53 |
+
image_path = 'test_img/img2.jpg'
|
54 |
+
seg_mask = np.zeros((224,224))
|
55 |
+
seg_mask[50:200, 50:200] = 1
|
56 |
+
print(f'process image {image_path}')
|
57 |
+
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
captioner/modeling_blip.py
ADDED
@@ -0,0 +1,1476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch BLIP model."""
|
16 |
+
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Any, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.utils.checkpoint
|
22 |
+
from torch import nn
|
23 |
+
from torch.nn.functional import normalize
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
27 |
+
from transformers.modeling_utils import PreTrainedModel
|
28 |
+
from transformers.utils import (
|
29 |
+
ModelOutput,
|
30 |
+
add_start_docstrings,
|
31 |
+
add_start_docstrings_to_model_forward,
|
32 |
+
logging,
|
33 |
+
replace_return_docstrings,
|
34 |
+
)
|
35 |
+
from transformers.models.blip.configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
|
36 |
+
from transformers.models.blip.modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
|
37 |
+
from .vit_pixel_masks_utils import ViTPatchMaskGenerator
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__)
|
40 |
+
|
41 |
+
_CHECKPOINT_FOR_DOC = "Salesforce/blip-vqa-base"
|
42 |
+
|
43 |
+
BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
44 |
+
"Salesforce/blip-vqa-base",
|
45 |
+
"Salesforce/blip-vqa-capfit-large",
|
46 |
+
"Salesforce/blip-image-captioning-base",
|
47 |
+
"Salesforce/blip-image-captioning-large",
|
48 |
+
"Salesforce/blip-itm-base-coco",
|
49 |
+
"Salesforce/blip-itm-large-coco",
|
50 |
+
"Salesforce/blip-itm-base-flikr",
|
51 |
+
"Salesforce/blip-itm-large-flikr",
|
52 |
+
# See all BLIP models at https://huggingface.co/models?filter=blip
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
# Copied from transformers.models.clip.modeling_clip.contrastive_loss
|
57 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
58 |
+
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
59 |
+
|
60 |
+
|
61 |
+
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip
|
62 |
+
def blip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
63 |
+
caption_loss = contrastive_loss(similarity)
|
64 |
+
image_loss = contrastive_loss(similarity.t())
|
65 |
+
return (caption_loss + image_loss) / 2.0
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class BlipForConditionalGenerationModelOutput(ModelOutput):
|
70 |
+
"""
|
71 |
+
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
|
72 |
+
last hidden states. This class also adds the loss term from the text decoder.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
76 |
+
Languge modeling loss from the text decoder.
|
77 |
+
decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
|
78 |
+
Prediction scores of the language modeling head of the text decoder model.
|
79 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
|
80 |
+
The image embeddings obtained after applying the Vision Transformer model to the input image.
|
81 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
82 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
83 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
84 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
85 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
86 |
+
|
87 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
88 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
|
89 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
90 |
+
sequence_length)`.
|
91 |
+
|
92 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
93 |
+
heads.
|
94 |
+
"""
|
95 |
+
|
96 |
+
loss: Optional[Tuple[torch.FloatTensor]] = None
|
97 |
+
decoder_logits: Optional[Tuple[torch.FloatTensor]] = None
|
98 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
99 |
+
last_hidden_state: torch.FloatTensor = None
|
100 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
101 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
102 |
+
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class BlipTextVisionModelOutput(ModelOutput):
|
106 |
+
"""
|
107 |
+
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
|
108 |
+
last hidden states. This class also adds the loss term from the text decoder.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
112 |
+
Languge modeling loss from the text decoder.
|
113 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
114 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
115 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
116 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
117 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
118 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
119 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
120 |
+
|
121 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
122 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
123 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
124 |
+
sequence_length)`.
|
125 |
+
|
126 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
127 |
+
heads.
|
128 |
+
"""
|
129 |
+
|
130 |
+
loss: Optional[torch.FloatTensor] = None
|
131 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
132 |
+
last_hidden_state: torch.FloatTensor = None
|
133 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
134 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
135 |
+
|
136 |
+
|
137 |
+
@dataclass
|
138 |
+
class BlipImageTextMatchingModelOutput(ModelOutput):
|
139 |
+
"""
|
140 |
+
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
|
141 |
+
last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
|
142 |
+
scores.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
itm_score (`torch.FloatTensor`):
|
146 |
+
The image-text similarity scores.
|
147 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
148 |
+
Languge modeling loss from the text decoder.
|
149 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
150 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
151 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
152 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
153 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
154 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
155 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
156 |
+
|
157 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
158 |
+
vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
|
159 |
+
Last layer hidden-state of the vision of the vision-only branch of the model.
|
160 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
161 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
162 |
+
sequence_length)`.
|
163 |
+
|
164 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
165 |
+
heads.
|
166 |
+
question_embeds (`torch.FloatTensor`):
|
167 |
+
The question embeddings obtained by the text projection layer.
|
168 |
+
"""
|
169 |
+
|
170 |
+
itm_score: Optional[torch.FloatTensor] = None
|
171 |
+
loss: Optional[torch.FloatTensor] = None
|
172 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
173 |
+
last_hidden_state: torch.FloatTensor = None
|
174 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
175 |
+
vision_pooler_output: Optional[torch.FloatTensor] = None
|
176 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
177 |
+
question_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
178 |
+
|
179 |
+
|
180 |
+
@dataclass
|
181 |
+
class BlipOutput(ModelOutput):
|
182 |
+
"""
|
183 |
+
Args:
|
184 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
185 |
+
Contrastive loss for image-text similarity.
|
186 |
+
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
187 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
188 |
+
similarity scores.
|
189 |
+
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
190 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
191 |
+
similarity scores.
|
192 |
+
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
193 |
+
The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
|
194 |
+
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
195 |
+
The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
|
196 |
+
text_model_output(`BaseModelOutputWithPooling`):
|
197 |
+
The output of the [`BlipTextModel`].
|
198 |
+
vision_model_output(`BaseModelOutputWithPooling`):
|
199 |
+
The output of the [`BlipVisionModel`].
|
200 |
+
"""
|
201 |
+
|
202 |
+
loss: Optional[torch.FloatTensor] = None
|
203 |
+
logits_per_image: torch.FloatTensor = None
|
204 |
+
logits_per_text: torch.FloatTensor = None
|
205 |
+
text_embeds: torch.FloatTensor = None
|
206 |
+
image_embeds: torch.FloatTensor = None
|
207 |
+
text_model_output: BaseModelOutputWithPooling = None
|
208 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
209 |
+
|
210 |
+
def to_tuple(self) -> Tuple[Any]:
|
211 |
+
return tuple(
|
212 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
213 |
+
for k in self.keys()
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
class BlipVisionEmbeddings(nn.Module):
|
218 |
+
def __init__(self, config: BlipVisionConfig):
|
219 |
+
super().__init__()
|
220 |
+
self.config = config
|
221 |
+
self.embed_dim = config.hidden_size
|
222 |
+
self.image_size = config.image_size
|
223 |
+
self.patch_size = config.patch_size
|
224 |
+
|
225 |
+
self.class_embedding = nn.Parameter(
|
226 |
+
torch.randn(1, 1, self.embed_dim),
|
227 |
+
)
|
228 |
+
|
229 |
+
self.patch_embedding = nn.Conv2d(
|
230 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
231 |
+
)
|
232 |
+
|
233 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
234 |
+
self.num_positions = self.num_patches + 1
|
235 |
+
|
236 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
237 |
+
|
238 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
239 |
+
batch_size = pixel_values.shape[0]
|
240 |
+
target_dtype = self.patch_embedding.weight.dtype
|
241 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
242 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
243 |
+
|
244 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
245 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
246 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
247 |
+
return embeddings
|
248 |
+
|
249 |
+
|
250 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip
|
251 |
+
class BlipTextEmbeddings(nn.Module):
|
252 |
+
def __init__(self, config: BlipTextConfig):
|
253 |
+
super().__init__()
|
254 |
+
embed_dim = config.hidden_size
|
255 |
+
|
256 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
257 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
258 |
+
|
259 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
260 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
261 |
+
|
262 |
+
def forward(
|
263 |
+
self,
|
264 |
+
input_ids: Optional[torch.LongTensor] = None,
|
265 |
+
position_ids: Optional[torch.LongTensor] = None,
|
266 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
267 |
+
) -> torch.Tensor:
|
268 |
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
269 |
+
|
270 |
+
if position_ids is None:
|
271 |
+
position_ids = self.position_ids[:, :seq_length]
|
272 |
+
|
273 |
+
if inputs_embeds is None:
|
274 |
+
inputs_embeds = self.token_embedding(input_ids)
|
275 |
+
|
276 |
+
position_embeddings = self.position_embedding(position_ids)
|
277 |
+
embeddings = inputs_embeds + position_embeddings
|
278 |
+
|
279 |
+
return embeddings
|
280 |
+
|
281 |
+
|
282 |
+
class BlipAttention(nn.Module):
|
283 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
284 |
+
|
285 |
+
def __init__(self, config):
|
286 |
+
super().__init__()
|
287 |
+
self.config = config
|
288 |
+
self.embed_dim = config.hidden_size
|
289 |
+
self.num_heads = config.num_attention_heads
|
290 |
+
self.head_dim = self.embed_dim // self.num_heads
|
291 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
292 |
+
raise ValueError(
|
293 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
294 |
+
f" {self.num_heads})."
|
295 |
+
)
|
296 |
+
self.scale = self.head_dim**-0.5
|
297 |
+
self.dropout = nn.Dropout(config.attention_dropout)
|
298 |
+
|
299 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
|
300 |
+
|
301 |
+
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
|
302 |
+
|
303 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
304 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
305 |
+
|
306 |
+
def forward(
|
307 |
+
self,
|
308 |
+
hidden_states: torch.Tensor,
|
309 |
+
head_mask: Optional[torch.Tensor] = None,
|
310 |
+
output_attentions: Optional[bool] = False,
|
311 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
312 |
+
"""Input shape: Batch x Time x Channel"""
|
313 |
+
|
314 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
315 |
+
|
316 |
+
mixed_qkv = self.qkv(hidden_states)
|
317 |
+
mixed_qkv = (
|
318 |
+
self.qkv(hidden_states)
|
319 |
+
.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
|
320 |
+
.permute(2, 0, 3, 1, 4)
|
321 |
+
)
|
322 |
+
query_states, key_states, value_states = (
|
323 |
+
mixed_qkv[0],
|
324 |
+
mixed_qkv[1],
|
325 |
+
mixed_qkv[2],
|
326 |
+
)
|
327 |
+
|
328 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
329 |
+
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
330 |
+
|
331 |
+
attention_scores = attention_scores * self.scale
|
332 |
+
|
333 |
+
# Normalize the attention scores to probabilities.
|
334 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
335 |
+
|
336 |
+
# This is actually dropping out entire tokens to attend to, which might
|
337 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
338 |
+
attention_probs = self.dropout(attention_probs)
|
339 |
+
|
340 |
+
# Mask heads if we want to
|
341 |
+
if head_mask is not None:
|
342 |
+
attention_probs = attention_probs * head_mask
|
343 |
+
|
344 |
+
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
345 |
+
|
346 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
347 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
348 |
+
|
349 |
+
output = self.projection(context_layer)
|
350 |
+
|
351 |
+
outputs = (output, attention_probs) if output_attentions else (output, None)
|
352 |
+
|
353 |
+
return outputs
|
354 |
+
|
355 |
+
|
356 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip
|
357 |
+
class BlipMLP(nn.Module):
|
358 |
+
def __init__(self, config):
|
359 |
+
super().__init__()
|
360 |
+
self.config = config
|
361 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
362 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
363 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
364 |
+
|
365 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
366 |
+
hidden_states = self.fc1(hidden_states)
|
367 |
+
hidden_states = self.activation_fn(hidden_states)
|
368 |
+
hidden_states = self.fc2(hidden_states)
|
369 |
+
return hidden_states
|
370 |
+
|
371 |
+
|
372 |
+
class BlipEncoderLayer(nn.Module):
|
373 |
+
def __init__(self, config: BlipConfig):
|
374 |
+
super().__init__()
|
375 |
+
self.embed_dim = config.hidden_size
|
376 |
+
self.self_attn = BlipAttention(config)
|
377 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
378 |
+
self.mlp = BlipMLP(config)
|
379 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
380 |
+
|
381 |
+
def forward(
|
382 |
+
self,
|
383 |
+
hidden_states: torch.Tensor,
|
384 |
+
attention_mask: torch.Tensor,
|
385 |
+
output_attentions: Optional[bool] = False,
|
386 |
+
) -> Tuple[torch.FloatTensor]:
|
387 |
+
"""
|
388 |
+
Args:
|
389 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
390 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
391 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
392 |
+
`(config.encoder_attention_heads,)`.
|
393 |
+
output_attentions (`bool`, *optional*):
|
394 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
395 |
+
returned tensors for more detail.
|
396 |
+
"""
|
397 |
+
residual = hidden_states
|
398 |
+
|
399 |
+
hidden_states = self.layer_norm1(hidden_states)
|
400 |
+
hidden_states, attn_weights = self.self_attn(
|
401 |
+
hidden_states=hidden_states,
|
402 |
+
head_mask=attention_mask,
|
403 |
+
output_attentions=output_attentions,
|
404 |
+
)
|
405 |
+
hidden_states = hidden_states + residual
|
406 |
+
residual = hidden_states
|
407 |
+
hidden_states = self.layer_norm2(hidden_states)
|
408 |
+
hidden_states = self.mlp(hidden_states)
|
409 |
+
|
410 |
+
hidden_states = hidden_states + residual
|
411 |
+
|
412 |
+
outputs = (hidden_states,)
|
413 |
+
|
414 |
+
if output_attentions:
|
415 |
+
outputs += (attn_weights,)
|
416 |
+
|
417 |
+
return outputs
|
418 |
+
|
419 |
+
|
420 |
+
class BlipPreTrainedModel(PreTrainedModel):
|
421 |
+
"""
|
422 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
423 |
+
models.
|
424 |
+
"""
|
425 |
+
|
426 |
+
config_class = BlipConfig
|
427 |
+
base_model_prefix = "blip"
|
428 |
+
supports_gradient_checkpointing = True
|
429 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
430 |
+
|
431 |
+
def _init_weights(self, module):
|
432 |
+
"""Initialize the weights"""
|
433 |
+
factor = self.config.initializer_range
|
434 |
+
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
435 |
+
module.weight.data.normal_(mean=0.0, std=factor)
|
436 |
+
if hasattr(module, "bias") and module.bias is not None:
|
437 |
+
module.bias.data.zero_()
|
438 |
+
|
439 |
+
if isinstance(module, BlipVisionEmbeddings):
|
440 |
+
if hasattr(self.config, "vision_config"):
|
441 |
+
factor = self.config.vision_config.initializer_range
|
442 |
+
nn.init.trunc_normal_(
|
443 |
+
module.position_embedding,
|
444 |
+
mean=0.0,
|
445 |
+
std=factor,
|
446 |
+
)
|
447 |
+
|
448 |
+
nn.init.trunc_normal_(
|
449 |
+
module.class_embedding,
|
450 |
+
mean=0.0,
|
451 |
+
std=factor,
|
452 |
+
)
|
453 |
+
|
454 |
+
elif isinstance(module, nn.LayerNorm):
|
455 |
+
module.bias.data.zero_()
|
456 |
+
module.weight.data.fill_(1.0)
|
457 |
+
elif isinstance(module, nn.Linear) and module.bias is not None:
|
458 |
+
module.bias.data.zero_()
|
459 |
+
|
460 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
461 |
+
if isinstance(module, BlipEncoder):
|
462 |
+
module.gradient_checkpointing = value
|
463 |
+
|
464 |
+
|
465 |
+
BLIP_START_DOCSTRING = r"""
|
466 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
467 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
468 |
+
etc.)
|
469 |
+
|
470 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
471 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
472 |
+
and behavior.
|
473 |
+
|
474 |
+
Parameters:
|
475 |
+
config ([`BlipConfig`]): Model configuration class with all the parameters of the model.
|
476 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
477 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
478 |
+
"""
|
479 |
+
|
480 |
+
BLIP_TEXT_INPUTS_DOCSTRING = r"""
|
481 |
+
Args:
|
482 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
483 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
484 |
+
it.
|
485 |
+
|
486 |
+
Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
|
487 |
+
|
488 |
+
[What are input IDs?](../glossary#input-ids)
|
489 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
490 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
491 |
+
|
492 |
+
- 1 for tokens that are **not masked**,
|
493 |
+
- 0 for tokens that are **masked**.
|
494 |
+
|
495 |
+
[What are attention masks?](../glossary#attention-mask)
|
496 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
497 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
498 |
+
config.max_position_embeddings - 1]`.
|
499 |
+
|
500 |
+
[What are position IDs?](../glossary#position-ids)
|
501 |
+
output_attentions (`bool`, *optional*):
|
502 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
503 |
+
tensors for more detail.
|
504 |
+
output_hidden_states (`bool`, *optional*):
|
505 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
506 |
+
more detail.
|
507 |
+
return_dict (`bool`, *optional*):
|
508 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
509 |
+
"""
|
510 |
+
|
511 |
+
BLIP_VISION_INPUTS_DOCSTRING = r"""
|
512 |
+
Args:
|
513 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
514 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
515 |
+
[`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
|
516 |
+
output_attentions (`bool`, *optional*):
|
517 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
518 |
+
tensors for more detail.
|
519 |
+
output_hidden_states (`bool`, *optional*):
|
520 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
521 |
+
more detail.
|
522 |
+
return_dict (`bool`, *optional*):
|
523 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
524 |
+
"""
|
525 |
+
|
526 |
+
BLIP_INPUTS_DOCSTRING = r"""
|
527 |
+
Args:
|
528 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
529 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
530 |
+
it.
|
531 |
+
|
532 |
+
Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
|
533 |
+
|
534 |
+
[What are input IDs?](../glossary#input-ids)
|
535 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
536 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
537 |
+
|
538 |
+
- 1 for tokens that are **not masked**,
|
539 |
+
- 0 for tokens that are **masked**.
|
540 |
+
|
541 |
+
[What are attention masks?](../glossary#attention-mask)
|
542 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
543 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
544 |
+
config.max_position_embeddings - 1]`.
|
545 |
+
|
546 |
+
[What are position IDs?](../glossary#position-ids)
|
547 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
548 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
549 |
+
[`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
|
550 |
+
return_loss (`bool`, *optional*):
|
551 |
+
Whether or not to return the contrastive loss.
|
552 |
+
output_attentions (`bool`, *optional*):
|
553 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
554 |
+
tensors for more detail.
|
555 |
+
output_hidden_states (`bool`, *optional*):
|
556 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
557 |
+
more detail.
|
558 |
+
return_dict (`bool`, *optional*):
|
559 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
560 |
+
"""
|
561 |
+
|
562 |
+
|
563 |
+
class BlipEncoder(nn.Module):
|
564 |
+
"""
|
565 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
566 |
+
[`BlipEncoderLayer`].
|
567 |
+
|
568 |
+
Args:
|
569 |
+
config (`BlipConfig`):
|
570 |
+
The corresponding vision configuration for the `BlipEncoder`.
|
571 |
+
"""
|
572 |
+
|
573 |
+
def __init__(self, config: BlipConfig):
|
574 |
+
super().__init__()
|
575 |
+
self.config = config
|
576 |
+
self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
577 |
+
self.gradient_checkpointing = False
|
578 |
+
|
579 |
+
def forward(
|
580 |
+
self,
|
581 |
+
inputs_embeds,
|
582 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
583 |
+
output_attentions: Optional[bool] = None,
|
584 |
+
output_hidden_states: Optional[bool] = None,
|
585 |
+
return_dict: Optional[bool] = None,
|
586 |
+
) -> Union[Tuple, BaseModelOutput]:
|
587 |
+
r"""
|
588 |
+
Args:
|
589 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
590 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
591 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
592 |
+
than the model's internal embedding lookup matrix.
|
593 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
594 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
595 |
+
|
596 |
+
- 1 for tokens that are **not masked**,
|
597 |
+
- 0 for tokens that are **masked**.
|
598 |
+
|
599 |
+
[What are attention masks?](../glossary#attention-mask)
|
600 |
+
output_attentions (`bool`, *optional*):
|
601 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
602 |
+
returned tensors for more detail.
|
603 |
+
output_hidden_states (`bool`, *optional*):
|
604 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
605 |
+
for more detail.
|
606 |
+
return_dict (`bool`, *optional*):
|
607 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
608 |
+
"""
|
609 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
610 |
+
output_hidden_states = (
|
611 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
612 |
+
)
|
613 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
614 |
+
|
615 |
+
encoder_states = () if output_hidden_states else None
|
616 |
+
all_attentions = () if output_attentions else None
|
617 |
+
|
618 |
+
hidden_states = inputs_embeds
|
619 |
+
for idx, encoder_layer in enumerate(self.layers):
|
620 |
+
if output_hidden_states:
|
621 |
+
encoder_states = encoder_states + (hidden_states,)
|
622 |
+
if self.gradient_checkpointing and self.training:
|
623 |
+
|
624 |
+
def create_custom_forward(module):
|
625 |
+
def custom_forward(*inputs):
|
626 |
+
return module(*inputs, output_attentions)
|
627 |
+
|
628 |
+
return custom_forward
|
629 |
+
|
630 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
631 |
+
create_custom_forward(encoder_layer),
|
632 |
+
hidden_states,
|
633 |
+
attention_mask,
|
634 |
+
)
|
635 |
+
else:
|
636 |
+
layer_outputs = encoder_layer(
|
637 |
+
hidden_states,
|
638 |
+
attention_mask,
|
639 |
+
output_attentions=output_attentions,
|
640 |
+
)
|
641 |
+
|
642 |
+
hidden_states = layer_outputs[0]
|
643 |
+
|
644 |
+
if output_attentions:
|
645 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
646 |
+
|
647 |
+
if output_hidden_states:
|
648 |
+
encoder_states = encoder_states + (hidden_states,)
|
649 |
+
|
650 |
+
if not return_dict:
|
651 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
652 |
+
return BaseModelOutput(
|
653 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
654 |
+
)
|
655 |
+
|
656 |
+
|
657 |
+
class BlipVisionModel(BlipPreTrainedModel):
|
658 |
+
main_input_name = "pixel_values"
|
659 |
+
config_class = BlipVisionConfig
|
660 |
+
|
661 |
+
def __init__(self, config: BlipVisionConfig):
|
662 |
+
super().__init__(config)
|
663 |
+
self.config = config
|
664 |
+
embed_dim = config.hidden_size
|
665 |
+
self.embeddings = BlipVisionEmbeddings(config)
|
666 |
+
self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size)
|
667 |
+
self.encoder = BlipEncoder(config)
|
668 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
669 |
+
|
670 |
+
self.post_init()
|
671 |
+
|
672 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
673 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=BlipVisionConfig)
|
674 |
+
def forward(
|
675 |
+
self,
|
676 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
677 |
+
pixel_masks: Optional[torch.LongTensor] = None,
|
678 |
+
output_attentions: Optional[bool] = None,
|
679 |
+
output_hidden_states: Optional[bool] = None,
|
680 |
+
return_dict: Optional[bool] = None,
|
681 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
682 |
+
r"""
|
683 |
+
Returns:
|
684 |
+
|
685 |
+
"""
|
686 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
687 |
+
output_hidden_states = (
|
688 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
689 |
+
)
|
690 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
691 |
+
|
692 |
+
if pixel_values is None:
|
693 |
+
raise ValueError("You have to specify pixel_values")
|
694 |
+
|
695 |
+
hidden_states = self.embeddings(pixel_values)
|
696 |
+
B, N, D = hidden_states.shape
|
697 |
+
# print('Before mask:', hidden_states.shape)
|
698 |
+
if pixel_masks is not None:
|
699 |
+
assert pixel_masks.shape[0] == 1
|
700 |
+
patch_masks = self.patch_mask_generator(pixel_masks)
|
701 |
+
# print(patch_masks.shape)
|
702 |
+
patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states)
|
703 |
+
hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D)
|
704 |
+
# print('After mask:', hidden_states.shape)
|
705 |
+
|
706 |
+
encoder_outputs = self.encoder(
|
707 |
+
inputs_embeds=hidden_states,
|
708 |
+
output_attentions=output_attentions,
|
709 |
+
output_hidden_states=output_hidden_states,
|
710 |
+
return_dict=return_dict,
|
711 |
+
)
|
712 |
+
|
713 |
+
last_hidden_state = encoder_outputs[0]
|
714 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
715 |
+
|
716 |
+
pooled_output = last_hidden_state[:, 0, :]
|
717 |
+
pooled_output = self.post_layernorm(pooled_output)
|
718 |
+
|
719 |
+
if not return_dict:
|
720 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
721 |
+
|
722 |
+
return BaseModelOutputWithPooling(
|
723 |
+
last_hidden_state=last_hidden_state,
|
724 |
+
pooler_output=pooled_output,
|
725 |
+
hidden_states=encoder_outputs.hidden_states,
|
726 |
+
attentions=encoder_outputs.attentions,
|
727 |
+
)
|
728 |
+
|
729 |
+
def get_input_embeddings(self):
|
730 |
+
return self.embeddings
|
731 |
+
|
732 |
+
|
733 |
+
@add_start_docstrings(BLIP_START_DOCSTRING)
|
734 |
+
class BlipModel(BlipPreTrainedModel):
|
735 |
+
config_class = BlipConfig
|
736 |
+
|
737 |
+
def __init__(self, config: BlipConfig):
|
738 |
+
super().__init__(config)
|
739 |
+
|
740 |
+
if not isinstance(config.text_config, BlipTextConfig):
|
741 |
+
raise ValueError(
|
742 |
+
"config.text_config is expected to be of type BlipTextConfig but is of type"
|
743 |
+
f" {type(config.text_config)}."
|
744 |
+
)
|
745 |
+
|
746 |
+
if not isinstance(config.vision_config, BlipVisionConfig):
|
747 |
+
raise ValueError(
|
748 |
+
"config.vision_config is expected to be of type BlipVisionConfig but is of type"
|
749 |
+
f" {type(config.vision_config)}."
|
750 |
+
)
|
751 |
+
|
752 |
+
text_config = config.text_config
|
753 |
+
vision_config = config.vision_config
|
754 |
+
|
755 |
+
self.projection_dim = config.projection_dim
|
756 |
+
self.text_embed_dim = text_config.hidden_size
|
757 |
+
self.vision_embed_dim = vision_config.hidden_size
|
758 |
+
|
759 |
+
self.text_model = BlipTextModel(text_config)
|
760 |
+
self.vision_model = BlipVisionModel(vision_config)
|
761 |
+
|
762 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
763 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
764 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
|
765 |
+
|
766 |
+
# Initialize weights and apply final processing
|
767 |
+
self.post_init()
|
768 |
+
|
769 |
+
@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
|
770 |
+
def get_text_features(
|
771 |
+
self,
|
772 |
+
input_ids: Optional[torch.Tensor] = None,
|
773 |
+
attention_mask: Optional[torch.Tensor] = None,
|
774 |
+
position_ids: Optional[torch.Tensor] = None,
|
775 |
+
return_dict: Optional[bool] = None,
|
776 |
+
) -> torch.FloatTensor:
|
777 |
+
r"""
|
778 |
+
Returns:
|
779 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
780 |
+
applying the projection layer to the pooled output of [`BlipTextModel`].
|
781 |
+
|
782 |
+
Examples:
|
783 |
+
|
784 |
+
```python
|
785 |
+
>>> from transformers import AutoProcessor, BlipModel
|
786 |
+
|
787 |
+
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
|
788 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
789 |
+
|
790 |
+
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
791 |
+
>>> text_features = model.get_text_features(**inputs)
|
792 |
+
```"""
|
793 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
794 |
+
|
795 |
+
text_outputs = self.text_model(
|
796 |
+
input_ids=input_ids,
|
797 |
+
attention_mask=attention_mask,
|
798 |
+
position_ids=position_ids,
|
799 |
+
return_dict=return_dict,
|
800 |
+
)
|
801 |
+
|
802 |
+
pooled_output = text_outputs[1]
|
803 |
+
text_features = self.text_projection(pooled_output)
|
804 |
+
|
805 |
+
return text_features
|
806 |
+
|
807 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
808 |
+
def get_image_features(
|
809 |
+
self,
|
810 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
811 |
+
return_dict: Optional[bool] = None,
|
812 |
+
) -> torch.FloatTensor:
|
813 |
+
r"""
|
814 |
+
Returns:
|
815 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
816 |
+
applying the projection layer to the pooled output of [`BlipVisionModel`].
|
817 |
+
|
818 |
+
Examples:
|
819 |
+
|
820 |
+
```python
|
821 |
+
>>> from PIL import Image
|
822 |
+
>>> import requests
|
823 |
+
>>> from transformers import AutoProcessor, BlipModel
|
824 |
+
|
825 |
+
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
|
826 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
827 |
+
|
828 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
829 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
830 |
+
|
831 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
832 |
+
|
833 |
+
>>> image_features = model.get_image_features(**inputs)
|
834 |
+
```"""
|
835 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
836 |
+
|
837 |
+
vision_outputs = self.vision_model(
|
838 |
+
pixel_values=pixel_values,
|
839 |
+
return_dict=return_dict,
|
840 |
+
)
|
841 |
+
|
842 |
+
pooled_output = vision_outputs[1] # pooled_output
|
843 |
+
image_features = self.visual_projection(pooled_output)
|
844 |
+
|
845 |
+
return image_features
|
846 |
+
|
847 |
+
@add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
|
848 |
+
@replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig)
|
849 |
+
def forward(
|
850 |
+
self,
|
851 |
+
input_ids: Optional[torch.LongTensor] = None,
|
852 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
853 |
+
pixel_masks: Optional[torch.FloatTensor] = None,
|
854 |
+
attention_mask: Optional[torch.Tensor] = None,
|
855 |
+
position_ids: Optional[torch.LongTensor] = None,
|
856 |
+
return_loss: Optional[bool] = None,
|
857 |
+
output_attentions: Optional[bool] = None,
|
858 |
+
output_hidden_states: Optional[bool] = None,
|
859 |
+
return_dict: Optional[bool] = None,
|
860 |
+
) -> Union[Tuple, BlipOutput]:
|
861 |
+
r"""
|
862 |
+
Returns:
|
863 |
+
|
864 |
+
Examples:
|
865 |
+
|
866 |
+
```python
|
867 |
+
>>> from PIL import Image
|
868 |
+
>>> import requests
|
869 |
+
>>> from transformers import AutoProcessor, BlipModel
|
870 |
+
|
871 |
+
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
|
872 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
873 |
+
|
874 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
875 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
876 |
+
|
877 |
+
>>> inputs = processor(
|
878 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
879 |
+
... )
|
880 |
+
|
881 |
+
>>> outputs = model(**inputs)
|
882 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
883 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
884 |
+
```"""
|
885 |
+
# Use BLIP model's config for some fields (if specified) instead of those of vision & text components.
|
886 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
887 |
+
output_hidden_states = (
|
888 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
889 |
+
)
|
890 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
891 |
+
|
892 |
+
vision_outputs = self.vision_model(
|
893 |
+
pixel_values=pixel_values,
|
894 |
+
pixel_masks=pixel_masks,
|
895 |
+
output_attentions=output_attentions,
|
896 |
+
output_hidden_states=output_hidden_states,
|
897 |
+
return_dict=return_dict,
|
898 |
+
)
|
899 |
+
|
900 |
+
text_outputs = self.text_model(
|
901 |
+
input_ids=input_ids,
|
902 |
+
attention_mask=attention_mask,
|
903 |
+
position_ids=position_ids,
|
904 |
+
output_attentions=output_attentions,
|
905 |
+
output_hidden_states=output_hidden_states,
|
906 |
+
return_dict=return_dict,
|
907 |
+
)
|
908 |
+
|
909 |
+
image_embeds = vision_outputs[1]
|
910 |
+
image_embeds = self.visual_projection(image_embeds)
|
911 |
+
|
912 |
+
text_embeds = text_outputs[1]
|
913 |
+
text_embeds = self.text_projection(text_embeds)
|
914 |
+
|
915 |
+
# normalized features
|
916 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
917 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
918 |
+
|
919 |
+
# cosine similarity as logits
|
920 |
+
logit_scale = self.logit_scale.exp()
|
921 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
922 |
+
logits_per_image = logits_per_text.t()
|
923 |
+
|
924 |
+
loss = None
|
925 |
+
if return_loss:
|
926 |
+
loss = blip_loss(logits_per_text)
|
927 |
+
|
928 |
+
if not return_dict:
|
929 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
930 |
+
return ((loss,) + output) if loss is not None else output
|
931 |
+
|
932 |
+
return BlipOutput(
|
933 |
+
loss=loss,
|
934 |
+
logits_per_image=logits_per_image,
|
935 |
+
logits_per_text=logits_per_text,
|
936 |
+
text_embeds=text_embeds,
|
937 |
+
image_embeds=image_embeds,
|
938 |
+
text_model_output=text_outputs,
|
939 |
+
vision_model_output=vision_outputs,
|
940 |
+
)
|
941 |
+
|
942 |
+
|
943 |
+
@add_start_docstrings(
|
944 |
+
"""
|
945 |
+
BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
|
946 |
+
`input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
|
947 |
+
the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
|
948 |
+
from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
|
949 |
+
""",
|
950 |
+
BLIP_START_DOCSTRING,
|
951 |
+
)
|
952 |
+
class BlipForConditionalGeneration(BlipPreTrainedModel):
|
953 |
+
config_class = BlipConfig
|
954 |
+
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
|
955 |
+
main_input_name = "pixel_values"
|
956 |
+
|
957 |
+
def __init__(self, config: BlipConfig):
|
958 |
+
super().__init__(config)
|
959 |
+
|
960 |
+
self.vision_model = BlipVisionModel(config.vision_config)
|
961 |
+
|
962 |
+
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
963 |
+
|
964 |
+
self.decoder_input_ids = config.text_config.bos_token_id
|
965 |
+
self.decoder_pad_token_id = config.text_config.pad_token_id
|
966 |
+
|
967 |
+
# Initialize weights and apply final processing
|
968 |
+
self.post_init()
|
969 |
+
|
970 |
+
def get_input_embeddings(self) -> nn.Module:
|
971 |
+
return self.vision_model.embeddings.patch_embedding
|
972 |
+
|
973 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
974 |
+
@replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
|
975 |
+
def forward(
|
976 |
+
self,
|
977 |
+
pixel_values: torch.FloatTensor,
|
978 |
+
input_ids: Optional[torch.LongTensor] = None,
|
979 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
980 |
+
output_attentions: Optional[bool] = None,
|
981 |
+
output_hidden_states: Optional[bool] = None,
|
982 |
+
labels: Optional[torch.LongTensor] = None,
|
983 |
+
return_dict: Optional[bool] = None,
|
984 |
+
) -> Union[Tuple, BlipForConditionalGenerationModelOutput]:
|
985 |
+
r"""
|
986 |
+
Returns:
|
987 |
+
|
988 |
+
Examples:
|
989 |
+
|
990 |
+
```python
|
991 |
+
>>> from PIL import Image
|
992 |
+
>>> import requests
|
993 |
+
>>> from transformers import AutoProcessor, BlipForConditionalGeneration
|
994 |
+
|
995 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
996 |
+
>>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
997 |
+
|
998 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
999 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1000 |
+
>>> text = "A picture of"
|
1001 |
+
|
1002 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1003 |
+
|
1004 |
+
>>> outputs = model(**inputs)
|
1005 |
+
```"""
|
1006 |
+
|
1007 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1008 |
+
|
1009 |
+
vision_outputs = self.vision_model(
|
1010 |
+
pixel_values=pixel_values,
|
1011 |
+
output_attentions=output_attentions,
|
1012 |
+
output_hidden_states=output_hidden_states,
|
1013 |
+
return_dict=return_dict,
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
image_embeds = vision_outputs[0]
|
1017 |
+
|
1018 |
+
outputs = self.text_decoder(
|
1019 |
+
input_ids=input_ids,
|
1020 |
+
attention_mask=attention_mask,
|
1021 |
+
encoder_hidden_states=image_embeds,
|
1022 |
+
labels=labels,
|
1023 |
+
return_dict=return_dict,
|
1024 |
+
reduction="mean",
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
if not return_dict:
|
1028 |
+
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
1029 |
+
return tuple(output for output in outputs if output is not None)
|
1030 |
+
|
1031 |
+
return BlipForConditionalGenerationModelOutput(
|
1032 |
+
loss=outputs.loss,
|
1033 |
+
decoder_logits=outputs.logits,
|
1034 |
+
image_embeds=image_embeds,
|
1035 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
1036 |
+
hidden_states=vision_outputs.hidden_states,
|
1037 |
+
attentions=vision_outputs.attentions,
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
@torch.no_grad()
|
1041 |
+
def generate(
|
1042 |
+
self,
|
1043 |
+
pixel_values: torch.FloatTensor,
|
1044 |
+
pixel_masks: torch.Tensor = None,
|
1045 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1046 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1047 |
+
**generate_kwargs,
|
1048 |
+
) -> torch.LongTensor:
|
1049 |
+
r"""
|
1050 |
+
Overrides *generate* function to be able to use the model as a conditional generator
|
1051 |
+
|
1052 |
+
Parameters:
|
1053 |
+
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
|
1054 |
+
Input image to be processed
|
1055 |
+
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
|
1056 |
+
The sequence used as a prompt for the generation.
|
1057 |
+
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
|
1058 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1059 |
+
|
1060 |
+
|
1061 |
+
Examples:
|
1062 |
+
```python
|
1063 |
+
>>> from PIL import Image
|
1064 |
+
>>> import requests
|
1065 |
+
>>> from transformers import AutoProcessor, BlipForConditionalGeneration
|
1066 |
+
|
1067 |
+
>>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
1068 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
1069 |
+
|
1070 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1071 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1072 |
+
|
1073 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1074 |
+
|
1075 |
+
>>> outputs = model.generate(**inputs)
|
1076 |
+
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
|
1077 |
+
two cats are laying on a couch
|
1078 |
+
```
|
1079 |
+
"""
|
1080 |
+
|
1081 |
+
batch_size = pixel_values.shape[0]
|
1082 |
+
vision_outputs = self.vision_model(
|
1083 |
+
pixel_values=pixel_values,
|
1084 |
+
pixel_masks=pixel_masks,
|
1085 |
+
)
|
1086 |
+
|
1087 |
+
image_embeds = vision_outputs[0]
|
1088 |
+
|
1089 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
|
1090 |
+
|
1091 |
+
if isinstance(input_ids, list):
|
1092 |
+
input_ids = torch.LongTensor(input_ids)
|
1093 |
+
elif input_ids is None:
|
1094 |
+
input_ids = (
|
1095 |
+
torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
|
1096 |
+
.repeat(batch_size, 1)
|
1097 |
+
.to(image_embeds.device)
|
1098 |
+
)
|
1099 |
+
|
1100 |
+
input_ids[:, 0] = self.config.text_config.bos_token_id
|
1101 |
+
attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
|
1102 |
+
|
1103 |
+
outputs = self.text_decoder.generate(
|
1104 |
+
input_ids=input_ids[:, :-1],
|
1105 |
+
eos_token_id=self.config.text_config.sep_token_id,
|
1106 |
+
pad_token_id=self.config.text_config.pad_token_id,
|
1107 |
+
attention_mask=attention_mask,
|
1108 |
+
encoder_hidden_states=image_embeds,
|
1109 |
+
encoder_attention_mask=image_attention_mask,
|
1110 |
+
**generate_kwargs,
|
1111 |
+
)
|
1112 |
+
|
1113 |
+
return outputs
|
1114 |
+
|
1115 |
+
|
1116 |
+
@add_start_docstrings(
|
1117 |
+
"""
|
1118 |
+
BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
|
1119 |
+
decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
|
1120 |
+
with the encoding of the image, and the text decoder will output the answer to the question.
|
1121 |
+
""",
|
1122 |
+
BLIP_START_DOCSTRING,
|
1123 |
+
)
|
1124 |
+
class BlipForQuestionAnswering(BlipPreTrainedModel):
|
1125 |
+
config_class = BlipConfig
|
1126 |
+
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
|
1127 |
+
|
1128 |
+
def __init__(self, config: BlipConfig):
|
1129 |
+
super().__init__(config)
|
1130 |
+
|
1131 |
+
self.vision_model = BlipVisionModel(config.vision_config)
|
1132 |
+
|
1133 |
+
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
|
1134 |
+
|
1135 |
+
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
1136 |
+
|
1137 |
+
self.decoder_pad_token_id = config.text_config.pad_token_id
|
1138 |
+
self.decoder_start_token_id = config.text_config.bos_token_id
|
1139 |
+
|
1140 |
+
# Initialize weights and apply final processing
|
1141 |
+
self.post_init()
|
1142 |
+
|
1143 |
+
def get_input_embeddings(self) -> nn.Module:
|
1144 |
+
return self.vision_model.embeddings.patch_embedding
|
1145 |
+
|
1146 |
+
# Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
|
1147 |
+
def _shift_right(self, input_ids):
|
1148 |
+
pad_token_id = self.decoder_pad_token_id
|
1149 |
+
|
1150 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1151 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1152 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1153 |
+
|
1154 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1155 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
1156 |
+
|
1157 |
+
return shifted_input_ids
|
1158 |
+
|
1159 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
1160 |
+
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
1161 |
+
def forward(
|
1162 |
+
self,
|
1163 |
+
input_ids: torch.LongTensor,
|
1164 |
+
pixel_values: torch.FloatTensor,
|
1165 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1166 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1167 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1168 |
+
output_attentions: Optional[bool] = None,
|
1169 |
+
output_hidden_states: Optional[bool] = None,
|
1170 |
+
labels: Optional[torch.LongTensor] = None,
|
1171 |
+
return_dict: Optional[bool] = None,
|
1172 |
+
) -> Union[Tuple, BlipTextVisionModelOutput]:
|
1173 |
+
r"""
|
1174 |
+
Returns:
|
1175 |
+
|
1176 |
+
Examples:
|
1177 |
+
|
1178 |
+
```python
|
1179 |
+
>>> from PIL import Image
|
1180 |
+
>>> import requests
|
1181 |
+
>>> from transformers import AutoProcessor, BlipForQuestionAnswering
|
1182 |
+
|
1183 |
+
>>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
1184 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
1185 |
+
|
1186 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1187 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1188 |
+
|
1189 |
+
>>> # training
|
1190 |
+
>>> text = "How many cats are in the picture?"
|
1191 |
+
>>> label = "2"
|
1192 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1193 |
+
>>> labels = processor(text=label, return_tensors="pt").input_ids
|
1194 |
+
|
1195 |
+
>>> inputs["labels"] = labels
|
1196 |
+
>>> outputs = model(**inputs)
|
1197 |
+
>>> loss = outputs.loss
|
1198 |
+
>>> loss.backward()
|
1199 |
+
|
1200 |
+
>>> # inference
|
1201 |
+
>>> text = "How many cats are in the picture?"
|
1202 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1203 |
+
>>> outputs = model.generate(**inputs)
|
1204 |
+
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
|
1205 |
+
2
|
1206 |
+
```"""
|
1207 |
+
if labels is None and decoder_input_ids is None:
|
1208 |
+
raise ValueError(
|
1209 |
+
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
|
1210 |
+
" `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
|
1211 |
+
" are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
|
1212 |
+
)
|
1213 |
+
|
1214 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1215 |
+
|
1216 |
+
vision_outputs = self.vision_model(
|
1217 |
+
pixel_values=pixel_values,
|
1218 |
+
output_attentions=output_attentions,
|
1219 |
+
output_hidden_states=output_hidden_states,
|
1220 |
+
return_dict=return_dict,
|
1221 |
+
)
|
1222 |
+
|
1223 |
+
image_embeds = vision_outputs[0]
|
1224 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
|
1225 |
+
|
1226 |
+
question_embeds = self.text_encoder(
|
1227 |
+
input_ids=input_ids,
|
1228 |
+
attention_mask=attention_mask,
|
1229 |
+
encoder_hidden_states=image_embeds,
|
1230 |
+
encoder_attention_mask=image_attention_mask,
|
1231 |
+
return_dict=return_dict,
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
1235 |
+
|
1236 |
+
if labels is not None and decoder_input_ids is None:
|
1237 |
+
# get decoder inputs from shifting lm labels to the right - this is used in training mode
|
1238 |
+
decoder_input_ids = self._shift_right(labels)
|
1239 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1240 |
+
labels = labels.masked_fill(labels == self.decoder_pad_token_id, -100)
|
1241 |
+
|
1242 |
+
answer_output = self.text_decoder(
|
1243 |
+
input_ids=decoder_input_ids,
|
1244 |
+
attention_mask=decoder_attention_mask,
|
1245 |
+
encoder_hidden_states=question_embeds,
|
1246 |
+
encoder_attention_mask=attention_mask,
|
1247 |
+
labels=labels,
|
1248 |
+
return_dict=return_dict,
|
1249 |
+
reduction="mean",
|
1250 |
+
)
|
1251 |
+
|
1252 |
+
if labels is not None:
|
1253 |
+
decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean()
|
1254 |
+
else:
|
1255 |
+
decoder_loss = None
|
1256 |
+
|
1257 |
+
if not return_dict:
|
1258 |
+
outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
1259 |
+
return tuple(output for output in outputs if output is not None)
|
1260 |
+
|
1261 |
+
return BlipTextVisionModelOutput(
|
1262 |
+
loss=decoder_loss,
|
1263 |
+
image_embeds=image_embeds,
|
1264 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
1265 |
+
hidden_states=vision_outputs.hidden_states,
|
1266 |
+
attentions=vision_outputs.attentions,
|
1267 |
+
)
|
1268 |
+
|
1269 |
+
@torch.no_grad()
|
1270 |
+
def generate(
|
1271 |
+
self,
|
1272 |
+
input_ids: torch.LongTensor,
|
1273 |
+
pixel_values: torch.FloatTensor,
|
1274 |
+
pixel_masks: torch.Tensor = None,
|
1275 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1276 |
+
**generate_kwargs,
|
1277 |
+
) -> torch.LongTensor:
|
1278 |
+
r"""
|
1279 |
+
Overrides *generate* function to be able to use the model as a conditional generator
|
1280 |
+
|
1281 |
+
Parameters:
|
1282 |
+
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
|
1283 |
+
The sequence used as a prompt for the generation.
|
1284 |
+
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
|
1285 |
+
Input image to be processed
|
1286 |
+
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
|
1287 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
|
1288 |
+
tokens that are NOT MASKED, `0` for MASKED tokens.
|
1289 |
+
**generate_kwargs:
|
1290 |
+
Additional arguments passed to the *generate* function of the decoder
|
1291 |
+
|
1292 |
+
|
1293 |
+
Examples:
|
1294 |
+
```python
|
1295 |
+
>>> from PIL import Image
|
1296 |
+
>>> import requests
|
1297 |
+
>>> from transformers import AutoProcessor, BlipForQuestionAnswering
|
1298 |
+
|
1299 |
+
>>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
1300 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
1301 |
+
|
1302 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1303 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1304 |
+
>>> text = "How many cats are in the picture?"
|
1305 |
+
|
1306 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1307 |
+
|
1308 |
+
>>> outputs = model.generate(**inputs)
|
1309 |
+
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
|
1310 |
+
2
|
1311 |
+
```
|
1312 |
+
"""
|
1313 |
+
vision_outputs = self.vision_model(
|
1314 |
+
pixel_values=pixel_values,
|
1315 |
+
pixel_masks=pixel_masks
|
1316 |
+
)
|
1317 |
+
|
1318 |
+
image_embeds = vision_outputs[0]
|
1319 |
+
|
1320 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
|
1321 |
+
|
1322 |
+
if isinstance(input_ids, list):
|
1323 |
+
input_ids = torch.LongTensor(input_ids)
|
1324 |
+
|
1325 |
+
question_outputs = self.text_encoder(
|
1326 |
+
input_ids=input_ids,
|
1327 |
+
attention_mask=attention_mask,
|
1328 |
+
encoder_hidden_states=image_embeds,
|
1329 |
+
encoder_attention_mask=image_attention_mask,
|
1330 |
+
return_dict=False,
|
1331 |
+
)
|
1332 |
+
|
1333 |
+
question_embeds = question_outputs[0]
|
1334 |
+
|
1335 |
+
question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)
|
1336 |
+
|
1337 |
+
bos_ids = torch.full(
|
1338 |
+
(question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
|
1339 |
+
)
|
1340 |
+
|
1341 |
+
outputs = self.text_decoder.generate(
|
1342 |
+
input_ids=bos_ids,
|
1343 |
+
eos_token_id=self.config.text_config.sep_token_id,
|
1344 |
+
pad_token_id=self.config.text_config.pad_token_id,
|
1345 |
+
encoder_hidden_states=question_embeds,
|
1346 |
+
encoder_attention_mask=question_attention_mask,
|
1347 |
+
**generate_kwargs,
|
1348 |
+
)
|
1349 |
+
|
1350 |
+
return outputs
|
1351 |
+
|
1352 |
+
|
1353 |
+
@add_start_docstrings(
|
1354 |
+
"""
|
1355 |
+
BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
|
1356 |
+
image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
|
1357 |
+
the image.
|
1358 |
+
""",
|
1359 |
+
BLIP_START_DOCSTRING,
|
1360 |
+
)
|
1361 |
+
class BlipForImageTextRetrieval(BlipPreTrainedModel):
|
1362 |
+
config_class = BlipConfig
|
1363 |
+
|
1364 |
+
def __init__(self, config: BlipConfig):
|
1365 |
+
super().__init__(config)
|
1366 |
+
|
1367 |
+
self.vision_model = BlipVisionModel(config.vision_config)
|
1368 |
+
|
1369 |
+
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
|
1370 |
+
|
1371 |
+
# vision projection layer
|
1372 |
+
self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)
|
1373 |
+
|
1374 |
+
# text projection layer
|
1375 |
+
self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)
|
1376 |
+
|
1377 |
+
# image text matching head
|
1378 |
+
self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
|
1379 |
+
|
1380 |
+
self.decoder_pad_token_id = (
|
1381 |
+
config.text_config.pad_token_id
|
1382 |
+
if not hasattr(config, "decoder_pad_token_id")
|
1383 |
+
else config.decoder_pad_token_id
|
1384 |
+
)
|
1385 |
+
self.decoder_start_token_id = (
|
1386 |
+
config.text_config.bos_token_id
|
1387 |
+
if not hasattr(config, "decoder_start_token_id")
|
1388 |
+
else config.decoder_start_token_id
|
1389 |
+
)
|
1390 |
+
|
1391 |
+
# Initialize weights and apply final processing
|
1392 |
+
self.post_init()
|
1393 |
+
|
1394 |
+
def get_input_embeddings(self) -> nn.Module:
|
1395 |
+
return self.vision_model.embeddings.patch_embedding
|
1396 |
+
|
1397 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
1398 |
+
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
1399 |
+
def forward(
|
1400 |
+
self,
|
1401 |
+
input_ids: torch.LongTensor,
|
1402 |
+
pixel_values: torch.FloatTensor,
|
1403 |
+
use_itm_head: Optional[bool] = True,
|
1404 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1405 |
+
output_attentions: Optional[bool] = None,
|
1406 |
+
output_hidden_states: Optional[bool] = None,
|
1407 |
+
return_dict: Optional[bool] = None,
|
1408 |
+
) -> Union[Tuple, BlipTextVisionModelOutput]:
|
1409 |
+
r"""
|
1410 |
+
Returns:
|
1411 |
+
|
1412 |
+
Examples:
|
1413 |
+
|
1414 |
+
```python
|
1415 |
+
>>> from PIL import Image
|
1416 |
+
>>> import requests
|
1417 |
+
>>> from transformers import AutoProcessor, BlipForImageTextRetrieval
|
1418 |
+
|
1419 |
+
>>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
|
1420 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
|
1421 |
+
|
1422 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1423 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1424 |
+
>>> text = "an image of a cat"
|
1425 |
+
|
1426 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1427 |
+
>>> outputs = model(**inputs)
|
1428 |
+
```
|
1429 |
+
"""
|
1430 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1431 |
+
|
1432 |
+
vision_outputs = self.vision_model(
|
1433 |
+
pixel_values=pixel_values,
|
1434 |
+
output_attentions=output_attentions,
|
1435 |
+
output_hidden_states=output_hidden_states,
|
1436 |
+
return_dict=return_dict,
|
1437 |
+
)
|
1438 |
+
|
1439 |
+
image_embeds = vision_outputs[0]
|
1440 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
|
1441 |
+
|
1442 |
+
if use_itm_head:
|
1443 |
+
question_embeds = self.text_encoder(
|
1444 |
+
input_ids=input_ids,
|
1445 |
+
attention_mask=attention_mask,
|
1446 |
+
encoder_hidden_states=image_embeds,
|
1447 |
+
encoder_attention_mask=image_atts,
|
1448 |
+
return_dict=return_dict,
|
1449 |
+
)
|
1450 |
+
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
1451 |
+
|
1452 |
+
output = self.itm_head(question_embeds[:, 0, :])
|
1453 |
+
else:
|
1454 |
+
question_embeds = self.text_encoder(
|
1455 |
+
input_ids=input_ids,
|
1456 |
+
attention_mask=attention_mask,
|
1457 |
+
return_dict=return_dict,
|
1458 |
+
)
|
1459 |
+
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
1460 |
+
|
1461 |
+
image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
|
1462 |
+
text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)
|
1463 |
+
|
1464 |
+
output = image_feat @ text_feat.t()
|
1465 |
+
|
1466 |
+
if not return_dict:
|
1467 |
+
outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,)
|
1468 |
+
return tuple(output for output in outputs if output is not None)
|
1469 |
+
|
1470 |
+
return BlipImageTextMatchingModelOutput(
|
1471 |
+
itm_score=output,
|
1472 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
1473 |
+
hidden_states=vision_outputs.hidden_states,
|
1474 |
+
attentions=vision_outputs.attentions,
|
1475 |
+
question_embeds=question_embeds,
|
1476 |
+
)
|
captioner/modeling_git.py
ADDED
@@ -0,0 +1,1587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch GIT model."""
|
17 |
+
|
18 |
+
|
19 |
+
import math
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import CrossEntropyLoss
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from transformers.file_utils import ModelOutput
|
30 |
+
from transformers.modeling_outputs import (
|
31 |
+
BaseModelOutput,
|
32 |
+
BaseModelOutputWithPast,
|
33 |
+
BaseModelOutputWithPooling,
|
34 |
+
CausalLMOutputWithPast,
|
35 |
+
)
|
36 |
+
from transformers.modeling_utils import PreTrainedModel
|
37 |
+
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
38 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
39 |
+
from transformers.models.git.configuration_git import GitConfig, GitVisionConfig
|
40 |
+
from .vit_pixel_masks_utils import ViTPatchMaskGenerator
|
41 |
+
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
_CHECKPOINT_FOR_DOC = "microsoft/git-base"
|
46 |
+
_CONFIG_FOR_DOC = "GitConfig"
|
47 |
+
|
48 |
+
GIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
49 |
+
"microsoft/git-base",
|
50 |
+
# See all GIT models at https://huggingface.co/models?filter=git
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
|
56 |
+
class GitVisionModelOutput(ModelOutput):
|
57 |
+
"""
|
58 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
62 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
63 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
64 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
65 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
66 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
67 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
68 |
+
|
69 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
70 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
71 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
72 |
+
sequence_length)`.
|
73 |
+
|
74 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
75 |
+
heads.
|
76 |
+
"""
|
77 |
+
|
78 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
79 |
+
last_hidden_state: torch.FloatTensor = None
|
80 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
81 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
82 |
+
|
83 |
+
|
84 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
85 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
86 |
+
"""
|
87 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
88 |
+
"""
|
89 |
+
bsz, src_len = mask.size()
|
90 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
91 |
+
|
92 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
93 |
+
|
94 |
+
inverted_mask = 1.0 - expanded_mask
|
95 |
+
|
96 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
97 |
+
|
98 |
+
|
99 |
+
class GitEmbeddings(nn.Module):
|
100 |
+
"""Construct the embeddings from word and position embeddings."""
|
101 |
+
|
102 |
+
def __init__(self, config):
|
103 |
+
super().__init__()
|
104 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
105 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
106 |
+
|
107 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
108 |
+
# any TensorFlow checkpoint file
|
109 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
110 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
111 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
112 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
113 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
114 |
+
|
115 |
+
def forward(
|
116 |
+
self,
|
117 |
+
input_ids: Optional[torch.LongTensor] = None,
|
118 |
+
position_ids: Optional[torch.LongTensor] = None,
|
119 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
120 |
+
past_key_values_length: int = 0,
|
121 |
+
) -> torch.Tensor:
|
122 |
+
if input_ids is not None:
|
123 |
+
input_shape = input_ids.size()
|
124 |
+
else:
|
125 |
+
input_shape = inputs_embeds.size()[:-1]
|
126 |
+
|
127 |
+
seq_length = input_shape[1]
|
128 |
+
|
129 |
+
if position_ids is None:
|
130 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
131 |
+
|
132 |
+
if inputs_embeds is None:
|
133 |
+
embeddings = self.word_embeddings(input_ids)
|
134 |
+
else:
|
135 |
+
embeddings = inputs_embeds
|
136 |
+
|
137 |
+
if self.position_embedding_type == "absolute":
|
138 |
+
position_embeddings = self.position_embeddings(position_ids)
|
139 |
+
embeddings += position_embeddings
|
140 |
+
embeddings = self.LayerNorm(embeddings)
|
141 |
+
embeddings = self.dropout(embeddings)
|
142 |
+
return embeddings
|
143 |
+
|
144 |
+
|
145 |
+
class GitSelfAttention(nn.Module):
|
146 |
+
def __init__(self, config, position_embedding_type=None):
|
147 |
+
super().__init__()
|
148 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
149 |
+
raise ValueError(
|
150 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
151 |
+
f"heads ({config.num_attention_heads})"
|
152 |
+
)
|
153 |
+
|
154 |
+
self.num_attention_heads = config.num_attention_heads
|
155 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
156 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
157 |
+
self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
158 |
+
if config.num_image_with_embedding is not None:
|
159 |
+
self.image_patch_tokens *= config.num_image_with_embedding
|
160 |
+
|
161 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
162 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
163 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
164 |
+
|
165 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
166 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
167 |
+
config, "position_embedding_type", "absolute"
|
168 |
+
)
|
169 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
170 |
+
self.max_position_embeddings = config.max_position_embeddings
|
171 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
172 |
+
|
173 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
174 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
175 |
+
x = x.view(new_x_shape)
|
176 |
+
return x.permute(0, 2, 1, 3)
|
177 |
+
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
hidden_states: torch.Tensor,
|
181 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
182 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
183 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
184 |
+
output_attentions: Optional[bool] = False,
|
185 |
+
pixel_values_present: Optional[bool] = False,
|
186 |
+
image_token_num: Optional[int] = None
|
187 |
+
) -> Tuple[torch.Tensor]:
|
188 |
+
mixed_query_layer = self.query(hidden_states)
|
189 |
+
if image_token_num is not None:
|
190 |
+
cutoff = image_token_num
|
191 |
+
else:
|
192 |
+
cutoff = self.image_patch_tokens if pixel_values_present else 0
|
193 |
+
if past_key_value is not None:
|
194 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
195 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
196 |
+
key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2)
|
197 |
+
value_layer = torch.cat(
|
198 |
+
[value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
202 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
203 |
+
|
204 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
205 |
+
|
206 |
+
use_cache = past_key_value is not None
|
207 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
208 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
209 |
+
# key/value_states (first "if" case)
|
210 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
211 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
212 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
213 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
214 |
+
# NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
|
215 |
+
past_key_value = (
|
216 |
+
key_layer[:, :, cutoff:, :],
|
217 |
+
value_layer[:, :, cutoff:, :],
|
218 |
+
)
|
219 |
+
|
220 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
221 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
222 |
+
|
223 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
224 |
+
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
225 |
+
if use_cache:
|
226 |
+
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
227 |
+
-1, 1
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
231 |
+
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
232 |
+
distance = position_ids_l - position_ids_r
|
233 |
+
|
234 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
235 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
236 |
+
|
237 |
+
if self.position_embedding_type == "relative_key":
|
238 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
239 |
+
attention_scores = attention_scores + relative_position_scores
|
240 |
+
elif self.position_embedding_type == "relative_key_query":
|
241 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
242 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
243 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
244 |
+
|
245 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
246 |
+
if attention_mask is not None:
|
247 |
+
# Apply the attention mask is (precomputed for all layers in GitModel forward() function)
|
248 |
+
attention_scores = attention_scores + attention_mask
|
249 |
+
|
250 |
+
# Normalize the attention scores to probabilities.
|
251 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
252 |
+
|
253 |
+
# This is actually dropping out entire tokens to attend to, which might
|
254 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
255 |
+
attention_probs = self.dropout(attention_probs)
|
256 |
+
|
257 |
+
# Mask heads if we want to
|
258 |
+
if head_mask is not None:
|
259 |
+
attention_probs = attention_probs * head_mask
|
260 |
+
|
261 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
262 |
+
|
263 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
264 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
265 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
266 |
+
|
267 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
268 |
+
|
269 |
+
outputs = outputs + (past_key_value,)
|
270 |
+
return outputs
|
271 |
+
|
272 |
+
|
273 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
|
274 |
+
class GitSelfOutput(nn.Module):
|
275 |
+
def __init__(self, config):
|
276 |
+
super().__init__()
|
277 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
278 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
279 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
280 |
+
|
281 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
282 |
+
hidden_states = self.dense(hidden_states)
|
283 |
+
hidden_states = self.dropout(hidden_states)
|
284 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
285 |
+
return hidden_states
|
286 |
+
|
287 |
+
|
288 |
+
class GitAttention(nn.Module):
|
289 |
+
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git
|
290 |
+
def __init__(self, config, position_embedding_type=None):
|
291 |
+
super().__init__()
|
292 |
+
self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
|
293 |
+
self.output = GitSelfOutput(config)
|
294 |
+
self.pruned_heads = set()
|
295 |
+
|
296 |
+
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
|
297 |
+
def prune_heads(self, heads):
|
298 |
+
if len(heads) == 0:
|
299 |
+
return
|
300 |
+
heads, index = find_pruneable_heads_and_indices(
|
301 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
302 |
+
)
|
303 |
+
|
304 |
+
# Prune linear layers
|
305 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
306 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
307 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
308 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
309 |
+
|
310 |
+
# Update hyper params and store pruned heads
|
311 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
312 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
313 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
314 |
+
|
315 |
+
def forward(
|
316 |
+
self,
|
317 |
+
hidden_states: torch.Tensor,
|
318 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
319 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
320 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
321 |
+
output_attentions: Optional[bool] = False,
|
322 |
+
pixel_values_present: Optional[bool] = False,
|
323 |
+
image_token_num: Optional[int] = None
|
324 |
+
) -> Tuple[torch.Tensor]:
|
325 |
+
self_outputs = self.self(
|
326 |
+
hidden_states,
|
327 |
+
attention_mask,
|
328 |
+
head_mask,
|
329 |
+
past_key_value,
|
330 |
+
output_attentions,
|
331 |
+
pixel_values_present,
|
332 |
+
image_token_num
|
333 |
+
)
|
334 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
335 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
336 |
+
return outputs
|
337 |
+
|
338 |
+
|
339 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate
|
340 |
+
class GitIntermediate(nn.Module):
|
341 |
+
def __init__(self, config):
|
342 |
+
super().__init__()
|
343 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
344 |
+
if isinstance(config.hidden_act, str):
|
345 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
346 |
+
else:
|
347 |
+
self.intermediate_act_fn = config.hidden_act
|
348 |
+
|
349 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
350 |
+
hidden_states = self.dense(hidden_states)
|
351 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
352 |
+
return hidden_states
|
353 |
+
|
354 |
+
|
355 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput
|
356 |
+
class GitOutput(nn.Module):
|
357 |
+
def __init__(self, config):
|
358 |
+
super().__init__()
|
359 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
360 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
361 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
362 |
+
|
363 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
364 |
+
hidden_states = self.dense(hidden_states)
|
365 |
+
hidden_states = self.dropout(hidden_states)
|
366 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
367 |
+
return hidden_states
|
368 |
+
|
369 |
+
|
370 |
+
class GitLayer(nn.Module):
|
371 |
+
def __init__(self, config):
|
372 |
+
super().__init__()
|
373 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
374 |
+
self.seq_len_dim = 1
|
375 |
+
self.attention = GitAttention(config)
|
376 |
+
self.intermediate = GitIntermediate(config)
|
377 |
+
self.output = GitOutput(config)
|
378 |
+
|
379 |
+
def forward(
|
380 |
+
self,
|
381 |
+
hidden_states: torch.Tensor,
|
382 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
383 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
384 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
385 |
+
output_attentions: Optional[bool] = False,
|
386 |
+
pixel_values_present: Optional[bool] = False,
|
387 |
+
image_token_num: Optional[bool] = None,
|
388 |
+
) -> Tuple[torch.Tensor]:
|
389 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
390 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
391 |
+
self_attention_outputs = self.attention(
|
392 |
+
hidden_states,
|
393 |
+
attention_mask,
|
394 |
+
head_mask,
|
395 |
+
output_attentions=output_attentions,
|
396 |
+
past_key_value=self_attn_past_key_value,
|
397 |
+
pixel_values_present=pixel_values_present,
|
398 |
+
image_token_num=image_token_num
|
399 |
+
)
|
400 |
+
attention_output = self_attention_outputs[0]
|
401 |
+
|
402 |
+
# if decoder, the last output is tuple of self-attn cache
|
403 |
+
outputs = self_attention_outputs[1:-1]
|
404 |
+
present_key_value = self_attention_outputs[-1]
|
405 |
+
|
406 |
+
layer_output = apply_chunking_to_forward(
|
407 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
408 |
+
)
|
409 |
+
outputs = (layer_output,) + outputs
|
410 |
+
|
411 |
+
# if decoder, return the attn key/values as the last output
|
412 |
+
outputs = outputs + (present_key_value,)
|
413 |
+
|
414 |
+
return outputs
|
415 |
+
|
416 |
+
def feed_forward_chunk(self, attention_output):
|
417 |
+
intermediate_output = self.intermediate(attention_output)
|
418 |
+
layer_output = self.output(intermediate_output, attention_output)
|
419 |
+
return layer_output
|
420 |
+
|
421 |
+
|
422 |
+
class GitEncoder(nn.Module):
|
423 |
+
# Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Git
|
424 |
+
def __init__(self, config):
|
425 |
+
super().__init__()
|
426 |
+
self.config = config
|
427 |
+
self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)])
|
428 |
+
self.gradient_checkpointing = False
|
429 |
+
|
430 |
+
def forward(
|
431 |
+
self,
|
432 |
+
hidden_states: torch.Tensor,
|
433 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
434 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
435 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
436 |
+
use_cache: Optional[bool] = None,
|
437 |
+
output_attentions: Optional[bool] = False,
|
438 |
+
output_hidden_states: Optional[bool] = False,
|
439 |
+
pixel_values_present: Optional[bool] = False,
|
440 |
+
image_token_num: Optional[int] = None,
|
441 |
+
return_dict: Optional[bool] = True,
|
442 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
|
443 |
+
if self.gradient_checkpointing and self.training:
|
444 |
+
if use_cache:
|
445 |
+
logger.warning_once(
|
446 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
447 |
+
)
|
448 |
+
use_cache = False
|
449 |
+
|
450 |
+
all_hidden_states = () if output_hidden_states else None
|
451 |
+
all_self_attentions = () if output_attentions else None
|
452 |
+
|
453 |
+
next_decoder_cache = () if use_cache else None
|
454 |
+
for i, layer_module in enumerate(self.layer):
|
455 |
+
if output_hidden_states:
|
456 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
457 |
+
|
458 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
459 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
460 |
+
|
461 |
+
if self.gradient_checkpointing and self.training:
|
462 |
+
|
463 |
+
def create_custom_forward(module):
|
464 |
+
def custom_forward(*inputs):
|
465 |
+
return module(*inputs, past_key_value, output_attentions)
|
466 |
+
|
467 |
+
return custom_forward
|
468 |
+
|
469 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
470 |
+
create_custom_forward(layer_module),
|
471 |
+
hidden_states,
|
472 |
+
attention_mask,
|
473 |
+
layer_head_mask,
|
474 |
+
)
|
475 |
+
else:
|
476 |
+
layer_outputs = layer_module(
|
477 |
+
hidden_states,
|
478 |
+
attention_mask,
|
479 |
+
layer_head_mask,
|
480 |
+
past_key_value,
|
481 |
+
output_attentions,
|
482 |
+
pixel_values_present,
|
483 |
+
image_token_num,
|
484 |
+
|
485 |
+
)
|
486 |
+
|
487 |
+
hidden_states = layer_outputs[0]
|
488 |
+
if use_cache:
|
489 |
+
next_decoder_cache += (layer_outputs[-1],)
|
490 |
+
if output_attentions:
|
491 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
492 |
+
|
493 |
+
if output_hidden_states:
|
494 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
495 |
+
|
496 |
+
if not return_dict:
|
497 |
+
return tuple(
|
498 |
+
v
|
499 |
+
for v in [
|
500 |
+
hidden_states,
|
501 |
+
next_decoder_cache,
|
502 |
+
all_hidden_states,
|
503 |
+
all_self_attentions,
|
504 |
+
]
|
505 |
+
if v is not None
|
506 |
+
)
|
507 |
+
return BaseModelOutputWithPast(
|
508 |
+
last_hidden_state=hidden_states,
|
509 |
+
past_key_values=next_decoder_cache,
|
510 |
+
hidden_states=all_hidden_states,
|
511 |
+
attentions=all_self_attentions,
|
512 |
+
)
|
513 |
+
|
514 |
+
|
515 |
+
class GitPreTrainedModel(PreTrainedModel):
|
516 |
+
"""
|
517 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
518 |
+
models.
|
519 |
+
"""
|
520 |
+
|
521 |
+
config_class = GitConfig
|
522 |
+
base_model_prefix = "git"
|
523 |
+
supports_gradient_checkpointing = True
|
524 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
525 |
+
|
526 |
+
def _init_weights(self, module):
|
527 |
+
"""Initialize the weights"""
|
528 |
+
if isinstance(module, GitVisionEmbeddings):
|
529 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
|
530 |
+
nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
|
531 |
+
nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
|
532 |
+
if isinstance(module, nn.Linear):
|
533 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
534 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
535 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
536 |
+
if module.bias is not None:
|
537 |
+
module.bias.data.zero_()
|
538 |
+
elif isinstance(module, nn.Embedding):
|
539 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
540 |
+
if module.padding_idx is not None:
|
541 |
+
module.weight.data[module.padding_idx].zero_()
|
542 |
+
elif isinstance(module, nn.LayerNorm):
|
543 |
+
module.bias.data.zero_()
|
544 |
+
module.weight.data.fill_(1.0)
|
545 |
+
|
546 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
547 |
+
if isinstance(module, (GitEncoder, GitVisionEncoder)):
|
548 |
+
module.gradient_checkpointing = value
|
549 |
+
|
550 |
+
|
551 |
+
GIT_START_DOCSTRING = r"""
|
552 |
+
|
553 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
554 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
555 |
+
etc.)
|
556 |
+
|
557 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
558 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
559 |
+
and behavior.
|
560 |
+
|
561 |
+
Parameters:
|
562 |
+
config ([`GitConfig`]): Model configuration class with all the parameters of the model.
|
563 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
564 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
565 |
+
"""
|
566 |
+
|
567 |
+
GIT_INPUTS_DOCSTRING = r"""
|
568 |
+
Args:
|
569 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
570 |
+
Indices of input sequence tokens in the vocabulary.
|
571 |
+
|
572 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
573 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
574 |
+
|
575 |
+
[What are input IDs?](../glossary#input-ids)
|
576 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
577 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
578 |
+
|
579 |
+
- 1 for tokens that are **not masked**,
|
580 |
+
- 0 for tokens that are **masked**.
|
581 |
+
|
582 |
+
[What are attention masks?](../glossary#attention-mask)
|
583 |
+
|
584 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
585 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
586 |
+
config.max_position_embeddings - 1]`.
|
587 |
+
|
588 |
+
[What are position IDs?](../glossary#position-ids)
|
589 |
+
|
590 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
591 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
592 |
+
[`CLIPImageProcessor.__call__`] for details.
|
593 |
+
|
594 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
595 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
596 |
+
|
597 |
+
- 1 indicates the head is **not masked**,
|
598 |
+
- 0 indicates the head is **masked**.
|
599 |
+
|
600 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
601 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
602 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
603 |
+
model's internal embedding lookup matrix.
|
604 |
+
output_attentions (`bool`, *optional*):
|
605 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
606 |
+
tensors for more detail.
|
607 |
+
output_hidden_states (`bool`, *optional*):
|
608 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
609 |
+
more detail.
|
610 |
+
return_dict (`bool`, *optional*):
|
611 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
612 |
+
"""
|
613 |
+
|
614 |
+
|
615 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
|
616 |
+
class GitVisionEmbeddings(nn.Module):
|
617 |
+
def __init__(self, config: GitVisionConfig):
|
618 |
+
super().__init__()
|
619 |
+
self.config = config
|
620 |
+
self.embed_dim = config.hidden_size
|
621 |
+
self.image_size = config.image_size
|
622 |
+
self.patch_size = config.patch_size
|
623 |
+
|
624 |
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
625 |
+
|
626 |
+
self.patch_embedding = nn.Conv2d(
|
627 |
+
in_channels=config.num_channels,
|
628 |
+
out_channels=self.embed_dim,
|
629 |
+
kernel_size=self.patch_size,
|
630 |
+
stride=self.patch_size,
|
631 |
+
bias=False,
|
632 |
+
)
|
633 |
+
|
634 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
635 |
+
self.num_positions = self.num_patches + 1
|
636 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
637 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
638 |
+
|
639 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
640 |
+
batch_size = pixel_values.shape[0]
|
641 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
642 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
643 |
+
|
644 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
645 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
646 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
647 |
+
return embeddings
|
648 |
+
|
649 |
+
|
650 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP
|
651 |
+
class GitVisionMLP(nn.Module):
|
652 |
+
def __init__(self, config):
|
653 |
+
super().__init__()
|
654 |
+
self.config = config
|
655 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
656 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
657 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
658 |
+
|
659 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
660 |
+
hidden_states = self.fc1(hidden_states)
|
661 |
+
hidden_states = self.activation_fn(hidden_states)
|
662 |
+
hidden_states = self.fc2(hidden_states)
|
663 |
+
return hidden_states
|
664 |
+
|
665 |
+
|
666 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention
|
667 |
+
class GitVisionAttention(nn.Module):
|
668 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
669 |
+
|
670 |
+
def __init__(self, config):
|
671 |
+
super().__init__()
|
672 |
+
self.config = config
|
673 |
+
self.embed_dim = config.hidden_size
|
674 |
+
self.num_heads = config.num_attention_heads
|
675 |
+
self.head_dim = self.embed_dim // self.num_heads
|
676 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
677 |
+
raise ValueError(
|
678 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
679 |
+
f" {self.num_heads})."
|
680 |
+
)
|
681 |
+
self.scale = self.head_dim**-0.5
|
682 |
+
self.dropout = config.attention_dropout
|
683 |
+
|
684 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
685 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
686 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
687 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
688 |
+
|
689 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
690 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
691 |
+
|
692 |
+
def forward(
|
693 |
+
self,
|
694 |
+
hidden_states: torch.Tensor,
|
695 |
+
attention_mask: Optional[torch.Tensor] = None,
|
696 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
697 |
+
output_attentions: Optional[bool] = False,
|
698 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
699 |
+
"""Input shape: Batch x Time x Channel"""
|
700 |
+
|
701 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
702 |
+
|
703 |
+
# get query proj
|
704 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
705 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
706 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
707 |
+
|
708 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
709 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
710 |
+
key_states = key_states.view(*proj_shape)
|
711 |
+
value_states = value_states.view(*proj_shape)
|
712 |
+
|
713 |
+
src_len = key_states.size(1)
|
714 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
715 |
+
|
716 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
717 |
+
raise ValueError(
|
718 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
719 |
+
f" {attn_weights.size()}"
|
720 |
+
)
|
721 |
+
|
722 |
+
# apply the causal_attention_mask first
|
723 |
+
if causal_attention_mask is not None:
|
724 |
+
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
725 |
+
raise ValueError(
|
726 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
727 |
+
f" {causal_attention_mask.size()}"
|
728 |
+
)
|
729 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
730 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
731 |
+
|
732 |
+
if attention_mask is not None:
|
733 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
734 |
+
raise ValueError(
|
735 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
736 |
+
)
|
737 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
738 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
739 |
+
|
740 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
741 |
+
|
742 |
+
if output_attentions:
|
743 |
+
# this operation is a bit akward, but it's required to
|
744 |
+
# make sure that attn_weights keeps its gradient.
|
745 |
+
# In order to do so, attn_weights have to reshaped
|
746 |
+
# twice and have to be reused in the following
|
747 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
748 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
749 |
+
else:
|
750 |
+
attn_weights_reshaped = None
|
751 |
+
|
752 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
753 |
+
|
754 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
755 |
+
|
756 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
757 |
+
raise ValueError(
|
758 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
759 |
+
f" {attn_output.size()}"
|
760 |
+
)
|
761 |
+
|
762 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
763 |
+
attn_output = attn_output.transpose(1, 2)
|
764 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
765 |
+
|
766 |
+
attn_output = self.out_proj(attn_output)
|
767 |
+
|
768 |
+
return attn_output, attn_weights_reshaped
|
769 |
+
|
770 |
+
|
771 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GitVision
|
772 |
+
class GitVisionEncoderLayer(nn.Module):
|
773 |
+
def __init__(self, config: GitVisionConfig):
|
774 |
+
super().__init__()
|
775 |
+
self.embed_dim = config.hidden_size
|
776 |
+
self.self_attn = GitVisionAttention(config)
|
777 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
778 |
+
self.mlp = GitVisionMLP(config)
|
779 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
780 |
+
|
781 |
+
def forward(
|
782 |
+
self,
|
783 |
+
hidden_states: torch.Tensor,
|
784 |
+
attention_mask: torch.Tensor,
|
785 |
+
causal_attention_mask: torch.Tensor,
|
786 |
+
output_attentions: Optional[bool] = False,
|
787 |
+
) -> Tuple[torch.FloatTensor]:
|
788 |
+
"""
|
789 |
+
Args:
|
790 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
791 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
792 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
793 |
+
`(config.encoder_attention_heads,)`.
|
794 |
+
output_attentions (`bool`, *optional*):
|
795 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
796 |
+
returned tensors for more detail.
|
797 |
+
"""
|
798 |
+
residual = hidden_states
|
799 |
+
|
800 |
+
hidden_states = self.layer_norm1(hidden_states)
|
801 |
+
hidden_states, attn_weights = self.self_attn(
|
802 |
+
hidden_states=hidden_states,
|
803 |
+
attention_mask=attention_mask,
|
804 |
+
causal_attention_mask=causal_attention_mask,
|
805 |
+
output_attentions=output_attentions,
|
806 |
+
)
|
807 |
+
hidden_states = residual + hidden_states
|
808 |
+
|
809 |
+
residual = hidden_states
|
810 |
+
hidden_states = self.layer_norm2(hidden_states)
|
811 |
+
hidden_states = self.mlp(hidden_states)
|
812 |
+
hidden_states = residual + hidden_states
|
813 |
+
|
814 |
+
outputs = (hidden_states,)
|
815 |
+
|
816 |
+
if output_attentions:
|
817 |
+
outputs += (attn_weights,)
|
818 |
+
|
819 |
+
return outputs
|
820 |
+
|
821 |
+
|
822 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->GitVision, CLIPConfig
|
823 |
+
class GitVisionEncoder(nn.Module):
|
824 |
+
"""
|
825 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
826 |
+
[`GitVisionEncoderLayer`].
|
827 |
+
|
828 |
+
Args:
|
829 |
+
config: GitVisionConfig
|
830 |
+
"""
|
831 |
+
|
832 |
+
def __init__(self, config: GitVisionConfig):
|
833 |
+
super().__init__()
|
834 |
+
self.config = config
|
835 |
+
self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
836 |
+
self.gradient_checkpointing = False
|
837 |
+
|
838 |
+
def forward(
|
839 |
+
self,
|
840 |
+
inputs_embeds,
|
841 |
+
attention_mask: Optional[torch.Tensor] = None,
|
842 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
843 |
+
output_attentions: Optional[bool] = None,
|
844 |
+
output_hidden_states: Optional[bool] = None,
|
845 |
+
return_dict: Optional[bool] = None,
|
846 |
+
) -> Union[Tuple, BaseModelOutput]:
|
847 |
+
r"""
|
848 |
+
Args:
|
849 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
850 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
851 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
852 |
+
than the model's internal embedding lookup matrix.
|
853 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
854 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
855 |
+
|
856 |
+
- 1 for tokens that are **not masked**,
|
857 |
+
- 0 for tokens that are **masked**.
|
858 |
+
|
859 |
+
[What are attention masks?](../glossary#attention-mask)
|
860 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
861 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
862 |
+
|
863 |
+
- 1 for tokens that are **not masked**,
|
864 |
+
- 0 for tokens that are **masked**.
|
865 |
+
|
866 |
+
[What are attention masks?](../glossary#attention-mask)
|
867 |
+
output_attentions (`bool`, *optional*):
|
868 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
869 |
+
returned tensors for more detail.
|
870 |
+
output_hidden_states (`bool`, *optional*):
|
871 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
872 |
+
for more detail.
|
873 |
+
return_dict (`bool`, *optional*):
|
874 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
875 |
+
"""
|
876 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
877 |
+
output_hidden_states = (
|
878 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
879 |
+
)
|
880 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
881 |
+
|
882 |
+
encoder_states = () if output_hidden_states else None
|
883 |
+
all_attentions = () if output_attentions else None
|
884 |
+
|
885 |
+
hidden_states = inputs_embeds
|
886 |
+
for idx, encoder_layer in enumerate(self.layers):
|
887 |
+
if output_hidden_states:
|
888 |
+
encoder_states = encoder_states + (hidden_states,)
|
889 |
+
if self.gradient_checkpointing and self.training:
|
890 |
+
|
891 |
+
def create_custom_forward(module):
|
892 |
+
def custom_forward(*inputs):
|
893 |
+
return module(*inputs, output_attentions)
|
894 |
+
|
895 |
+
return custom_forward
|
896 |
+
|
897 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
898 |
+
create_custom_forward(encoder_layer),
|
899 |
+
hidden_states,
|
900 |
+
attention_mask,
|
901 |
+
causal_attention_mask,
|
902 |
+
)
|
903 |
+
else:
|
904 |
+
layer_outputs = encoder_layer(
|
905 |
+
hidden_states,
|
906 |
+
attention_mask,
|
907 |
+
causal_attention_mask,
|
908 |
+
output_attentions=output_attentions,
|
909 |
+
)
|
910 |
+
|
911 |
+
hidden_states = layer_outputs[0]
|
912 |
+
|
913 |
+
if output_attentions:
|
914 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
915 |
+
|
916 |
+
if output_hidden_states:
|
917 |
+
encoder_states = encoder_states + (hidden_states,)
|
918 |
+
|
919 |
+
if not return_dict:
|
920 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
921 |
+
return BaseModelOutput(
|
922 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
923 |
+
)
|
924 |
+
|
925 |
+
|
926 |
+
GIT_VISION_INPUTS_DOCSTRING = r"""
|
927 |
+
Args:
|
928 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
929 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
930 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
931 |
+
output_attentions (`bool`, *optional*):
|
932 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
933 |
+
tensors for more detail.
|
934 |
+
output_hidden_states (`bool`, *optional*):
|
935 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
936 |
+
more detail.
|
937 |
+
return_dict (`bool`, *optional*):
|
938 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
939 |
+
"""
|
940 |
+
|
941 |
+
|
942 |
+
class GitVisionTransformer(nn.Module):
|
943 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ with CLIPEncoder->GitVisionEncoder, CLIP->Git
|
944 |
+
def __init__(self, config: GitVisionConfig):
|
945 |
+
super().__init__()
|
946 |
+
self.config = config
|
947 |
+
embed_dim = config.hidden_size
|
948 |
+
|
949 |
+
self.embeddings = GitVisionEmbeddings(config)
|
950 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
951 |
+
self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size)
|
952 |
+
self.encoder = GitVisionEncoder(config)
|
953 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
954 |
+
|
955 |
+
@add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
|
956 |
+
@replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
|
957 |
+
def forward(
|
958 |
+
self,
|
959 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
960 |
+
pixel_masks: Optional[torch.Tensor] = None,
|
961 |
+
output_attentions: Optional[bool] = None,
|
962 |
+
output_hidden_states: Optional[bool] = None,
|
963 |
+
return_dict: Optional[bool] = None,
|
964 |
+
) -> Union[Tuple, BaseModelOutput]:
|
965 |
+
r"""
|
966 |
+
Returns:
|
967 |
+
|
968 |
+
"""
|
969 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
970 |
+
output_hidden_states = (
|
971 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
972 |
+
)
|
973 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
974 |
+
|
975 |
+
if pixel_values is None:
|
976 |
+
raise ValueError("You have to specify pixel_values")
|
977 |
+
|
978 |
+
hidden_states = self.embeddings(pixel_values)
|
979 |
+
B, N, D = hidden_states.shape
|
980 |
+
# print('Before mask:', hidden_states.shape)
|
981 |
+
if pixel_masks is not None:
|
982 |
+
assert pixel_masks.shape[0] == 1
|
983 |
+
patch_masks = self.patch_mask_generator(pixel_masks)
|
984 |
+
# print(patch_masks.shape)
|
985 |
+
patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states)
|
986 |
+
hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D)
|
987 |
+
# print('After mask:', hidden_states.shape)
|
988 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
989 |
+
|
990 |
+
encoder_outputs = self.encoder(
|
991 |
+
inputs_embeds=hidden_states,
|
992 |
+
output_attentions=output_attentions,
|
993 |
+
output_hidden_states=output_hidden_states,
|
994 |
+
return_dict=return_dict,
|
995 |
+
)
|
996 |
+
|
997 |
+
last_hidden_state = encoder_outputs[0]
|
998 |
+
|
999 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
1000 |
+
|
1001 |
+
if not return_dict:
|
1002 |
+
return (last_hidden_state,) + encoder_outputs[1:]
|
1003 |
+
|
1004 |
+
return BaseModelOutput(
|
1005 |
+
last_hidden_state=last_hidden_state,
|
1006 |
+
hidden_states=encoder_outputs.hidden_states,
|
1007 |
+
attentions=encoder_outputs.attentions,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
|
1011 |
+
@add_start_docstrings(
|
1012 |
+
"""The vision model from CLIP, used in GIT, without any head or projection on top.""",
|
1013 |
+
GIT_START_DOCSTRING,
|
1014 |
+
)
|
1015 |
+
class GitVisionModel(GitPreTrainedModel):
|
1016 |
+
config_class = GitVisionConfig
|
1017 |
+
main_input_name = "pixel_values"
|
1018 |
+
|
1019 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
|
1020 |
+
def __init__(self, config: GitVisionConfig):
|
1021 |
+
super().__init__(config)
|
1022 |
+
self.vision_model = GitVisionTransformer(config)
|
1023 |
+
# Initialize weights and apply final processing
|
1024 |
+
self.post_init()
|
1025 |
+
|
1026 |
+
def get_input_embeddings(self) -> nn.Module:
|
1027 |
+
return self.vision_model.embeddings.patch_embedding
|
1028 |
+
|
1029 |
+
@add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
|
1030 |
+
@replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
|
1031 |
+
def forward(
|
1032 |
+
self,
|
1033 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1034 |
+
pixel_masks: Optional[torch.Tensor] = None,
|
1035 |
+
output_attentions: Optional[bool] = None,
|
1036 |
+
output_hidden_states: Optional[bool] = None,
|
1037 |
+
return_dict: Optional[bool] = None,
|
1038 |
+
) -> Union[Tuple, BaseModelOutput]:
|
1039 |
+
r"""
|
1040 |
+
Returns:
|
1041 |
+
|
1042 |
+
Examples:
|
1043 |
+
|
1044 |
+
```python
|
1045 |
+
>>> from PIL import Image
|
1046 |
+
>>> import requests
|
1047 |
+
>>> from transformers import AutoProcessor, GitVisionModel
|
1048 |
+
|
1049 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
1050 |
+
>>> model = GitVisionModel.from_pretrained("microsoft/git-base")
|
1051 |
+
|
1052 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1053 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1054 |
+
|
1055 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1056 |
+
|
1057 |
+
>>> outputs = model(**inputs)
|
1058 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
1059 |
+
```"""
|
1060 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1061 |
+
|
1062 |
+
return self.vision_model(
|
1063 |
+
pixel_values=pixel_values,
|
1064 |
+
pixel_masks=pixel_masks,
|
1065 |
+
output_attentions=output_attentions,
|
1066 |
+
output_hidden_states=output_hidden_states,
|
1067 |
+
return_dict=return_dict,
|
1068 |
+
)
|
1069 |
+
|
1070 |
+
|
1071 |
+
class GitProjection(nn.Module):
|
1072 |
+
def __init__(self, config: GitConfig):
|
1073 |
+
super().__init__()
|
1074 |
+
self.config = config
|
1075 |
+
self.visual_projection = nn.Sequential(
|
1076 |
+
nn.Linear(config.vision_config.hidden_size, config.hidden_size),
|
1077 |
+
nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
|
1081 |
+
return self.visual_projection(embeddings)
|
1082 |
+
|
1083 |
+
|
1084 |
+
@add_start_docstrings(
|
1085 |
+
"The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states"
|
1086 |
+
" without any specific head on top.",
|
1087 |
+
GIT_START_DOCSTRING,
|
1088 |
+
)
|
1089 |
+
class GitModel(GitPreTrainedModel):
|
1090 |
+
def __init__(self, config):
|
1091 |
+
super().__init__(config)
|
1092 |
+
self.config = config
|
1093 |
+
|
1094 |
+
self.embeddings = GitEmbeddings(config)
|
1095 |
+
self.image_encoder = GitVisionModel(config.vision_config)
|
1096 |
+
self.encoder = GitEncoder(config)
|
1097 |
+
|
1098 |
+
self.visual_projection = GitProjection(config)
|
1099 |
+
|
1100 |
+
if config.num_image_with_embedding is not None:
|
1101 |
+
self.img_temperal_embedding = nn.ParameterList(
|
1102 |
+
nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
|
1103 |
+
for _ in range(config.num_image_with_embedding)
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
# Initialize weights and apply final processing
|
1107 |
+
self.post_init()
|
1108 |
+
|
1109 |
+
def get_input_embeddings(self):
|
1110 |
+
return self.embeddings.word_embeddings
|
1111 |
+
|
1112 |
+
def set_input_embeddings(self, value):
|
1113 |
+
self.embeddings.word_embeddings = value
|
1114 |
+
|
1115 |
+
def _prune_heads(self, heads_to_prune):
|
1116 |
+
"""
|
1117 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
1118 |
+
class PreTrainedModel
|
1119 |
+
"""
|
1120 |
+
for layer, heads in heads_to_prune.items():
|
1121 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
1122 |
+
|
1123 |
+
def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
1124 |
+
# Default mask is for forward direction. Flip for backward direction.
|
1125 |
+
mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
|
1126 |
+
mask = mask.masked_fill(mask == 1, float("-inf"))
|
1127 |
+
return mask
|
1128 |
+
|
1129 |
+
def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
|
1130 |
+
num_tgt = tgt.shape[1]
|
1131 |
+
num_memory = memory.shape[1]
|
1132 |
+
device = tgt.device
|
1133 |
+
dtype = tgt.dtype
|
1134 |
+
top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
|
1135 |
+
top_right = torch.full(
|
1136 |
+
(num_memory, num_tgt + past_key_values_length),
|
1137 |
+
float("-inf"),
|
1138 |
+
device=tgt.device,
|
1139 |
+
dtype=dtype,
|
1140 |
+
)
|
1141 |
+
bottom_left = torch.zeros(
|
1142 |
+
(num_tgt, num_memory),
|
1143 |
+
dtype=dtype,
|
1144 |
+
device=tgt_mask.device,
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
if past_key_values_length > 0:
|
1148 |
+
tgt_mask = torch.zeros(
|
1149 |
+
(tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
|
1150 |
+
dtype=dtype,
|
1151 |
+
device=tgt_mask.device,
|
1152 |
+
)
|
1153 |
+
|
1154 |
+
left = torch.cat((top_left, bottom_left), dim=0)
|
1155 |
+
right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
|
1156 |
+
|
1157 |
+
full_attention_mask = torch.cat((left, right), dim=1)[None, :]
|
1158 |
+
|
1159 |
+
if memory_key_padding_mask is None:
|
1160 |
+
memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
|
1161 |
+
# if it is False, it means valid. That is, it is not a padding
|
1162 |
+
if memory_key_padding_mask.dtype != torch.bool:
|
1163 |
+
raise ValueError("Memory key padding mask must be a boolean tensor.")
|
1164 |
+
zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
|
1165 |
+
zero_negative_infinity[memory_key_padding_mask] = float("-inf")
|
1166 |
+
full_attention_mask = full_attention_mask.expand(
|
1167 |
+
(memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
|
1168 |
+
)
|
1169 |
+
full_attention_mask = full_attention_mask.clone()
|
1170 |
+
origin_left = full_attention_mask[:, :, :num_memory]
|
1171 |
+
update = zero_negative_infinity[:, None, :]
|
1172 |
+
full_attention_mask[:, :, :num_memory] = origin_left + update
|
1173 |
+
|
1174 |
+
# add axis for multi-head
|
1175 |
+
full_attention_mask = full_attention_mask[:, None, :, :]
|
1176 |
+
|
1177 |
+
return full_attention_mask
|
1178 |
+
|
1179 |
+
@add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1180 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
1181 |
+
def forward(
|
1182 |
+
self,
|
1183 |
+
input_ids: Optional[torch.Tensor] = None,
|
1184 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1185 |
+
position_ids: Optional[torch.Tensor] = None,
|
1186 |
+
pixel_values: Optional[torch.Tensor] = None,
|
1187 |
+
pixel_masks: Optional[torch.Tensor] = None,
|
1188 |
+
head_mask: Optional[torch.Tensor] = None,
|
1189 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1190 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1191 |
+
use_cache: Optional[bool] = None,
|
1192 |
+
output_attentions: Optional[bool] = None,
|
1193 |
+
output_hidden_states: Optional[bool] = None,
|
1194 |
+
return_dict: Optional[bool] = None,
|
1195 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
1196 |
+
r"""
|
1197 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1198 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1199 |
+
|
1200 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1201 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1202 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1203 |
+
use_cache (`bool`, *optional*):
|
1204 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1205 |
+
`past_key_values`).
|
1206 |
+
|
1207 |
+
Returns:
|
1208 |
+
|
1209 |
+
Examples:
|
1210 |
+
|
1211 |
+
```python
|
1212 |
+
>>> from transformers import AutoProcessor, AutoModel
|
1213 |
+
>>> import requests
|
1214 |
+
>>> from PIL import Image
|
1215 |
+
|
1216 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
1217 |
+
>>> model = AutoModel.from_pretrained("microsoft/git-base")
|
1218 |
+
|
1219 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1220 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1221 |
+
|
1222 |
+
>>> text = "this is an image of two cats"
|
1223 |
+
|
1224 |
+
>>> inputs = processor(text, images=image, return_tensors="pt")
|
1225 |
+
|
1226 |
+
>>> outputs = model(**inputs)
|
1227 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
1228 |
+
```"""
|
1229 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1230 |
+
output_hidden_states = (
|
1231 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1232 |
+
)
|
1233 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1234 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1235 |
+
|
1236 |
+
if input_ids is not None and inputs_embeds is not None:
|
1237 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
1238 |
+
elif input_ids is not None:
|
1239 |
+
input_shape = input_ids.size()
|
1240 |
+
elif inputs_embeds is not None:
|
1241 |
+
input_shape = inputs_embeds.size()[:-1]
|
1242 |
+
else:
|
1243 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
1244 |
+
|
1245 |
+
seq_length = input_shape[1]
|
1246 |
+
|
1247 |
+
# past_key_values_length
|
1248 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
1249 |
+
|
1250 |
+
# Prepare head mask if needed
|
1251 |
+
# 1.0 in head_mask indicate we keep the head
|
1252 |
+
# attention_probs has shape bsz x n_heads x N x N
|
1253 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
1254 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1255 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
1256 |
+
|
1257 |
+
projected_visual_features = None
|
1258 |
+
if pixel_values is not None:
|
1259 |
+
if pixel_values.ndim == 4:
|
1260 |
+
# here we assume pixel_values is of shape (batch_size, num_channels, height, width)
|
1261 |
+
visual_features = self.image_encoder(pixel_values=pixel_values, pixel_masks=pixel_masks).last_hidden_state
|
1262 |
+
|
1263 |
+
elif pixel_values.ndim == 5:
|
1264 |
+
# here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
|
1265 |
+
visual_features = []
|
1266 |
+
for frame_idx in range(pixel_values.shape[1]):
|
1267 |
+
visual_features_frame = self.image_encoder(pixel_values[:, frame_idx, :, :]).last_hidden_state
|
1268 |
+
visual_features_frame += self.img_temperal_embedding[frame_idx]
|
1269 |
+
visual_features.append(visual_features_frame)
|
1270 |
+
|
1271 |
+
# finally, concatenate all features along sequence dimension
|
1272 |
+
visual_features = torch.cat(visual_features, dim=1)
|
1273 |
+
|
1274 |
+
else:
|
1275 |
+
raise ValueError("pixel_values must be of rank 4 or 5")
|
1276 |
+
|
1277 |
+
projected_visual_features = self.visual_projection(visual_features)
|
1278 |
+
image_token_num = projected_visual_features.shape[1]
|
1279 |
+
embedding_output = self.embeddings(
|
1280 |
+
input_ids=input_ids,
|
1281 |
+
position_ids=position_ids,
|
1282 |
+
inputs_embeds=inputs_embeds,
|
1283 |
+
past_key_values_length=past_key_values_length,
|
1284 |
+
)
|
1285 |
+
|
1286 |
+
if projected_visual_features is None:
|
1287 |
+
projected_visual_features = torch.zeros(
|
1288 |
+
(embedding_output.shape[0], 0, embedding_output.shape[2]),
|
1289 |
+
dtype=embedding_output.dtype,
|
1290 |
+
device=embedding_output.device,
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
# Repeat visual features to match embedding batch size.
|
1294 |
+
projected_visual_features = projected_visual_features.repeat(
|
1295 |
+
embedding_output.size(0) // projected_visual_features.size(0), 1, 1
|
1296 |
+
)
|
1297 |
+
|
1298 |
+
# concatenate patch token and text token embeddings
|
1299 |
+
hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
|
1300 |
+
|
1301 |
+
# By default, an additive causal mask is created
|
1302 |
+
# for masking the future (one direction).
|
1303 |
+
tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
|
1304 |
+
|
1305 |
+
# Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
|
1306 |
+
combined_attention_mask = self.create_attention_mask(
|
1307 |
+
tgt=embedding_output,
|
1308 |
+
memory=projected_visual_features,
|
1309 |
+
tgt_mask=tgt_mask,
|
1310 |
+
past_key_values_length=past_key_values_length,
|
1311 |
+
)
|
1312 |
+
|
1313 |
+
if attention_mask is not None:
|
1314 |
+
# if the user provides an attention mask, we add it to the default one
|
1315 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1316 |
+
expanded_attn_mask = _expand_mask(attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]).to(
|
1317 |
+
embedding_output.device
|
1318 |
+
)
|
1319 |
+
if past_key_values_length > 0:
|
1320 |
+
expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
|
1321 |
+
else:
|
1322 |
+
combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
|
1323 |
+
|
1324 |
+
encoder_outputs = self.encoder(
|
1325 |
+
hidden_states,
|
1326 |
+
attention_mask=combined_attention_mask,
|
1327 |
+
head_mask=head_mask,
|
1328 |
+
past_key_values=past_key_values,
|
1329 |
+
use_cache=use_cache,
|
1330 |
+
output_attentions=output_attentions,
|
1331 |
+
output_hidden_states=output_hidden_states,
|
1332 |
+
return_dict=return_dict,
|
1333 |
+
pixel_values_present=pixel_values is not None,
|
1334 |
+
image_token_num=image_token_num
|
1335 |
+
)
|
1336 |
+
sequence_output = encoder_outputs[0]
|
1337 |
+
|
1338 |
+
if not return_dict:
|
1339 |
+
return (sequence_output,) + encoder_outputs[1:]
|
1340 |
+
|
1341 |
+
return BaseModelOutputWithPast(
|
1342 |
+
last_hidden_state=sequence_output,
|
1343 |
+
past_key_values=encoder_outputs.past_key_values,
|
1344 |
+
hidden_states=encoder_outputs.hidden_states,
|
1345 |
+
attentions=encoder_outputs.attentions,
|
1346 |
+
)
|
1347 |
+
|
1348 |
+
|
1349 |
+
@add_start_docstrings(
|
1350 |
+
"""GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING
|
1351 |
+
)
|
1352 |
+
class GitForCausalLM(GitPreTrainedModel):
|
1353 |
+
def __init__(self, config):
|
1354 |
+
super().__init__(config)
|
1355 |
+
|
1356 |
+
self.git = GitModel(config)
|
1357 |
+
self.output = nn.Linear(config.hidden_size, config.vocab_size)
|
1358 |
+
|
1359 |
+
# Initialize weights and apply final processing
|
1360 |
+
self.post_init()
|
1361 |
+
|
1362 |
+
def get_output_embeddings(self):
|
1363 |
+
return self.output
|
1364 |
+
|
1365 |
+
def set_output_embeddings(self, new_embeddings):
|
1366 |
+
self.output = new_embeddings
|
1367 |
+
|
1368 |
+
@add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1369 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1370 |
+
def forward(
|
1371 |
+
self,
|
1372 |
+
input_ids: Optional[torch.Tensor] = None,
|
1373 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1374 |
+
position_ids: Optional[torch.Tensor] = None,
|
1375 |
+
pixel_values: Optional[torch.Tensor] = None,
|
1376 |
+
pixel_masks: Optional[torch.Tensor] = None,
|
1377 |
+
head_mask: Optional[torch.Tensor] = None,
|
1378 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1379 |
+
labels: Optional[torch.Tensor] = None,
|
1380 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
1381 |
+
use_cache: Optional[bool] = None,
|
1382 |
+
output_attentions: Optional[bool] = None,
|
1383 |
+
output_hidden_states: Optional[bool] = None,
|
1384 |
+
return_dict: Optional[bool] = None,
|
1385 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
|
1386 |
+
r"""
|
1387 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1388 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
1389 |
+
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
1390 |
+
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
1391 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1392 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1393 |
+
|
1394 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1395 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1396 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1397 |
+
use_cache (`bool`, *optional*):
|
1398 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1399 |
+
`past_key_values`).
|
1400 |
+
|
1401 |
+
Returns:
|
1402 |
+
|
1403 |
+
Examples:
|
1404 |
+
|
1405 |
+
Image captioning example:
|
1406 |
+
|
1407 |
+
```python
|
1408 |
+
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
1409 |
+
>>> import requests
|
1410 |
+
>>> from PIL import Image
|
1411 |
+
|
1412 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
|
1413 |
+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
|
1414 |
+
|
1415 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1416 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1417 |
+
|
1418 |
+
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
1419 |
+
|
1420 |
+
>>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
|
1421 |
+
>>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
1422 |
+
>>> print(generated_caption)
|
1423 |
+
two cats sleeping on a pink blanket next to remotes.
|
1424 |
+
```
|
1425 |
+
|
1426 |
+
Visual question answering (VQA) example:
|
1427 |
+
|
1428 |
+
```python
|
1429 |
+
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
1430 |
+
>>> from huggingface_hub import hf_hub_download
|
1431 |
+
>>> from PIL import Image
|
1432 |
+
|
1433 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
|
1434 |
+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
|
1435 |
+
|
1436 |
+
>>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
|
1437 |
+
>>> image = Image.open(file_path).convert("RGB")
|
1438 |
+
|
1439 |
+
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
1440 |
+
|
1441 |
+
>>> question = "what does the front of the bus say at the top?"
|
1442 |
+
|
1443 |
+
>>> input_ids = processor(text=question, add_special_tokens=False).input_ids
|
1444 |
+
>>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
|
1445 |
+
>>> input_ids = torch.tensor(input_ids).unsqueeze(0)
|
1446 |
+
|
1447 |
+
>>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
|
1448 |
+
>>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
|
1449 |
+
['what does the front of the bus say at the top? special']
|
1450 |
+
```
|
1451 |
+
|
1452 |
+
Video captioning example:
|
1453 |
+
|
1454 |
+
```python
|
1455 |
+
>>> import av
|
1456 |
+
>>> import numpy as np
|
1457 |
+
>>> from PIL import Image
|
1458 |
+
>>> from huggingface_hub import hf_hub_download
|
1459 |
+
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
1460 |
+
|
1461 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
|
1462 |
+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
|
1463 |
+
|
1464 |
+
>>> # set seed for reproducability
|
1465 |
+
>>> np.random.seed(45)
|
1466 |
+
|
1467 |
+
|
1468 |
+
>>> def read_video_pyav(container, indices):
|
1469 |
+
... '''
|
1470 |
+
... Decode the video with PyAV decoder.
|
1471 |
+
... Args:
|
1472 |
+
... container (`av.container.input.InputContainer`): PyAV container.
|
1473 |
+
... indices (`List[int]`): List of frame indices to decode.
|
1474 |
+
... Returns:
|
1475 |
+
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
1476 |
+
... '''
|
1477 |
+
... frames = []
|
1478 |
+
... container.seek(0)
|
1479 |
+
... start_index = indices[0]
|
1480 |
+
... end_index = indices[-1]
|
1481 |
+
... for i, frame in enumerate(container.decode(video=0)):
|
1482 |
+
... if i > end_index:
|
1483 |
+
... break
|
1484 |
+
... if i >= start_index and i in indices:
|
1485 |
+
... frames.append(frame)
|
1486 |
+
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
1487 |
+
|
1488 |
+
|
1489 |
+
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
|
1490 |
+
... converted_len = int(clip_len * frame_sample_rate)
|
1491 |
+
... end_idx = np.random.randint(converted_len, seg_len)
|
1492 |
+
... start_idx = end_idx - converted_len
|
1493 |
+
... indices = np.linspace(start_idx, end_idx, num=clip_len)
|
1494 |
+
... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
1495 |
+
... return indices
|
1496 |
+
|
1497 |
+
|
1498 |
+
>>> # load video
|
1499 |
+
>>> file_path = hf_hub_download(
|
1500 |
+
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
|
1501 |
+
... )
|
1502 |
+
>>> container = av.open(file_path)
|
1503 |
+
|
1504 |
+
>>> # sample frames
|
1505 |
+
>>> num_frames = model.config.num_image_with_embedding
|
1506 |
+
>>> indices = sample_frame_indices(
|
1507 |
+
... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
|
1508 |
+
... )
|
1509 |
+
>>> frames = read_video_pyav(container, indices)
|
1510 |
+
|
1511 |
+
>>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
|
1512 |
+
|
1513 |
+
>>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
|
1514 |
+
|
1515 |
+
>>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
|
1516 |
+
Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
|
1517 |
+
```
|
1518 |
+
"""
|
1519 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1520 |
+
if labels is not None:
|
1521 |
+
use_cache = False
|
1522 |
+
|
1523 |
+
outputs = self.git(
|
1524 |
+
input_ids,
|
1525 |
+
attention_mask=attention_mask,
|
1526 |
+
position_ids=position_ids,
|
1527 |
+
pixel_values=pixel_values,
|
1528 |
+
pixel_masks=pixel_masks,
|
1529 |
+
head_mask=head_mask,
|
1530 |
+
inputs_embeds=inputs_embeds,
|
1531 |
+
past_key_values=past_key_values,
|
1532 |
+
use_cache=use_cache,
|
1533 |
+
output_attentions=output_attentions,
|
1534 |
+
output_hidden_states=output_hidden_states,
|
1535 |
+
return_dict=return_dict,
|
1536 |
+
)
|
1537 |
+
|
1538 |
+
sequence_output = outputs[0]
|
1539 |
+
logits = self.output(sequence_output)
|
1540 |
+
|
1541 |
+
loss = None
|
1542 |
+
if labels is not None:
|
1543 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1544 |
+
num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
|
1545 |
+
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
|
1546 |
+
labels = labels[:, 1:].contiguous()
|
1547 |
+
loss_fct = CrossEntropyLoss()
|
1548 |
+
loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
1549 |
+
|
1550 |
+
if not return_dict:
|
1551 |
+
output = (logits,) + outputs[1:]
|
1552 |
+
return ((loss,) + output) if loss is not None else output
|
1553 |
+
|
1554 |
+
return CausalLMOutputWithPast(
|
1555 |
+
loss=loss,
|
1556 |
+
logits=logits,
|
1557 |
+
past_key_values=outputs.past_key_values,
|
1558 |
+
hidden_states=outputs.hidden_states,
|
1559 |
+
attentions=outputs.attentions,
|
1560 |
+
)
|
1561 |
+
|
1562 |
+
def prepare_inputs_for_generation(
|
1563 |
+
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
1564 |
+
):
|
1565 |
+
# cut decoder_input_ids if past_key_values is used
|
1566 |
+
if past_key_values is not None:
|
1567 |
+
input_ids = input_ids[:, -1:]
|
1568 |
+
|
1569 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1570 |
+
input_shape = input_ids.shape
|
1571 |
+
if attention_mask is None:
|
1572 |
+
attention_mask = input_ids.new_ones(input_shape)
|
1573 |
+
|
1574 |
+
return {
|
1575 |
+
"input_ids": input_ids,
|
1576 |
+
"attention_mask": attention_mask,
|
1577 |
+
"pixel_values": kwargs.get("pixel_values", None),
|
1578 |
+
"pixel_masks": kwargs.get("pixel_masks", None),
|
1579 |
+
"past_key_values": past_key_values,
|
1580 |
+
"use_cache": use_cache,
|
1581 |
+
}
|
1582 |
+
|
1583 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
1584 |
+
reordered_past = ()
|
1585 |
+
for layer_past in past_key_values:
|
1586 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
1587 |
+
return reordered_past
|
captioner/vit_pixel_masks_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class ViTPatchMaskGenerator(nn.Module):
|
7 |
+
def __init__(self, patch_size) -> None:
|
8 |
+
super(ViTPatchMaskGenerator, self).__init__()
|
9 |
+
self.patch_size = patch_size
|
10 |
+
self.pool = nn.MaxPool2d(kernel_size=patch_size, stride=patch_size)
|
11 |
+
|
12 |
+
def forward(self, pixel_masks):
|
13 |
+
patch_mask = self.pool(pixel_masks)
|
14 |
+
patch_mask = patch_mask.bool().flatten(1)
|
15 |
+
cls_token_mask = patch_mask.new_ones([patch_mask.shape[0], 1]).bool()
|
16 |
+
patch_mask = torch.cat([cls_token_mask, patch_mask], dim=-1)
|
17 |
+
return patch_mask
|
env.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
conda create -n caption_anything python=3.8 -y
|
2 |
+
source activate caption_anything
|
3 |
+
pip install -r requirements.txt
|
4 |
+
# cd [email protected]
|
5 |
+
# wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
6 |
+
|
image_editing_utils.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageFont
|
2 |
+
import copy
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def wrap_text(text, font, max_width):
|
6 |
+
lines = []
|
7 |
+
words = text.split(' ')
|
8 |
+
current_line = ''
|
9 |
+
|
10 |
+
for word in words:
|
11 |
+
if font.getsize(current_line + word)[0] <= max_width:
|
12 |
+
current_line += word + ' '
|
13 |
+
else:
|
14 |
+
lines.append(current_line)
|
15 |
+
current_line = word + ' '
|
16 |
+
|
17 |
+
lines.append(current_line)
|
18 |
+
return lines
|
19 |
+
|
20 |
+
def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.ttf', font_size_ratio=0.025):
|
21 |
+
# Load the image
|
22 |
+
if type(image) == np.ndarray:
|
23 |
+
image = Image.fromarray(image)
|
24 |
+
|
25 |
+
image = copy.deepcopy(image)
|
26 |
+
width, height = image.size
|
27 |
+
|
28 |
+
# Calculate max_text_width and font_size based on image dimensions and total number of characters
|
29 |
+
total_chars = len(text)
|
30 |
+
max_text_width = int(0.4 * width)
|
31 |
+
font_size = int(height * font_size_ratio)
|
32 |
+
|
33 |
+
# Load the font
|
34 |
+
font = ImageFont.truetype(font_path, font_size)
|
35 |
+
|
36 |
+
# Wrap the text to fit within the max_text_width
|
37 |
+
lines = wrap_text(text, font, max_text_width)
|
38 |
+
text_width = max([font.getsize(line)[0] for line in lines])
|
39 |
+
_, text_height = font.getsize(lines[0])
|
40 |
+
text_height = text_height * len(lines)
|
41 |
+
|
42 |
+
# Define bubble frame dimensions
|
43 |
+
padding = 10
|
44 |
+
bubble_width = text_width + 2 * padding
|
45 |
+
bubble_height = text_height + 2 * padding
|
46 |
+
|
47 |
+
# Create a new image for the bubble frame
|
48 |
+
bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 255, 255, 0))
|
49 |
+
|
50 |
+
# Draw the bubble frame on the new image
|
51 |
+
draw = ImageDraw.Draw(bubble)
|
52 |
+
# draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
|
53 |
+
|
54 |
+
# Draw the wrapped text line by line
|
55 |
+
y_text = padding
|
56 |
+
for line in lines:
|
57 |
+
draw.text((padding, y_text), line, font=font, fill=(255, 255, 255, 255))
|
58 |
+
y_text += font.getsize(line)[1]
|
59 |
+
|
60 |
+
# Calculate the bubble frame position
|
61 |
+
x, y = point
|
62 |
+
if x + bubble_width > width:
|
63 |
+
x = width - bubble_width
|
64 |
+
if y + bubble_height > height:
|
65 |
+
y = height - bubble_height
|
66 |
+
|
67 |
+
# Paste the bubble frame onto the image
|
68 |
+
image.paste(bubble, (x, y), bubble)
|
69 |
+
return image
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://download.pytorch.org/whl/cu111/torch-1.10.1%2Bcu111-cp38-cp38-linux_x86_64.whl
|
2 |
+
https://download.pytorch.org/whl/cu111/torchvision-0.11.2%2Bcu111-cp38-cp38-linux_x86_64.whl
|
3 |
+
https://download.pytorch.org/whl/cu111/torchaudio-0.10.1%2Bcu111-cp38-cp38-linux_x86_64.whl
|
4 |
+
openai
|
5 |
+
pillow
|
6 |
+
langchain==0.0.101
|
7 |
+
git+https://github.com/huggingface/transformers.git
|
8 |
+
ftfy
|
9 |
+
regex
|
10 |
+
tqdm
|
11 |
+
git+https://github.com/openai/CLIP.git
|
12 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
13 |
+
opencv-python
|
14 |
+
pycocotools
|
15 |
+
matplotlib
|
16 |
+
onnxruntime
|
17 |
+
onnx
|
18 |
+
https://gradio-builds.s3.amazonaws.com/3e68e5e882a6790ac5b457bd33f4edf9b695af90/gradio-3.24.1-py3-none-any.whl
|
19 |
+
accelerate
|
20 |
+
bitsandbytes
|
segmenter/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from segmenter.base_segmenter import BaseSegmenter
|
2 |
+
|
3 |
+
|
4 |
+
def build_segmenter(type, device, args=None, model=None):
|
5 |
+
if type == 'base':
|
6 |
+
return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
|
7 |
+
else:
|
8 |
+
raise NotImplementedError()
|
segmenter/base_segmenter.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from PIL import Image, ImageDraw, ImageOps
|
5 |
+
import numpy as np
|
6 |
+
from typing import Union
|
7 |
+
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import PIL
|
10 |
+
|
11 |
+
class BaseSegmenter:
|
12 |
+
def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True, model=None):
|
13 |
+
print(f"Initializing BaseSegmenter to {device}")
|
14 |
+
self.device = device
|
15 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
16 |
+
self.processor = None
|
17 |
+
self.model_type = model_type
|
18 |
+
if model is None:
|
19 |
+
self.checkpoint = checkpoint
|
20 |
+
self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
|
21 |
+
self.model.to(device=self.device)
|
22 |
+
else:
|
23 |
+
self.model = model
|
24 |
+
self.reuse_feature = reuse_feature
|
25 |
+
self.predictor = SamPredictor(self.model)
|
26 |
+
self.mask_generator = SamAutomaticMaskGenerator(self.model)
|
27 |
+
self.image_embedding = None
|
28 |
+
self.image = None
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def set_image(self, image: Union[np.ndarray, Image.Image, str]):
|
33 |
+
if type(image) == str: # input path
|
34 |
+
image = Image.open(image)
|
35 |
+
image = np.array(image)
|
36 |
+
elif type(image) == Image.Image:
|
37 |
+
image = np.array(image)
|
38 |
+
self.image = image
|
39 |
+
if self.reuse_feature:
|
40 |
+
self.predictor.set_image(image)
|
41 |
+
self.image_embedding = self.predictor.get_image_embedding()
|
42 |
+
print(self.image_embedding.shape)
|
43 |
+
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def inference(self, image, control):
|
47 |
+
if 'everything' in control['prompt_type']:
|
48 |
+
masks = self.mask_generator.generate(image)
|
49 |
+
new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
|
50 |
+
return new_masks
|
51 |
+
else:
|
52 |
+
if not self.reuse_feature or self.image_embedding is None:
|
53 |
+
self.set_image(image)
|
54 |
+
self.predictor.set_image(self.image)
|
55 |
+
else:
|
56 |
+
assert self.image_embedding is not None
|
57 |
+
self.predictor.features = self.image_embedding
|
58 |
+
|
59 |
+
if 'mutimask_output' in control:
|
60 |
+
masks, scores, logits = self.predictor.predict(
|
61 |
+
point_coords = np.array(control['input_point']),
|
62 |
+
point_labels = np.array(control['input_label']),
|
63 |
+
multimask_output = True,
|
64 |
+
)
|
65 |
+
elif 'input_boxes' in control:
|
66 |
+
transformed_boxes = self.predictor.transform.apply_boxes_torch(
|
67 |
+
torch.tensor(control["input_boxes"], device=self.predictor.device),
|
68 |
+
image.shape[:2]
|
69 |
+
)
|
70 |
+
masks, _, _ = self.predictor.predict_torch(
|
71 |
+
point_coords=None,
|
72 |
+
point_labels=None,
|
73 |
+
boxes=transformed_boxes,
|
74 |
+
multimask_output=False,
|
75 |
+
)
|
76 |
+
masks = masks.squeeze(1).cpu().numpy()
|
77 |
+
|
78 |
+
else:
|
79 |
+
input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
|
80 |
+
input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
|
81 |
+
input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
|
82 |
+
|
83 |
+
masks, scores, logits = self.predictor.predict(
|
84 |
+
point_coords = input_point,
|
85 |
+
point_labels = input_label,
|
86 |
+
box = input_box,
|
87 |
+
multimask_output = False,
|
88 |
+
)
|
89 |
+
|
90 |
+
if 0 in control['input_label']:
|
91 |
+
mask_input = logits[np.argmax(scores), :, :]
|
92 |
+
masks, scores, logits = self.predictor.predict(
|
93 |
+
point_coords=input_point,
|
94 |
+
point_labels=input_label,
|
95 |
+
box = input_box,
|
96 |
+
mask_input=mask_input[None, :, :],
|
97 |
+
multimask_output=False,
|
98 |
+
)
|
99 |
+
|
100 |
+
return masks
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
image_path = 'segmenter/images/truck.jpg'
|
104 |
+
prompts = [
|
105 |
+
# {
|
106 |
+
# "prompt_type":["click"],
|
107 |
+
# "input_point":[[500, 375]],
|
108 |
+
# "input_label":[1],
|
109 |
+
# "multimask_output":"True",
|
110 |
+
# },
|
111 |
+
{
|
112 |
+
"prompt_type":["click"],
|
113 |
+
"input_point":[[1000, 600], [1325, 625]],
|
114 |
+
"input_label":[1, 0],
|
115 |
+
},
|
116 |
+
# {
|
117 |
+
# "prompt_type":["click", "box"],
|
118 |
+
# "input_box":[425, 600, 700, 875],
|
119 |
+
# "input_point":[[575, 750]],
|
120 |
+
# "input_label": [0]
|
121 |
+
# },
|
122 |
+
# {
|
123 |
+
# "prompt_type":["box"],
|
124 |
+
# "input_boxes": [
|
125 |
+
# [75, 275, 1725, 850],
|
126 |
+
# [425, 600, 700, 875],
|
127 |
+
# [1375, 550, 1650, 800],
|
128 |
+
# [1240, 675, 1400, 750],
|
129 |
+
# ]
|
130 |
+
# },
|
131 |
+
# {
|
132 |
+
# "prompt_type":["everything"]
|
133 |
+
# },
|
134 |
+
]
|
135 |
+
|
136 |
+
init_time = time.time()
|
137 |
+
segmenter = BaseSegmenter(
|
138 |
+
device='cuda',
|
139 |
+
# checkpoint='sam_vit_h_4b8939.pth',
|
140 |
+
checkpoint='segmenter/sam_vit_h_4b8939.pth',
|
141 |
+
model_type='vit_h',
|
142 |
+
reuse_feature=True
|
143 |
+
)
|
144 |
+
print(f'init time: {time.time() - init_time}')
|
145 |
+
|
146 |
+
image_path = 'test_img/img2.jpg'
|
147 |
+
infer_time = time.time()
|
148 |
+
for i, prompt in enumerate(prompts):
|
149 |
+
print(f'{prompt["prompt_type"]} mode')
|
150 |
+
image = Image.open(image_path)
|
151 |
+
segmenter.set_image(np.array(image))
|
152 |
+
masks = segmenter.inference(np.array(image), prompt)
|
153 |
+
Image.fromarray(masks[0]).save('seg.png')
|
154 |
+
print(masks.shape)
|
155 |
+
|
156 |
+
print(f'infer time: {time.time() - infer_time}')
|
segmenter/images/truck.jpg
ADDED
![]() |
segmenter/readme.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Prepare SAM
|
2 |
+
```
|
3 |
+
pip install git+https://github.com/facebookresearch/segment-anything.git
|
4 |
+
```
|
5 |
+
or
|
6 |
+
```
|
7 |
+
git clone [email protected]:facebookresearch/segment-anything.git
|
8 |
+
cd segment-anything; pip install -e .
|
9 |
+
```
|
10 |
+
|
11 |
+
```
|
12 |
+
pip install opencv-python pycocotools matplotlib onnxruntime onnx
|
13 |
+
```
|
14 |
+
### Download the checkpoint:
|
15 |
+
|
16 |
+
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
17 |
+
|
18 |
+
### Inference
|
19 |
+
|
20 |
+
The prompts are in json format:
|
21 |
+
|
22 |
+
```
|
23 |
+
prompts = [
|
24 |
+
{
|
25 |
+
"prompt_type":["click"],
|
26 |
+
"input_point":[[500, 375]],
|
27 |
+
"input_label":[1],
|
28 |
+
"multimask_output":"True",
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"prompt_type":["click"],
|
32 |
+
"input_point":[[500, 375], [1125, 625]],
|
33 |
+
"input_label":[1, 0],
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"prompt_type":["click", "box"],
|
37 |
+
"input_box":[425, 600, 700, 875],
|
38 |
+
"input_point":[[575, 750]],
|
39 |
+
"input_label": [0]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"prompt_type":["box"],
|
43 |
+
"input_boxes": [
|
44 |
+
[75, 275, 1725, 850],
|
45 |
+
[425, 600, 700, 875],
|
46 |
+
[1375, 550, 1650, 800],
|
47 |
+
[1240, 675, 1400, 750],
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"prompt_type":["everything"]
|
52 |
+
},
|
53 |
+
]
|
54 |
+
```
|
55 |
+
|
56 |
+
In `base_segmenter.py`:
|
57 |
+
```
|
58 |
+
segmenter = BaseSegmenter(
|
59 |
+
device='cuda',
|
60 |
+
checkpoint='sam_vit_h_4b8939.pth',
|
61 |
+
model_type='vit_h'
|
62 |
+
)
|
63 |
+
|
64 |
+
for i, prompt in enumerate(prompts):
|
65 |
+
masks = segmenter.inference(image_path, prompt)
|
66 |
+
```
|
67 |
+
|
68 |
+
Outputs are masks (True and False numpy Matrix), shape: (num of masks, height, weight)
|
segmenter/sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|
test_img/img0.png
ADDED
![]() |
test_img/img1.jpg
ADDED
![]() |
test_img/img1.jpg.raw_mask.png
ADDED
![]() |
test_img/img10.jpg
ADDED
![]() |
test_img/img10.jpg.raw_mask.png
ADDED
![]() |
test_img/img11.jpg
ADDED
![]() |
test_img/img12.jpg
ADDED
![]() |
test_img/img12.jpg.raw_mask.png
ADDED
![]() |
test_img/img13.jpg
ADDED
![]() |
test_img/img13.jpg.raw_mask.png
ADDED
![]() |
test_img/img14.jpg
ADDED
![]() |
test_img/img14.jpg.raw_mask.png
ADDED
![]() |
test_img/img15.jpg
ADDED
![]() |
test_img/img15.jpg.raw_mask.png
ADDED
![]() |
test_img/img16.jpg
ADDED
![]() |
test_img/img16.jpg.raw_mask.png
ADDED
![]() |
test_img/img17.jpg
ADDED
![]() |
test_img/img18.jpg
ADDED
![]() |
Git LFS Details
|
test_img/img19.jpg
ADDED
![]() |
test_img/img2.jpg
ADDED
![]() |