Spaces:
Runtime error
Runtime error
Commit
·
7978a78
0
Parent(s):
init
Browse files- .gitattributes +36 -0
- .gitignore +2 -0
- README.md +13 -0
- app.py +181 -0
- assets/leo.svg +0 -0
- assets/obj_features/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.pth +3 -0
- assets/obj_features/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.pth +3 -0
- assets/obj_features/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.pth +3 -0
- assets/scene_meshes/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.glb +3 -0
- assets/scene_meshes/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.glb +3 -0
- assets/scene_meshes/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.glb +3 -0
- model/cfg.yaml +21 -0
- model/leo_agent.py +210 -0
- requirements.txt +7 -0
- utils.py +184 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
logs/
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: LEO
|
| 3 |
+
emoji: 🦁
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.10.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from utils import *
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
with gr.Blocks(title='LEO Demo') as demo:
|
| 9 |
+
gr.HTML(value="<h1 align='center'>An Embodied Generalist Agent in 3D World</h1>")
|
| 10 |
+
gr.HTML(value="<div align='center' style='margin-top:-1em; margin-bottom:-1em;'><img src='/file=assets/leo.svg' width='4%'></div>")
|
| 11 |
+
# gr.HTML(value="<img src='/file=assets/teaser.png' alt='Teaser' width='760px' style='display: block; margin: auto;'>")
|
| 12 |
+
gr.HTML(value="<p align='center' style='font-size: 1.2em; color: #485fc7;'><a href='https://arxiv.org/abs/2311.12871' target='_blank'>arXiv</a> | <a href='https://embodied-generalist.github.io/' target='_blank'>Project Page</a> | <a href='https://github.com/embodied-generalist/embodied-generalist' target='_blank'>Code</a></p>")
|
| 13 |
+
gr.HTML(value="<p align='center' style='font-size: 1.15em;'><i>LEO: an embodied generalist agent capable of perceiving, grounding, reasoning, planning, and acting in 3D world.</i></p>")
|
| 14 |
+
|
| 15 |
+
with gr.Row():
|
| 16 |
+
with gr.Column(scale=5):
|
| 17 |
+
dropdown_scene = gr.Dropdown(
|
| 18 |
+
choices=MESH_NAMES,
|
| 19 |
+
value=MESH_NAMES[0],
|
| 20 |
+
interactive=True,
|
| 21 |
+
label='Select a 3D scene',
|
| 22 |
+
)
|
| 23 |
+
model_3d = gr.Model3D(
|
| 24 |
+
value=os.path.join(MESH_DIR, f'{MESH_NAMES[0]}.glb'),
|
| 25 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 26 |
+
label='3D Scene',
|
| 27 |
+
camera_position=(90, 30, 10),
|
| 28 |
+
height=659,
|
| 29 |
+
)
|
| 30 |
+
gr.HTML(
|
| 31 |
+
"""<center><strong>
|
| 32 |
+
👆 SCROLL and DRAG on the 3D Scene
|
| 33 |
+
to zoom in/out and rotate. Press CTRL and DRAG to pan.
|
| 34 |
+
</strong></center>
|
| 35 |
+
"""
|
| 36 |
+
)
|
| 37 |
+
with gr.Column(scale=5):
|
| 38 |
+
dropdown_conversation_mode = gr.Dropdown(
|
| 39 |
+
choices=['Single-round mode', 'Multi-round mode'],
|
| 40 |
+
value='Single-round mode',
|
| 41 |
+
interactive=True,
|
| 42 |
+
label='Select conversation mode',
|
| 43 |
+
)
|
| 44 |
+
chatbot = gr.Chatbot(label='Chat with LEO')
|
| 45 |
+
with gr.Row():
|
| 46 |
+
with gr.Column(scale=8):
|
| 47 |
+
user_chat_input = gr.Textbox(
|
| 48 |
+
placeholder="Enter text here to chat with LEO",
|
| 49 |
+
show_label=False,
|
| 50 |
+
autofocus=True,
|
| 51 |
+
)
|
| 52 |
+
with gr.Column(scale=2, min_width=0):
|
| 53 |
+
send_button = gr.Button('Send', variant='primary', scale=2)
|
| 54 |
+
with gr.Row():
|
| 55 |
+
upvote_button = gr.Button(value='👍 Upvote', interactive=False)
|
| 56 |
+
downvote_button = gr.Button(value='👎 Downvote', interactive=False)
|
| 57 |
+
flag_button = gr.Button(value='⚠️ Flag', interactive=False)
|
| 58 |
+
clear_button = gr.Button(value='🗑️ Clear', interactive=False)
|
| 59 |
+
with gr.Row():
|
| 60 |
+
with gr.Accordion(label="Examples for user instruction:", open=True):
|
| 61 |
+
gr.Examples(
|
| 62 |
+
examples=[
|
| 63 |
+
["How many armchairs are there in this room?"],
|
| 64 |
+
["Is there a radio in the room?"],
|
| 65 |
+
["Where is the wardrobe located?TODO"],
|
| 66 |
+
["What is the shape of the shelf in front of the picture?TODO"],
|
| 67 |
+
["Plan for the task: Tidy up and arrange the nursery room.TODO"],
|
| 68 |
+
],
|
| 69 |
+
inputs=user_chat_input,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# generation_config
|
| 73 |
+
with gr.Accordion('Parameters', open=False):
|
| 74 |
+
repetition_penalty = gr.Slider(
|
| 75 |
+
minimum=0.0,
|
| 76 |
+
maximum=10.0,
|
| 77 |
+
value=3.0,
|
| 78 |
+
step=1.0,
|
| 79 |
+
interactive=True,
|
| 80 |
+
label='Repetition penalty',
|
| 81 |
+
)
|
| 82 |
+
length_penalty = gr.Slider(
|
| 83 |
+
minimum=0.0,
|
| 84 |
+
maximum=10.0,
|
| 85 |
+
value=1.0,
|
| 86 |
+
step=1.0,
|
| 87 |
+
interactive=True,
|
| 88 |
+
label="Length penalty",
|
| 89 |
+
)
|
| 90 |
+
gr.Markdown("### Terms of Service")
|
| 91 |
+
gr.HTML(
|
| 92 |
+
"""By using this service, users are required to agree to the following terms:
|
| 93 |
+
the service is a research preview intended for non-commercial use only
|
| 94 |
+
and may collect user dialogue data for future research."""
|
| 95 |
+
)
|
| 96 |
+
gr.Markdown("### Acknowledgment")
|
| 97 |
+
gr.HTML(
|
| 98 |
+
"""Template adapted from <a href="https://llava.hliu.cc/">LLaVA</a> and
|
| 99 |
+
<a href="http://sled-whistler.eecs.umich.edu:7777/">LLM-Grounder</a>."""
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Event handling
|
| 103 |
+
button_list = [upvote_button, downvote_button, flag_button, clear_button]
|
| 104 |
+
|
| 105 |
+
dropdown_scene.change(
|
| 106 |
+
fn=change_scene,
|
| 107 |
+
inputs=[dropdown_scene],
|
| 108 |
+
outputs=[model_3d, chatbot],
|
| 109 |
+
queue=False,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
dropdown_conversation_mode.change(
|
| 113 |
+
fn=clear_history,
|
| 114 |
+
inputs=[],
|
| 115 |
+
outputs=[chatbot, user_chat_input] + button_list,
|
| 116 |
+
queue=False,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
user_chat_input.submit(
|
| 120 |
+
fn=receive_instruction,
|
| 121 |
+
inputs=[chatbot, user_chat_input],
|
| 122 |
+
outputs=[chatbot, user_chat_input, send_button] + button_list,
|
| 123 |
+
queue=False,
|
| 124 |
+
).then(
|
| 125 |
+
fn=generate_response,
|
| 126 |
+
inputs=[
|
| 127 |
+
chatbot,
|
| 128 |
+
dropdown_scene,
|
| 129 |
+
dropdown_conversation_mode,
|
| 130 |
+
repetition_penalty,
|
| 131 |
+
length_penalty,
|
| 132 |
+
],
|
| 133 |
+
outputs=[chatbot, send_button] + button_list,
|
| 134 |
+
scroll_to_output=True,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
send_button.click(
|
| 138 |
+
fn=receive_instruction,
|
| 139 |
+
inputs=[chatbot, user_chat_input],
|
| 140 |
+
outputs=[chatbot, user_chat_input, send_button] + button_list,
|
| 141 |
+
queue=False,
|
| 142 |
+
).then(
|
| 143 |
+
fn=generate_response,
|
| 144 |
+
inputs=[
|
| 145 |
+
chatbot,
|
| 146 |
+
dropdown_scene,
|
| 147 |
+
dropdown_conversation_mode,
|
| 148 |
+
repetition_penalty,
|
| 149 |
+
length_penalty,
|
| 150 |
+
],
|
| 151 |
+
outputs=[chatbot, send_button] + button_list,
|
| 152 |
+
scroll_to_output=True,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
upvote_button.click(
|
| 156 |
+
upvote_response,
|
| 157 |
+
[chatbot, dropdown_scene, dropdown_conversation_mode],
|
| 158 |
+
[user_chat_input, upvote_button, downvote_button, flag_button],
|
| 159 |
+
queue=False,
|
| 160 |
+
)
|
| 161 |
+
downvote_button.click(
|
| 162 |
+
downvote_response,
|
| 163 |
+
[chatbot, dropdown_scene, dropdown_conversation_mode],
|
| 164 |
+
[user_chat_input, upvote_button, downvote_button, flag_button],
|
| 165 |
+
queue=False,
|
| 166 |
+
)
|
| 167 |
+
flag_button.click(
|
| 168 |
+
flag_response,
|
| 169 |
+
[chatbot, dropdown_scene, dropdown_conversation_mode],
|
| 170 |
+
[user_chat_input, upvote_button, downvote_button, flag_button],
|
| 171 |
+
queue=False,
|
| 172 |
+
)
|
| 173 |
+
clear_button.click(
|
| 174 |
+
fn=clear_history,
|
| 175 |
+
inputs=[],
|
| 176 |
+
outputs=[chatbot, user_chat_input] + button_list,
|
| 177 |
+
queue=False,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
demo.queue().launch(share=True, allowed_paths=['assets'])
|
assets/leo.svg
ADDED
|
|
assets/obj_features/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5642bb84ba04d10c5aa199dbcd5ea1ab01df0d2517719a2a2e943381f11bd25b
|
| 3 |
+
size 1002083
|
assets/obj_features/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eae2324173b34331b6dad37c89a75db275d1d23fbb1f1d7478573085cdf1d733
|
| 3 |
+
size 1002083
|
assets/obj_features/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50a9e124ea270cbe23b59fbddb527d5cf61005c657bd3f5f41535998ba84d9b6
|
| 3 |
+
size 1002083
|
assets/scene_meshes/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d197483b3be1f6f1395faa3a8b413ee23335fd8f081456b63db96f5928291b1
|
| 3 |
+
size 9632176
|
assets/scene_meshes/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:419988aa4781ec7d0a06e9087c8a918a20c389c50b210daa6b3c47be981b28ac
|
| 3 |
+
size 9445868
|
assets/scene_meshes/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0db74afa2648056c839840ba8a11d832012b6f70114668835c2da82d5ae07ec2
|
| 3 |
+
size 11326324
|
model/cfg.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use_ckpt: hf
|
| 2 |
+
hf_ckpt_path: [huangjy-pku/embodied-generalist, weights/leo_noact_hf.pth]
|
| 3 |
+
local_ckpt_path: /mnt/huangjiangyong/leo/hf_assets/weights/leo_noact_lora.pth
|
| 4 |
+
model:
|
| 5 |
+
name: LeoAgentLLM
|
| 6 |
+
# vision modules omitted
|
| 7 |
+
llm:
|
| 8 |
+
name: Vicuna7B
|
| 9 |
+
use_ckpt: hf
|
| 10 |
+
hf_cfg_path: huangjy-pku/vicuna-7b
|
| 11 |
+
local_cfg_path: /mnt/huangjiangyong/vicuna-7b
|
| 12 |
+
truncation_side: right
|
| 13 |
+
prompt: ""
|
| 14 |
+
max_out_len: 256
|
| 15 |
+
max_context_len: 256 # for prompt_after_obj
|
| 16 |
+
lora:
|
| 17 |
+
flag: True
|
| 18 |
+
rank: 16
|
| 19 |
+
alpha: 16
|
| 20 |
+
target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
|
| 21 |
+
dropout: 0.0
|
model/leo_agent.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from huggingface_hub import snapshot_download
|
| 4 |
+
from peft import get_peft_model, LoraConfig
|
| 5 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def disabled_train(self, mode=True):
|
| 9 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 10 |
+
does not change anymore."""
|
| 11 |
+
return self
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LeoAgentLLM(nn.Module):
|
| 15 |
+
def __init__(self, cfg):
|
| 16 |
+
super().__init__()
|
| 17 |
+
if hasattr(cfg, 'model'):
|
| 18 |
+
cfg = cfg.model
|
| 19 |
+
|
| 20 |
+
# LLM
|
| 21 |
+
if cfg.llm.use_ckpt == 'hf':
|
| 22 |
+
llm_cfg_path = snapshot_download(cfg.llm.hf_cfg_path)
|
| 23 |
+
else:
|
| 24 |
+
llm_cfg_path = cfg.llm.local_cfg_path
|
| 25 |
+
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, use_fast=False,
|
| 26 |
+
truncation_side=cfg.llm.truncation_side)
|
| 27 |
+
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 28 |
+
self.llm_tokenizer.add_special_tokens({'bos_token': '<s>'})
|
| 29 |
+
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
| 30 |
+
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
|
| 31 |
+
self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16)
|
| 32 |
+
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
|
| 33 |
+
|
| 34 |
+
for param in self.llm_model.parameters():
|
| 35 |
+
param.requires_grad = False
|
| 36 |
+
self.llm_model.eval()
|
| 37 |
+
self.llm_model.train = disabled_train
|
| 38 |
+
|
| 39 |
+
# LoRA-based LLM fine-tuning
|
| 40 |
+
if cfg.llm.lora.flag:
|
| 41 |
+
lora_config = LoraConfig(
|
| 42 |
+
r=cfg.llm.lora.rank,
|
| 43 |
+
lora_alpha=cfg.llm.lora.alpha,
|
| 44 |
+
target_modules=cfg.llm.lora.target_modules,
|
| 45 |
+
lora_dropout=cfg.llm.lora.dropout,
|
| 46 |
+
bias='none',
|
| 47 |
+
modules_to_save=[],
|
| 48 |
+
)
|
| 49 |
+
self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config)
|
| 50 |
+
|
| 51 |
+
self.max_context_len = cfg.llm.max_context_len
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def device(self):
|
| 55 |
+
return list(self.parameters())[0].device
|
| 56 |
+
|
| 57 |
+
def build_right_justified_sequence(self, data_dict):
|
| 58 |
+
"""
|
| 59 |
+
Concat six sequences: `prompt_before_obj`, `prompt_middle_1`, `img_tokens`, `prompt_middle_2`, `obj_tokens`, `prompt_after_obj`.
|
| 60 |
+
Return right justified sequence for causal LM: <pad>, <role/situation>, <img>, <objs>, <instruction>.
|
| 61 |
+
"""
|
| 62 |
+
bs = len(data_dict['prompt_before_obj'])
|
| 63 |
+
|
| 64 |
+
self.llm_tokenizer.padding_side = 'left'
|
| 65 |
+
text_input_tokens_pre = self.llm_tokenizer(
|
| 66 |
+
data_dict['prompt_before_obj'],
|
| 67 |
+
return_tensors='pt',
|
| 68 |
+
padding='longest'
|
| 69 |
+
).to(self.device) # [PAD, BOS, tokens], (B, T1)
|
| 70 |
+
|
| 71 |
+
text_input_tokens_mid1 = self.llm_tokenizer(
|
| 72 |
+
data_dict['prompt_middle_1'],
|
| 73 |
+
return_tensors='pt',
|
| 74 |
+
padding='longest'
|
| 75 |
+
).to(self.device)
|
| 76 |
+
|
| 77 |
+
img_tokens = data_dict['img_tokens'].to(self.device)
|
| 78 |
+
img_masks = data_dict['img_masks'].to(self.device)
|
| 79 |
+
img_masks = img_masks.reshape(-1, 1).repeat(1, img_tokens.size(1))
|
| 80 |
+
|
| 81 |
+
text_input_tokens_mid2 = self.llm_tokenizer(
|
| 82 |
+
data_dict['prompt_middle_2'],
|
| 83 |
+
return_tensors='pt',
|
| 84 |
+
padding='longest'
|
| 85 |
+
).to(self.device)
|
| 86 |
+
|
| 87 |
+
obj_tokens = data_dict['obj_tokens'].to(self.device)
|
| 88 |
+
obj_masks = data_dict['obj_masks'].to(self.device)
|
| 89 |
+
|
| 90 |
+
self.llm_tokenizer.padding_side = 'right' # no need to be 'left', as padding tokens will be shifted
|
| 91 |
+
self.llm_tokenizer.truncation_side = 'left' # truncate history
|
| 92 |
+
text_input_tokens_post = self.llm_tokenizer(
|
| 93 |
+
data_dict['prompt_after_obj'],
|
| 94 |
+
return_tensors='pt',
|
| 95 |
+
padding='longest',
|
| 96 |
+
truncation=True,
|
| 97 |
+
max_length=self.max_context_len,
|
| 98 |
+
).to(self.device) # [BOS, tokens, PAD], (B, T3)
|
| 99 |
+
|
| 100 |
+
# hardcode, remove bos, make "tokenize subseq and concat" equivalent to "tokenize the whole seq"
|
| 101 |
+
assert text_input_tokens_mid1.attention_mask.all() and text_input_tokens_mid2.attention_mask.all(), \
|
| 102 |
+
"prompt_middle should be the same and thus no padding"
|
| 103 |
+
|
| 104 |
+
text_input_tokens_mid1.input_ids = text_input_tokens_mid1.input_ids[:, 1:]
|
| 105 |
+
text_input_tokens_mid1.attention_mask = text_input_tokens_mid1.attention_mask[:, 1:]
|
| 106 |
+
for i in range(bs):
|
| 107 |
+
if not img_masks[i].any():
|
| 108 |
+
# no image input, also mask the text prompt for image tokens
|
| 109 |
+
text_input_tokens_mid1.attention_mask[i].fill_(0)
|
| 110 |
+
|
| 111 |
+
text_input_tokens_mid2.input_ids[:, 0] = 869 # 1 (bos) -> 869 (▁.)
|
| 112 |
+
text_input_tokens_post.input_ids[:, 0] = 869 # 1 (bos) -> 869 (▁.)
|
| 113 |
+
|
| 114 |
+
inputs_embeds_pre = self.llm_model.get_input_embeddings()(text_input_tokens_pre.input_ids)
|
| 115 |
+
inputs_embeds_mid1 = self.llm_model.get_input_embeddings()(text_input_tokens_mid1.input_ids)
|
| 116 |
+
inputs_embeds_mid2 = self.llm_model.get_input_embeddings()(text_input_tokens_mid2.input_ids)
|
| 117 |
+
inputs_embeds_post = self.llm_model.get_input_embeddings()(text_input_tokens_post.input_ids)
|
| 118 |
+
|
| 119 |
+
# since img_tokens, prompt_mid, obj_tokens are fixed length without padding, we concat them first
|
| 120 |
+
inputs_embeds_mid = torch.cat([inputs_embeds_mid1, img_tokens, inputs_embeds_mid2, obj_tokens], dim=1)
|
| 121 |
+
attn_mask_mid = torch.cat([
|
| 122 |
+
text_input_tokens_mid1.attention_mask, img_masks,
|
| 123 |
+
text_input_tokens_mid2.attention_mask, obj_masks
|
| 124 |
+
], dim=1)
|
| 125 |
+
|
| 126 |
+
post_pad_length = torch.logical_not(text_input_tokens_post.attention_mask).sum(-1)
|
| 127 |
+
|
| 128 |
+
bs, l1, hidden_dim = inputs_embeds_pre.shape
|
| 129 |
+
_, l2, _ = inputs_embeds_mid.shape
|
| 130 |
+
_, l3, _ = inputs_embeds_post.shape
|
| 131 |
+
|
| 132 |
+
inputs_embeds = torch.zeros(
|
| 133 |
+
bs, l1+l2+l3, hidden_dim
|
| 134 |
+
).type(inputs_embeds_pre.dtype).to(self.device)
|
| 135 |
+
|
| 136 |
+
attention_mask = torch.zeros(
|
| 137 |
+
bs, l1+l2+l3
|
| 138 |
+
).type(obj_masks.dtype).to(self.device)
|
| 139 |
+
|
| 140 |
+
# assign by chunks
|
| 141 |
+
for i in range(bs):
|
| 142 |
+
post_pad_len = post_pad_length[i]
|
| 143 |
+
|
| 144 |
+
if post_pad_len > 0:
|
| 145 |
+
inputs_embeds[i, :post_pad_len] = inputs_embeds_post[i, -post_pad_len:]
|
| 146 |
+
attention_mask[i, :post_pad_len] = 0
|
| 147 |
+
inputs_embeds[i, post_pad_len+l1+l2:] = inputs_embeds_post[i, :-post_pad_len]
|
| 148 |
+
attention_mask[i, post_pad_len+l1+l2:] = 1
|
| 149 |
+
else:
|
| 150 |
+
# no padding
|
| 151 |
+
inputs_embeds[i, -l3:] = inputs_embeds_post[i]
|
| 152 |
+
attention_mask[i, -l3:] = 1
|
| 153 |
+
|
| 154 |
+
inputs_embeds[i, post_pad_len: post_pad_len+l1] = inputs_embeds_pre[i]
|
| 155 |
+
attention_mask[i, post_pad_len: post_pad_len+l1] = text_input_tokens_pre.attention_mask[i]
|
| 156 |
+
|
| 157 |
+
inputs_embeds[i, post_pad_len+l1: post_pad_len+l1+l2] = inputs_embeds_mid[i]
|
| 158 |
+
attention_mask[i, post_pad_len+l1: post_pad_len+l1+l2] = attn_mask_mid[i]
|
| 159 |
+
|
| 160 |
+
return inputs_embeds, attention_mask
|
| 161 |
+
|
| 162 |
+
@torch.no_grad()
|
| 163 |
+
def generate(
|
| 164 |
+
self,
|
| 165 |
+
data_dict,
|
| 166 |
+
use_nucleus_sampling=False,
|
| 167 |
+
num_beams=5,
|
| 168 |
+
max_length=256,
|
| 169 |
+
min_length=1,
|
| 170 |
+
repetition_penalty=3.0,
|
| 171 |
+
length_penalty=1,
|
| 172 |
+
num_captions=1,
|
| 173 |
+
temperature=1,
|
| 174 |
+
):
|
| 175 |
+
assert 'img_tokens' in data_dict and 'obj_tokens' in data_dict, "Visual features should have been processed offline."
|
| 176 |
+
|
| 177 |
+
inputs_embeds, attention_mask = self.build_right_justified_sequence(data_dict=data_dict)
|
| 178 |
+
bs = inputs_embeds.shape[0]
|
| 179 |
+
|
| 180 |
+
# give bos token as condition
|
| 181 |
+
bos_tokens = self.llm_tokenizer(
|
| 182 |
+
[self.llm_tokenizer.bos_token] * bs,
|
| 183 |
+
return_tensors='pt',
|
| 184 |
+
).to(self.device)
|
| 185 |
+
bos_tokens_ids = bos_tokens.input_ids[:, 0:1] # (B, 1)
|
| 186 |
+
bos_tokens_attn = bos_tokens.attention_mask[:, 0:1] # (B, 1)
|
| 187 |
+
|
| 188 |
+
# prepare a `bos_token`
|
| 189 |
+
bos_embeds = self.llm_model.get_input_embeddings()(bos_tokens_ids) # (B, 1, D)
|
| 190 |
+
inputs_embeds = torch.cat([inputs_embeds, bos_embeds], dim=1) # (B, T1+O+T2+1, D)
|
| 191 |
+
attention_mask = torch.cat([attention_mask, bos_tokens_attn], dim=1) # (B, T1+O+T2+1)
|
| 192 |
+
|
| 193 |
+
outputs = self.llm_model.generate(
|
| 194 |
+
inputs_embeds=inputs_embeds,
|
| 195 |
+
attention_mask=attention_mask,
|
| 196 |
+
do_sample=use_nucleus_sampling,
|
| 197 |
+
temperature=temperature,
|
| 198 |
+
num_beams=num_beams,
|
| 199 |
+
max_length=max_length,
|
| 200 |
+
min_length=min_length,
|
| 201 |
+
repetition_penalty=repetition_penalty,
|
| 202 |
+
length_penalty=length_penalty,
|
| 203 |
+
num_return_sequences=num_captions,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
outputs[outputs == 0] = 2 # convert output id 0 (unk_token) to 2 (eos_token)
|
| 207 |
+
|
| 208 |
+
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 209 |
+
output_text = [text.strip() for text in output_text]
|
| 210 |
+
return output_text
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
| 2 |
+
omegaconf==2.3.0
|
| 3 |
+
peft==0.5.0
|
| 4 |
+
pyyaml==6.0.1
|
| 5 |
+
sentencepiece
|
| 6 |
+
torch==1.13.0+cu116
|
| 7 |
+
transformers==4.28.1
|
utils.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import datetime
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import torch
|
| 9 |
+
import yaml
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
|
| 13 |
+
from model.leo_agent import LeoAgentLLM
|
| 14 |
+
|
| 15 |
+
LOG_DIR = 'logs'
|
| 16 |
+
MESH_DIR = 'assets/scene_meshes'
|
| 17 |
+
MESH_NAMES = [os.path.splitext(fname)[0] for fname in os.listdir(MESH_DIR)]
|
| 18 |
+
ENABLE_BUTTON = gr.update(interactive=True)
|
| 19 |
+
DISABLE_BUTTON = gr.update(interactive=False)
|
| 20 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 21 |
+
|
| 22 |
+
ROLE_PROMPT = "You are an AI visual assistant situated in a 3D scene. "\
|
| 23 |
+
"You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\
|
| 24 |
+
"You should properly respond to the USER's instruction according to the given visual information. "
|
| 25 |
+
EGOVIEW_PROMPT = "Ego-view image:"
|
| 26 |
+
OBJECTS_PROMPT = "Objects (including you) in the scene:"
|
| 27 |
+
TASK_PROMPT = "USER: {instruction} ASSISTANT:"
|
| 28 |
+
OBJ_FEATS_DIR = 'assets/obj_features'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_agent():
|
| 32 |
+
# build model
|
| 33 |
+
with open('model/cfg.yaml') as f:
|
| 34 |
+
cfg = yaml.safe_load(f)
|
| 35 |
+
cfg = OmegaConf.create(cfg)
|
| 36 |
+
agent = LeoAgentLLM(cfg)
|
| 37 |
+
|
| 38 |
+
# load checkpoint
|
| 39 |
+
if cfg.use_ckpt == 'hf':
|
| 40 |
+
ckpt_path = hf_hub_download(cfg.hf_ckpt_path[0], cfg.hf_ckpt_path[1])
|
| 41 |
+
else:
|
| 42 |
+
ckpt_path = cfg.local_ckpt_path
|
| 43 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
| 44 |
+
agent.load_state_dict(ckpt, strict=False)
|
| 45 |
+
|
| 46 |
+
agent.eval()
|
| 47 |
+
agent.to(DEVICE)
|
| 48 |
+
return agent
|
| 49 |
+
|
| 50 |
+
agent = load_agent()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_log_fname():
|
| 54 |
+
t = datetime.datetime.now()
|
| 55 |
+
fname = os.path.join(LOG_DIR, f'{t.year}-{t.month:02d}-{t.day:02d}.json')
|
| 56 |
+
return fname
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def change_scene(dropdown_scene: str):
|
| 60 |
+
# reset 3D scene and chatbot history
|
| 61 |
+
return os.path.join(MESH_DIR, f'{dropdown_scene}.glb'), None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def receive_instruction(chatbot: gr.Chatbot, user_chat_input: gr.Textbox):
|
| 65 |
+
# display user input, after submitting user message, before inference
|
| 66 |
+
chatbot.append((user_chat_input, None))
|
| 67 |
+
return (chatbot, gr.update(value=""),) + (DISABLE_BUTTON,) * 5
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def generate_response(
|
| 71 |
+
chatbot: gr.Chatbot,
|
| 72 |
+
dropdown_scene: gr.Dropdown,
|
| 73 |
+
dropdown_conversation_mode: gr.Dropdown,
|
| 74 |
+
repetition_penalty: float, length_penalty: float
|
| 75 |
+
):
|
| 76 |
+
# response starts
|
| 77 |
+
chatbot[-1] = (chatbot[-1][0], "▌")
|
| 78 |
+
yield (chatbot,) + (DISABLE_BUTTON,) * 5
|
| 79 |
+
|
| 80 |
+
# create data_dict, batch_size = 1
|
| 81 |
+
data_dict = {
|
| 82 |
+
'prompt_before_obj': [ROLE_PROMPT],
|
| 83 |
+
'prompt_middle_1': [EGOVIEW_PROMPT],
|
| 84 |
+
'prompt_middle_2': [OBJECTS_PROMPT],
|
| 85 |
+
'img_tokens': torch.zeros(1, 1, 4096).float(),
|
| 86 |
+
'img_masks': torch.zeros(1, 1).bool(),
|
| 87 |
+
'anchor_locs': torch.zeros(1, 3).float(),
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# initialize prompt
|
| 91 |
+
prompt = ""
|
| 92 |
+
if 'Multi-round' in dropdown_conversation_mode:
|
| 93 |
+
# multi-round dialogue, with memory
|
| 94 |
+
for (q, a) in chatbot[:-1]:
|
| 95 |
+
prompt += f"USER: {q.strip()} ASSISTANT: {a.strip()}</s>"
|
| 96 |
+
|
| 97 |
+
prompt += f"USER: {chatbot[-1][0]} ASSISTANT:"
|
| 98 |
+
data_dict['prompt_after_obj'] = [prompt]
|
| 99 |
+
|
| 100 |
+
# anchor orientation
|
| 101 |
+
anchor_orient = torch.zeros(1, 4).float()
|
| 102 |
+
anchor_orient[:, -1] = 1
|
| 103 |
+
data_dict['anchor_orientation'] = anchor_orient
|
| 104 |
+
|
| 105 |
+
# load preprocessed scene features
|
| 106 |
+
data_dict.update(torch.load(os.path.join(OBJ_FEATS_DIR, f'{dropdown_scene}.pth'), map_location='cpu'))
|
| 107 |
+
|
| 108 |
+
# inference
|
| 109 |
+
for k, v in data_dict.items():
|
| 110 |
+
if isinstance(v, torch.Tensor):
|
| 111 |
+
data_dict[k] = v.to(DEVICE)
|
| 112 |
+
|
| 113 |
+
output = agent.generate(
|
| 114 |
+
data_dict,
|
| 115 |
+
repetition_penalty=float(repetition_penalty),
|
| 116 |
+
length_penalty=float(length_penalty),
|
| 117 |
+
)
|
| 118 |
+
output = output[0]
|
| 119 |
+
|
| 120 |
+
# display response
|
| 121 |
+
for out_len in range(1, len(output)-1):
|
| 122 |
+
chatbot[-1] = (chatbot[-1][0], output[:out_len] + '▌')
|
| 123 |
+
yield (chatbot,) + (DISABLE_BUTTON,) * 5
|
| 124 |
+
time.sleep(0.01)
|
| 125 |
+
|
| 126 |
+
chatbot[-1] = (chatbot[-1][0], output)
|
| 127 |
+
vote_response(chatbot, 'log', dropdown_scene, dropdown_conversation_mode)
|
| 128 |
+
yield (chatbot,) + (ENABLE_BUTTON,) * 5
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def vote_response(
|
| 132 |
+
chatbot: gr.Chatbot, vote_type: str,
|
| 133 |
+
dropdown_scene: gr.Dropdown,
|
| 134 |
+
dropdown_conversation_mode: gr.Dropdown
|
| 135 |
+
):
|
| 136 |
+
t = datetime.datetime.now()
|
| 137 |
+
this_log = {
|
| 138 |
+
'time': f'{t.hour:02d}:{t.minute:02d}:{t.second:02d}',
|
| 139 |
+
'type': vote_type,
|
| 140 |
+
'scene': dropdown_scene,
|
| 141 |
+
'mode': dropdown_conversation_mode,
|
| 142 |
+
'dialogue': chatbot,
|
| 143 |
+
}
|
| 144 |
+
fname = get_log_fname()
|
| 145 |
+
if os.path.exists(fname):
|
| 146 |
+
with open(fname) as f:
|
| 147 |
+
logs = json.load(f)
|
| 148 |
+
logs.append(this_log)
|
| 149 |
+
else:
|
| 150 |
+
logs = [this_log]
|
| 151 |
+
with open(fname, 'w') as f:
|
| 152 |
+
json.dump(logs, f, indent=2)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def upvote_response(
|
| 156 |
+
chatbot: gr.Chatbot,
|
| 157 |
+
dropdown_scene: gr.Dropdown,
|
| 158 |
+
dropdown_conversation_mode: gr.Dropdown
|
| 159 |
+
):
|
| 160 |
+
vote_response(chatbot, 'upvote', dropdown_scene, dropdown_conversation_mode)
|
| 161 |
+
return ("",) + (DISABLE_BUTTON,) * 3
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def downvote_response(
|
| 165 |
+
chatbot: gr.Chatbot,
|
| 166 |
+
dropdown_scene: gr.Dropdown,
|
| 167 |
+
dropdown_conversation_mode: gr.Dropdown
|
| 168 |
+
):
|
| 169 |
+
vote_response(chatbot, 'downvote', dropdown_scene, dropdown_conversation_mode)
|
| 170 |
+
return ("",) + (DISABLE_BUTTON,) * 3
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def flag_response(
|
| 174 |
+
chatbot: gr.Chatbot,
|
| 175 |
+
dropdown_scene: gr.Dropdown,
|
| 176 |
+
dropdown_conversation_mode: gr.Dropdown
|
| 177 |
+
):
|
| 178 |
+
vote_response(chatbot, 'flag', dropdown_scene, dropdown_conversation_mode)
|
| 179 |
+
return ("",) + (DISABLE_BUTTON,) * 3
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def clear_history():
|
| 183 |
+
# reset chatbot history
|
| 184 |
+
return (None, "",) + (DISABLE_BUTTON,) * 4
|