Spaces:
Runtime error
Runtime error
yuanze
commited on
Commit
·
f15a1cd
1
Parent(s):
6dc0a5f
init
Browse files- .gitignore +2 -0
- README.md +3 -3
- app.py +129 -0
- change_setup.txt +38 -0
- data/objaverse_uni3d_3D_embeddings.pt +3 -0
- data/objaverse_uni3d_image_above_embeddings.pt +3 -0
- data/objaverse_uni3d_image_back_embeddings.pt +3 -0
- data/objaverse_uni3d_image_below_embeddings.pt +3 -0
- data/objaverse_uni3d_image_diag_above_embeddings.pt +3 -0
- data/objaverse_uni3d_image_diag_below_embeddings.pt +3 -0
- data/objaverse_uni3d_image_front_embeddings.pt +3 -0
- data/objaverse_uni3d_image_left_embeddings.pt +3 -0
- data/objaverse_uni3d_image_right_embeddings.pt +3 -0
- data/objaverse_uni3d_text_embeddings.pt +3 -0
- data/source_id_list.pt +3 -0
- dockerfile +19 -0
- feature_extractors/__init__.py +56 -0
- feature_extractors/uni3d_embedding_encoder.py +337 -0
- packages +1 -0
- requirements.txt +9 -0
- utils/bpe_simple_vocab_16e6.txt.gz +3 -0
- utils/tokenizer.py +147 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.cache
|
| 2 |
+
__pycache__/
|
README.md
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
---
|
| 2 |
title: LD T3D
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: yellow
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 1 |
---
|
| 2 |
title: LD T3D
|
| 3 |
+
emoji: 🐳
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: yellow
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
app.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import functools
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from feature_extractors.uni3d_embedding_encoder import Uni3dEmbeddingEncoder
|
| 8 |
+
|
| 9 |
+
# os.environ['HTTP_PROXY'] = 'http://192.168.48.17:18000'
|
| 10 |
+
# os.environ['HTTPS_PROXY'] = 'http://192.168.48.17:18000'
|
| 11 |
+
|
| 12 |
+
MAX_BATCH_SIZE = 16
|
| 13 |
+
MAX_QUEUE_SIZE = 10
|
| 14 |
+
MAX_K_RETRIEVAL = 20
|
| 15 |
+
cache_dir = "./.cache"
|
| 16 |
+
|
| 17 |
+
encoder = Uni3dEmbeddingEncoder(cache_dir)
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
source_id_list = torch.load("data/source_id_list.pt")
|
| 20 |
+
source_to_id = {source_id: i for i, source_id in enumerate(source_id_list)}
|
| 21 |
+
dataset = load_dataset("VAST-AI/LD-T3D", name=f"rendered_imgs_diag_above", split="base", cache_dir=cache_dir)
|
| 22 |
+
|
| 23 |
+
@functools.lru_cache()
|
| 24 |
+
def get_embedding(option, modality, angle=None):
|
| 25 |
+
save_path = f'data/objaverse_{option}_{modality + (("_" + str(angle)) if angle is not None else "")}_embeddings.pt'
|
| 26 |
+
if os.path.exists(save_path):
|
| 27 |
+
return torch.load(save_path)
|
| 28 |
+
else:
|
| 29 |
+
return gr.Error(f"Embedding file not found: {save_path}")
|
| 30 |
+
|
| 31 |
+
def predict(xb, xq, top_k):
|
| 32 |
+
xb = xb.to(xq.device)
|
| 33 |
+
sim = xq @ xb.T # (nq, nb)
|
| 34 |
+
_, indices = sim.topk(k=top_k, largest=True)
|
| 35 |
+
return indices
|
| 36 |
+
|
| 37 |
+
def get_image(index):
|
| 38 |
+
return dataset[index]["image"]
|
| 39 |
+
|
| 40 |
+
def retrieve_3D_models(textual_query, top_k, modality_list):
|
| 41 |
+
if textual_query == "":
|
| 42 |
+
raise gr.Error("Please enter a textual query")
|
| 43 |
+
if len(textual_query.split()) > 20:
|
| 44 |
+
gr.Warning("Retrieval result may be inaccurate due to long textual query")
|
| 45 |
+
if len(modality_list) == 0:
|
| 46 |
+
raise gr.Error("Please select at least one modality")
|
| 47 |
+
|
| 48 |
+
def _retrieve_3D_models(query, top_k, modals:list):
|
| 49 |
+
option = "uni3d"
|
| 50 |
+
op = "add"
|
| 51 |
+
is_text = True if "text" in modals else False
|
| 52 |
+
is_3D = True if "3D" in modals else False
|
| 53 |
+
if is_text:
|
| 54 |
+
modals.remove("text")
|
| 55 |
+
if is_3D:
|
| 56 |
+
modals.remove("3D")
|
| 57 |
+
angles = modals
|
| 58 |
+
|
| 59 |
+
# get base embeddings
|
| 60 |
+
embeddings = []
|
| 61 |
+
if is_text:
|
| 62 |
+
embeddings.append(get_embedding(option, "text"))
|
| 63 |
+
if len(angles) > 0:
|
| 64 |
+
for angle in angles:
|
| 65 |
+
embeddings.append(get_embedding(option, "image", angle=angle))
|
| 66 |
+
if is_3D:
|
| 67 |
+
embeddings.append(get_embedding(option, "3D"))
|
| 68 |
+
|
| 69 |
+
## fuse base embeddings
|
| 70 |
+
if len(embeddings) > 1:
|
| 71 |
+
if op == "concat":
|
| 72 |
+
embeddings = torch.cat(embeddings, dim=-1)
|
| 73 |
+
elif op == "add":
|
| 74 |
+
embeddings = sum(embeddings)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unsupported operation: {op}")
|
| 77 |
+
embeddings /= embeddings.norm(dim=-1, keepdim=True)
|
| 78 |
+
else:
|
| 79 |
+
embeddings = embeddings[0]
|
| 80 |
+
|
| 81 |
+
# encode query embeddings
|
| 82 |
+
xq = encoder.encode_query(query)
|
| 83 |
+
if op == "concat":
|
| 84 |
+
xq = xq.repeat(1, embeddings.shape[-1] // xq.shape[-1]) # repeat to be aligned with the xb
|
| 85 |
+
xq /= xq.norm(dim=-1, keepdim=True)
|
| 86 |
+
|
| 87 |
+
pred_ind_list = predict(embeddings, xq, top_k)
|
| 88 |
+
return pred_ind_list[0].cpu().tolist() # we have only one query
|
| 89 |
+
|
| 90 |
+
indices = _retrieve_3D_models(textual_query, top_k, modality_list)
|
| 91 |
+
return [get_image(index) for index in indices]
|
| 92 |
+
|
| 93 |
+
def launch():
|
| 94 |
+
with gr.Blocks() as demo:
|
| 95 |
+
with gr.Row():
|
| 96 |
+
textual_query = gr.Textbox(label="Textual Query", autofocus=True,
|
| 97 |
+
placeholder="A chair with a wooden frame and a cushioned seat")
|
| 98 |
+
modality_list = gr.CheckboxGroup(label="Modality List", value=[],
|
| 99 |
+
choices=["text", "front", "back", "left", "right", "above",
|
| 100 |
+
"below", "diag_above", "diag_below", "3D"])
|
| 101 |
+
with gr.Row():
|
| 102 |
+
top_k = gr.Slider(minimum=1, maximum=MAX_K_RETRIEVAL, step=1, label="Top K Retrieval Result",
|
| 103 |
+
value=5, scale=2)
|
| 104 |
+
run = gr.Button("Search", scale=1)
|
| 105 |
+
clear_button = gr.ClearButton(scale=1)
|
| 106 |
+
with gr.Row():
|
| 107 |
+
output = gr.Gallery(format="webp", label="Retrieval Result", columns=5, type="pil")
|
| 108 |
+
run.click(retrieve_3D_models, [textual_query, top_k, modality_list], output,
|
| 109 |
+
# batch=True, max_batch_size=MAX_BATCH_SIZE
|
| 110 |
+
)
|
| 111 |
+
clear_button.click(lambda: ["", 5, [], []], outputs=[textual_query, top_k, modality_list, output])
|
| 112 |
+
examples = gr.Examples(examples=[["An ice cream with a cherry on top", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
|
| 113 |
+
["A mid-age castle", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
|
| 114 |
+
["A coke", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]]],
|
| 115 |
+
inputs=[textual_query, top_k, modality_list],
|
| 116 |
+
# cache_examples=True,
|
| 117 |
+
outputs=output,
|
| 118 |
+
fn=retrieve_3D_models)
|
| 119 |
+
|
| 120 |
+
demo.queue(max_size=10)
|
| 121 |
+
|
| 122 |
+
# os.environ.pop('HTTP_PROXY')
|
| 123 |
+
# os.environ.pop('HTTPS_PROXY')
|
| 124 |
+
|
| 125 |
+
demo.launch(server_name='0.0.0.0')
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
launch()
|
| 129 |
+
# print(len(retrieve_3D_models("A chair with a wooden frame and a cushioned seat", 5, ["3D", "diag_above", "diag_below"])))
|
change_setup.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import os.path as osp
|
| 4 |
+
|
| 5 |
+
from setuptools import find_packages, setup
|
| 6 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 7 |
+
|
| 8 |
+
this_dir = osp.dirname(osp.abspath(__file__))
|
| 9 |
+
_ext_src_root = osp.join("pointnet2_ops", "_ext-src")
|
| 10 |
+
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
| 11 |
+
osp.join(_ext_src_root, "src", "*.cu")
|
| 12 |
+
)
|
| 13 |
+
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
| 14 |
+
|
| 15 |
+
requirements = ["torch>=1.4"]
|
| 16 |
+
|
| 17 |
+
exec(open(osp.join("pointnet2_ops", "_version.py")).read())
|
| 18 |
+
|
| 19 |
+
setup(
|
| 20 |
+
name="pointnet2_ops",
|
| 21 |
+
version=__version__,
|
| 22 |
+
author="Erik Wijmans",
|
| 23 |
+
packages=find_packages(),
|
| 24 |
+
install_requires=requirements,
|
| 25 |
+
ext_modules=[
|
| 26 |
+
CUDAExtension(
|
| 27 |
+
name="pointnet2_ops._ext",
|
| 28 |
+
sources=_ext_sources,
|
| 29 |
+
extra_compile_args={
|
| 30 |
+
"cxx": ["-O3"],
|
| 31 |
+
"nvcc": ["-O3", "-Xfatbin", "-compress-all"],
|
| 32 |
+
},
|
| 33 |
+
include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
|
| 34 |
+
)
|
| 35 |
+
],
|
| 36 |
+
cmdclass={"build_ext": BuildExtension},
|
| 37 |
+
include_package_data=True,
|
| 38 |
+
)
|
data/objaverse_uni3d_3D_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b05400ab75009785535bd78d859db0a902176fbeb5df2ef73e55a95990ded1b8
|
| 3 |
+
size 365511995
|
data/objaverse_uni3d_image_above_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c0708d9bfb4df4e6f86a21bd5a1096401c8c037e84575e6d0397efdb1b138289
|
| 3 |
+
size 365512104
|
data/objaverse_uni3d_image_back_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5667981bc1215e1f60c034ff8e2d214da6186a2f3212061b8ed3e1c32073ad6e
|
| 3 |
+
size 365512104
|
data/objaverse_uni3d_image_below_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f91df0329424657666dd9a5b3181d52f9155ad545dc22a2f725f24f9b854abbd
|
| 3 |
+
size 365512104
|
data/objaverse_uni3d_image_diag_above_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b44e2ee38885128e9080c75ee1d311fee8f718375e867c2209273649455c89a7
|
| 3 |
+
size 365512035
|
data/objaverse_uni3d_image_diag_below_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:79eb0da600d75874e22bbfcca6001669eb14f06ec37326bf5148521db82f3e34
|
| 3 |
+
size 365512035
|
data/objaverse_uni3d_image_front_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:016208fa7a76e959840c128c30e178a0b43a570cf7a8e6cfd6fcdb442f6b72db
|
| 3 |
+
size 365512104
|
data/objaverse_uni3d_image_left_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5db0c17a56ebbb0fa1323b105dfe04386f8d7f88c876bc24b943e8713a01076
|
| 3 |
+
size 365512035
|
data/objaverse_uni3d_image_right_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5fb149475c79b465157d5b2cfe2af4ad8947ff23f99577da264c2632bc9d770
|
| 3 |
+
size 365512035
|
data/objaverse_uni3d_text_embeddings.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2d908630bcc8a5a231e8b5d11714c63a3e8b6d78427a82a833da9219b2a7263
|
| 3 |
+
size 365512020
|
data/source_id_list.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c218ccb58d0045b0b6671c1378ee43362054b890f9895d7cac3de727683a9a76
|
| 3 |
+
size 3747900
|
dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
FROM nvcr.io/nvidia/pytorch:23.08
|
| 3 |
+
|
| 4 |
+
LABEL maintainer="yuanze"
|
| 5 |
+
LABEL email="[email protected]"
|
| 6 |
+
|
| 7 |
+
# Install webp support
|
| 8 |
+
RUN apt update && apt install libwebp-dev -y
|
| 9 |
+
|
| 10 |
+
RUN pip install -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# note that you may need to modify the TORCH_CUDA_ARCH_LIST in the setup.py file
|
| 13 |
+
ENV TORCH_CUDA_ARCH_LIST="8.6"
|
| 14 |
+
|
| 15 |
+
# Install Pointnet2_PyTorch
|
| 16 |
+
RUN git clone https://github.com/erikwijmans/Pointnet2_PyTorch.git \
|
| 17 |
+
&& mv -f backup_install.txt Pointnet2_PyTorch/pointnet2_ops_lib/setup.py \
|
| 18 |
+
&& cd Pointnet2_PyTorch/pointnet2_ops_lib \
|
| 19 |
+
&& python install .
|
feature_extractors/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Sequence
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
import torch
|
| 4 |
+
from PIL.Image import Image
|
| 5 |
+
|
| 6 |
+
class FeatureExtractor(ABC):
|
| 7 |
+
@abstractmethod
|
| 8 |
+
def encode_image(self, img_list: Sequence[Image]) -> torch.Tensor:
|
| 9 |
+
"""
|
| 10 |
+
Encode the input images and return the corresponding embeddings.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
img_list: A list of PIL.Image.Image objects.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
The embeddings of the input images. The shape should be (len(img_list), embedding_dim).
|
| 17 |
+
"""
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def encode_text(self, text_list: Sequence[str]) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
Encode the input text data and return the corresponding embeddings.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
text_list: A list of strings.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
The embeddings of the input text data. The shape should be (len(text_list), embedding_dim).
|
| 30 |
+
"""
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def encode_3D(self, pc_tensor: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
"""
|
| 36 |
+
Encode the input 3D point cloud and return the corresponding embeddings.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
pc_tensor: A tensor of shape (B, N, 3 + 3).
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
The embeddings of the input 3D point cloud. The shape should be (B, embedding_dim).
|
| 43 |
+
"""
|
| 44 |
+
raise NotImplementedError
|
| 45 |
+
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def encode_query(self, queries: Sequence[str]) -> torch.Tensor:
|
| 48 |
+
"""Encode the queries and return the corresponding embeddings.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
queries: A list of strings.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
The embeddings of the input text data. The shape should be (len(input_text), embedding_dim).
|
| 55 |
+
"""
|
| 56 |
+
raise NotImplementedError
|
feature_extractors/uni3d_embedding_encoder.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
See https://github.com/baaivision/Uni3D for source code
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import timm
|
| 8 |
+
import numpy as np
|
| 9 |
+
from pointnet2_ops import pointnet2_utils
|
| 10 |
+
import open_clip
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
import sys
|
| 13 |
+
sys.path.append('')
|
| 14 |
+
from feature_extractors import FeatureExtractor
|
| 15 |
+
from utils.tokenizer import SimpleTokenizer
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
def fps(data, number):
|
| 20 |
+
'''
|
| 21 |
+
data B N 3
|
| 22 |
+
number int
|
| 23 |
+
'''
|
| 24 |
+
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
|
| 25 |
+
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
|
| 26 |
+
return fps_data
|
| 27 |
+
|
| 28 |
+
# https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py
|
| 29 |
+
def knn_point(nsample, xyz, new_xyz):
|
| 30 |
+
"""
|
| 31 |
+
Input:
|
| 32 |
+
nsample: max sample number in local region
|
| 33 |
+
xyz: all points, [B, N, C]
|
| 34 |
+
new_xyz: query points, [B, S, C]
|
| 35 |
+
Return:
|
| 36 |
+
group_idx: grouped points index, [B, S, nsample]
|
| 37 |
+
"""
|
| 38 |
+
sqrdists = square_distance(new_xyz, xyz)
|
| 39 |
+
_, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
|
| 40 |
+
return group_idx
|
| 41 |
+
|
| 42 |
+
def square_distance(src, dst):
|
| 43 |
+
"""
|
| 44 |
+
Calculate Euclid distance between each two points.
|
| 45 |
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
| 46 |
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
| 47 |
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
| 48 |
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
| 49 |
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
| 50 |
+
Input:
|
| 51 |
+
src: source points, [B, N, C]
|
| 52 |
+
dst: target points, [B, M, C]
|
| 53 |
+
Output:
|
| 54 |
+
dist: per-point square distance, [B, N, M]
|
| 55 |
+
"""
|
| 56 |
+
B, N, _ = src.shape
|
| 57 |
+
_, M, _ = dst.shape
|
| 58 |
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
| 59 |
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
| 60 |
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
| 61 |
+
return dist
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class PatchDropout(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
https://arxiv.org/abs/2212.00794
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, prob, exclude_first_token=True):
|
| 70 |
+
super().__init__()
|
| 71 |
+
assert 0 <= prob < 1.
|
| 72 |
+
self.prob = prob
|
| 73 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
| 74 |
+
logging.info("patch dropout prob is {}".format(prob))
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
# if not self.training or self.prob == 0.:
|
| 78 |
+
# return x
|
| 79 |
+
|
| 80 |
+
if self.exclude_first_token:
|
| 81 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
| 82 |
+
else:
|
| 83 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
| 84 |
+
|
| 85 |
+
batch = x.size()[0]
|
| 86 |
+
num_tokens = x.size()[1]
|
| 87 |
+
|
| 88 |
+
batch_indices = torch.arange(batch)
|
| 89 |
+
batch_indices = batch_indices[..., None]
|
| 90 |
+
|
| 91 |
+
keep_prob = 1 - self.prob
|
| 92 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
| 93 |
+
|
| 94 |
+
rand = torch.randn(batch, num_tokens)
|
| 95 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
| 96 |
+
|
| 97 |
+
x = x[batch_indices, patch_indices_keep]
|
| 98 |
+
|
| 99 |
+
if self.exclude_first_token:
|
| 100 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 101 |
+
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Group(nn.Module):
|
| 106 |
+
def __init__(self, num_group, group_size):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.num_group = num_group
|
| 109 |
+
self.group_size = group_size
|
| 110 |
+
|
| 111 |
+
def forward(self, xyz, color):
|
| 112 |
+
'''
|
| 113 |
+
input: B N 3
|
| 114 |
+
---------------------------
|
| 115 |
+
output: B G M 3
|
| 116 |
+
center : B G 3
|
| 117 |
+
'''
|
| 118 |
+
batch_size, num_points, _ = xyz.shape
|
| 119 |
+
# fps the centers out
|
| 120 |
+
center = fps(xyz, self.num_group) # B G 3
|
| 121 |
+
# knn to get the neighborhood
|
| 122 |
+
# _, idx = self.knn(xyz, center) # B G M
|
| 123 |
+
idx = knn_point(self.group_size, xyz, center) # B G M
|
| 124 |
+
assert idx.size(1) == self.num_group
|
| 125 |
+
assert idx.size(2) == self.group_size
|
| 126 |
+
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
|
| 127 |
+
idx = idx + idx_base
|
| 128 |
+
idx = idx.view(-1)
|
| 129 |
+
neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
|
| 130 |
+
neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
|
| 131 |
+
|
| 132 |
+
neighborhood_color = color.view(batch_size * num_points, -1)[idx, :]
|
| 133 |
+
neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous()
|
| 134 |
+
|
| 135 |
+
# normalize
|
| 136 |
+
neighborhood = neighborhood - center.unsqueeze(2)
|
| 137 |
+
|
| 138 |
+
features = torch.cat((neighborhood, neighborhood_color), dim=-1)
|
| 139 |
+
return neighborhood, center, features
|
| 140 |
+
|
| 141 |
+
class Encoder(nn.Module):
|
| 142 |
+
def __init__(self, encoder_channel):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.encoder_channel = encoder_channel
|
| 145 |
+
self.first_conv = nn.Sequential(
|
| 146 |
+
nn.Conv1d(6, 128, 1),
|
| 147 |
+
nn.BatchNorm1d(128),
|
| 148 |
+
nn.ReLU(inplace=True),
|
| 149 |
+
nn.Conv1d(128, 256, 1)
|
| 150 |
+
)
|
| 151 |
+
self.second_conv = nn.Sequential(
|
| 152 |
+
nn.Conv1d(512, 512, 1),
|
| 153 |
+
nn.BatchNorm1d(512),
|
| 154 |
+
nn.ReLU(inplace=True),
|
| 155 |
+
nn.Conv1d(512, self.encoder_channel, 1)
|
| 156 |
+
)
|
| 157 |
+
def forward(self, point_groups):
|
| 158 |
+
'''
|
| 159 |
+
point_groups : B G N 3
|
| 160 |
+
-----------------
|
| 161 |
+
feature_global : B G C
|
| 162 |
+
'''
|
| 163 |
+
bs, g, n , _ = point_groups.shape
|
| 164 |
+
point_groups = point_groups.reshape(bs * g, n, 6)
|
| 165 |
+
# encoder
|
| 166 |
+
feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n
|
| 167 |
+
feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1
|
| 168 |
+
feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
|
| 169 |
+
feature = self.second_conv(feature) # BG 1024 n
|
| 170 |
+
feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
|
| 171 |
+
return feature_global.reshape(bs, g, self.encoder_channel)
|
| 172 |
+
|
| 173 |
+
class PointcloudEncoder(nn.Module):
|
| 174 |
+
def __init__(self, point_transformer):
|
| 175 |
+
# use the giant branch of uni3d
|
| 176 |
+
super().__init__()
|
| 177 |
+
from easydict import EasyDict
|
| 178 |
+
self.trans_dim = 1408
|
| 179 |
+
self.embed_dim = 1024
|
| 180 |
+
self.group_size = 64
|
| 181 |
+
self.num_group = 512
|
| 182 |
+
# grouper
|
| 183 |
+
self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
|
| 184 |
+
# define the encoder
|
| 185 |
+
self.encoder_dim = 512
|
| 186 |
+
self.encoder = Encoder(encoder_channel = self.encoder_dim)
|
| 187 |
+
|
| 188 |
+
# bridge encoder and transformer
|
| 189 |
+
self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim)
|
| 190 |
+
|
| 191 |
+
# bridge transformer and clip embedding
|
| 192 |
+
self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim)
|
| 193 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
|
| 194 |
+
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
|
| 195 |
+
|
| 196 |
+
self.pos_embed = nn.Sequential(
|
| 197 |
+
nn.Linear(3, 128),
|
| 198 |
+
nn.GELU(),
|
| 199 |
+
nn.Linear(128, self.trans_dim)
|
| 200 |
+
)
|
| 201 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
| 202 |
+
self.patch_dropout = PatchDropout(0.) if 0. > 0. else nn.Identity()
|
| 203 |
+
self.visual = point_transformer
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def forward(self, pts, colors):
|
| 207 |
+
# divide the point cloud in the same form. This is important
|
| 208 |
+
_, center, features = self.group_divider(pts, colors)
|
| 209 |
+
|
| 210 |
+
# encoder the input cloud patches
|
| 211 |
+
group_input_tokens = self.encoder(features) # B G N
|
| 212 |
+
group_input_tokens = self.encoder2trans(group_input_tokens)
|
| 213 |
+
# prepare cls
|
| 214 |
+
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
|
| 215 |
+
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
|
| 216 |
+
# add pos embedding
|
| 217 |
+
pos = self.pos_embed(center)
|
| 218 |
+
# final input
|
| 219 |
+
x = torch.cat((cls_tokens, group_input_tokens), dim=1)
|
| 220 |
+
pos = torch.cat((cls_pos, pos), dim=1)
|
| 221 |
+
# transformer
|
| 222 |
+
x = x + pos
|
| 223 |
+
# x = x.half()
|
| 224 |
+
|
| 225 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
| 226 |
+
x = self.patch_dropout(x)
|
| 227 |
+
|
| 228 |
+
x = self.visual.pos_drop(x)
|
| 229 |
+
|
| 230 |
+
# ModuleList not support forward
|
| 231 |
+
for i, blk in enumerate(self.visual.blocks):
|
| 232 |
+
x = blk(x)
|
| 233 |
+
x = self.visual.norm(x[:, 0, :])
|
| 234 |
+
x = self.visual.fc_norm(x)
|
| 235 |
+
|
| 236 |
+
x = self.trans2embed(x)
|
| 237 |
+
return x
|
| 238 |
+
|
| 239 |
+
class Uni3D(nn.Module):
|
| 240 |
+
def __init__(self, point_encoder):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 243 |
+
self.point_encoder = point_encoder
|
| 244 |
+
|
| 245 |
+
def encode_pc(self, pc):
|
| 246 |
+
xyz = pc[:,:,:3].contiguous()
|
| 247 |
+
color = pc[:,:,3:].contiguous()
|
| 248 |
+
pc_feat = self.point_encoder(xyz, color)
|
| 249 |
+
return pc_feat
|
| 250 |
+
|
| 251 |
+
def forward(self, pc, text, image):
|
| 252 |
+
text_embed_all = text
|
| 253 |
+
image_embed = image
|
| 254 |
+
pc_embed = self.encode_pc(pc)
|
| 255 |
+
return {'text_embed': text_embed_all,
|
| 256 |
+
'pc_embed': pc_embed,
|
| 257 |
+
'image_embed': image_embed,
|
| 258 |
+
'logit_scale': self.logit_scale.exp()}
|
| 259 |
+
|
| 260 |
+
def get_metric_names(model):
|
| 261 |
+
return ['loss', 'uni3d_loss', 'pc_image_acc', 'pc_text_acc']
|
| 262 |
+
|
| 263 |
+
def create_uni3d(uni3d_path):
|
| 264 |
+
# create transformer blocks for point cloud via timm
|
| 265 |
+
point_transformer = timm.create_model("eva_giant_patch14_560")
|
| 266 |
+
|
| 267 |
+
# create whole point cloud encoder
|
| 268 |
+
point_encoder = PointcloudEncoder(point_transformer)
|
| 269 |
+
|
| 270 |
+
# uni3d model
|
| 271 |
+
model = Uni3D(point_encoder=point_encoder,)
|
| 272 |
+
|
| 273 |
+
checkpoint = torch.load(uni3d_path, map_location='cpu')
|
| 274 |
+
logging.info('loaded checkpoint {}'.format(uni3d_path))
|
| 275 |
+
sd = checkpoint['module']
|
| 276 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
| 277 |
+
sd = {k[len('module.'):]: v for k, v in sd.items()}
|
| 278 |
+
model.load_state_dict(sd)
|
| 279 |
+
return model
|
| 280 |
+
|
| 281 |
+
class Uni3dEmbeddingEncoder(FeatureExtractor):
|
| 282 |
+
def __init__(self, cache_dir, **kwargs) -> None:
|
| 283 |
+
bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
|
| 284 |
+
uni3d_path = os.path.join(cache_dir, "Uni3D", "modelzoo", "uni3d-g", "model.pt") # concat the subfolder as hf_hub_download will put it here
|
| 285 |
+
clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
|
| 286 |
+
|
| 287 |
+
if not os.path.exists(uni3d_path):
|
| 288 |
+
hf_hub_download("BAAI/Uni3D", "model.pt", subfolder="modelzoo/uni3d-g", cache_dir=cache_dir,
|
| 289 |
+
local_dir=cache_dir + os.sep + "Uni3D")
|
| 290 |
+
if not os.path.exists(clip_path):
|
| 291 |
+
hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
|
| 292 |
+
cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
|
| 293 |
+
|
| 294 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 295 |
+
self.tokenizer = SimpleTokenizer(bpe_path)
|
| 296 |
+
self.model = create_uni3d(uni3d_path)
|
| 297 |
+
self.model.eval()
|
| 298 |
+
self.model.to(self.device)
|
| 299 |
+
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
|
| 300 |
+
self.clip_model.to(self.device)
|
| 301 |
+
|
| 302 |
+
def pc_norm(self, pc):
|
| 303 |
+
""" pc: NxC, return NxC """
|
| 304 |
+
centroid = np.mean(pc, axis=0)
|
| 305 |
+
pc = pc - centroid
|
| 306 |
+
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
|
| 307 |
+
pc = pc / m
|
| 308 |
+
return pc
|
| 309 |
+
|
| 310 |
+
@torch.no_grad()
|
| 311 |
+
def encode_3D(self, data):
|
| 312 |
+
pc = data.to(device=self.device, non_blocking=True)
|
| 313 |
+
pc_features = self.model.encode_pc(pc)
|
| 314 |
+
pc_features = pc_features / pc_features.norm(dim=-1, keepdim=True)
|
| 315 |
+
return pc_features.float()
|
| 316 |
+
|
| 317 |
+
@torch.no_grad()
|
| 318 |
+
def encode_text(self, input_text):
|
| 319 |
+
texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True)
|
| 320 |
+
if len(texts.shape) < 2:
|
| 321 |
+
texts = texts[None, ...]
|
| 322 |
+
class_embeddings = self.clip_model.encode_text(texts)
|
| 323 |
+
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
| 324 |
+
return class_embeddings.float()
|
| 325 |
+
|
| 326 |
+
@torch.no_grad()
|
| 327 |
+
def encode_image(self, img_tensor_list):
|
| 328 |
+
image = img_tensor_list.to(device=self.device, non_blocking=True)
|
| 329 |
+
image_features = self.clip_model.encode_image(image)
|
| 330 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 331 |
+
return image_features.float()
|
| 332 |
+
|
| 333 |
+
def encode_query(self, query_list):
|
| 334 |
+
return self.encode_text(query_list)
|
| 335 |
+
|
| 336 |
+
def get_img_transform(self):
|
| 337 |
+
return self.preprocess
|
packages
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
libwebp-dev
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
datasets
|
| 3 |
+
timm
|
| 4 |
+
pillow
|
| 5 |
+
open-clip-torch
|
| 6 |
+
huggingface_hub
|
| 7 |
+
ftfy
|
| 8 |
+
regex
|
| 9 |
+
easydict
|
utils/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
utils/tokenizer.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copied from github.com/baaivision/Uni3D
|
| 2 |
+
# # Modified from github.com/openai/CLIP
|
| 3 |
+
import gzip
|
| 4 |
+
import html
|
| 5 |
+
import os
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
import ftfy
|
| 9 |
+
import regex as re
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@lru_cache()
|
| 14 |
+
def bytes_to_unicode():
|
| 15 |
+
"""
|
| 16 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 17 |
+
The reversible bpe codes work on unicode strings.
|
| 18 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 19 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 20 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 21 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 22 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 23 |
+
"""
|
| 24 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 25 |
+
cs = bs[:]
|
| 26 |
+
n = 0
|
| 27 |
+
for b in range(2**8):
|
| 28 |
+
if b not in bs:
|
| 29 |
+
bs.append(b)
|
| 30 |
+
cs.append(2**8+n)
|
| 31 |
+
n += 1
|
| 32 |
+
cs = [chr(n) for n in cs]
|
| 33 |
+
return dict(zip(bs, cs))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_pairs(word):
|
| 37 |
+
"""Return set of symbol pairs in a word.
|
| 38 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 39 |
+
"""
|
| 40 |
+
pairs = set()
|
| 41 |
+
prev_char = word[0]
|
| 42 |
+
for char in word[1:]:
|
| 43 |
+
pairs.add((prev_char, char))
|
| 44 |
+
prev_char = char
|
| 45 |
+
return pairs
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def basic_clean(text):
|
| 49 |
+
text = ftfy.fix_text(text)
|
| 50 |
+
text = html.unescape(html.unescape(text))
|
| 51 |
+
return text.strip()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def whitespace_clean(text):
|
| 55 |
+
text = re.sub(r'\s+', ' ', text)
|
| 56 |
+
text = text.strip()
|
| 57 |
+
return text
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SimpleTokenizer(object):
|
| 61 |
+
def __init__(self, bpe_path):
|
| 62 |
+
self.byte_encoder = bytes_to_unicode()
|
| 63 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 64 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 65 |
+
merges = merges[1:49152-256-2+1]
|
| 66 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 67 |
+
vocab = list(bytes_to_unicode().values())
|
| 68 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 69 |
+
for merge in merges:
|
| 70 |
+
vocab.append(''.join(merge))
|
| 71 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
| 72 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 73 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 74 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 75 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
| 76 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 77 |
+
|
| 78 |
+
def bpe(self, token):
|
| 79 |
+
if token in self.cache:
|
| 80 |
+
return self.cache[token]
|
| 81 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 82 |
+
pairs = get_pairs(word)
|
| 83 |
+
|
| 84 |
+
if not pairs:
|
| 85 |
+
return token+'</w>'
|
| 86 |
+
|
| 87 |
+
while True:
|
| 88 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 89 |
+
if bigram not in self.bpe_ranks:
|
| 90 |
+
break
|
| 91 |
+
first, second = bigram
|
| 92 |
+
new_word = []
|
| 93 |
+
i = 0
|
| 94 |
+
while i < len(word):
|
| 95 |
+
try:
|
| 96 |
+
j = word.index(first, i)
|
| 97 |
+
new_word.extend(word[i:j])
|
| 98 |
+
i = j
|
| 99 |
+
except:
|
| 100 |
+
new_word.extend(word[i:])
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 104 |
+
new_word.append(first+second)
|
| 105 |
+
i += 2
|
| 106 |
+
else:
|
| 107 |
+
new_word.append(word[i])
|
| 108 |
+
i += 1
|
| 109 |
+
new_word = tuple(new_word)
|
| 110 |
+
word = new_word
|
| 111 |
+
if len(word) == 1:
|
| 112 |
+
break
|
| 113 |
+
else:
|
| 114 |
+
pairs = get_pairs(word)
|
| 115 |
+
word = ' '.join(word)
|
| 116 |
+
self.cache[token] = word
|
| 117 |
+
return word
|
| 118 |
+
|
| 119 |
+
def encode(self, text):
|
| 120 |
+
bpe_tokens = []
|
| 121 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 122 |
+
for token in re.findall(self.pat, text):
|
| 123 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 124 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 125 |
+
return bpe_tokens
|
| 126 |
+
|
| 127 |
+
def decode(self, tokens):
|
| 128 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 129 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 130 |
+
return text
|
| 131 |
+
|
| 132 |
+
def __call__(self, texts, context_length=77):
|
| 133 |
+
if isinstance(texts, str):
|
| 134 |
+
texts = [texts]
|
| 135 |
+
|
| 136 |
+
sot_token = self.encoder["<|startoftext|>"]
|
| 137 |
+
eot_token = self.encoder["<|endoftext|>"]
|
| 138 |
+
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
| 139 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 140 |
+
|
| 141 |
+
for i, tokens in enumerate(all_tokens):
|
| 142 |
+
tokens = tokens[:context_length]
|
| 143 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 144 |
+
|
| 145 |
+
if len(result) == 1:
|
| 146 |
+
return result[0]
|
| 147 |
+
return result
|