File size: 9,426 Bytes
def3000
 
5218da5
def3000
 
 
 
 
 
 
 
b24c174
def3000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b24c174
 
def3000
 
f75356d
5218da5
 
 
 
def3000
 
f75356d
 
def3000
 
 
 
 
 
 
 
 
 
 
 
 
b24c174
def3000
 
 
 
 
5218da5
def3000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b24c174
 
287f91c
def3000
 
 
b24c174
 
 
 
 
 
 
 
5218da5
b24c174
 
 
 
 
5218da5
b24c174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5218da5
 
b24c174
 
 
 
 
 
5218da5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b24c174
 
 
 
 
 
 
def3000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import json
import os
import shutil
import subprocess
import time
import uuid
import zipfile
from dataclasses import fields
from urllib.request import urlretrieve

import gradio as gr
import torch.multiprocessing as mp
import transformers
from legogpt.models import LegoGPT, LegoGPTConfig


def setup():
    # Set up Gurobi licence
    licence_filename = 'gurobi.lic'
    licence_lines = []
    for secret_name in ['WLSACCESSID', 'WLSSECRET', 'LICENSEID']:
        secret = os.environ.get(secret_name)
        if not secret:
            raise ValueError(f'Env variable {secret_name} not found. Please set it in the Hugging Face Space settings.')
        licence_lines.append(f'{secret_name}={secret}\n')
    with open(licence_filename, 'w') as f:
        f.writelines(licence_lines)
    os.environ['GRB_LICENSE_FILE'] = os.path.abspath(licence_filename)

    # Download LDraw part library and set LDraw library path
    ldraw_zip_url = 'https://library.ldraw.org/library/updates/complete.zip'
    ldraw_zip_filename = 'complete.zip'
    urlretrieve(ldraw_zip_url, ldraw_zip_filename)
    with zipfile.ZipFile(ldraw_zip_filename) as zip_ref:
        zip_ref.extractall()
    os.environ['LDRAW_LIBRARY_PATH'] = os.path.abspath('ldraw')


def main():
    if os.environ.get('IS_HF_SPACE') == '1':
        print('Running in Hugging Face Space, setting up environment...')
        setup()

    model_cfg = LegoGPTConfig(max_regenerations=5)
    generator = LegoGenerator(LegoGPT(model_cfg))

    # Define inputs and outputs
    in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a LEGO model.', max_length=500)
    in_optout = gr.Checkbox(label='Do not save my data',
                            info='We may collect model inputs and outputs to help us improve the model. '
                                 'Your data will never be shared or used for any other purpose. '
                                 'If you wish to opt out of data collection, please check the box below.')
    in_temperature = gr.Slider(0.01, 2.0, value=model_cfg.temperature, step=0.01,
                               label='Temperature', info=get_help_string('temperature'))
    in_seed = gr.Number(value=42, label='Seed', info='Random seed for generation.',
                        precision=0, minimum=0, maximum=2 ** 32 - 1, step=1)
    in_bricks = gr.Number(value=model_cfg.max_bricks, label='Max bricks', info=get_help_string('max_bricks'),
                          precision=0, minimum=1, step=1)
    in_rejections = gr.Number(value=model_cfg.max_brick_rejections, label='Max brick rejections',
                              info=get_help_string('max_brick_rejections'), precision=0, minimum=0, step=1)
    in_regenerations = gr.Number(value=model_cfg.max_regenerations, label='Max regenerations',
                                 info=get_help_string('max_regenerations'), precision=0, minimum=0, step=1)
    out_img = gr.Image(label='Output image', format='png')
    out_txt = gr.Textbox(label='Output LEGO bricks', lines=5, max_lines=5, show_copy_button=True,
                         info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a '
                              '1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).')

    # Define Gradio interface
    demo = gr.Interface(
        fn=generator.generate_lego_subprocess,
        title='LegoGPT Demo',
        description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.\n\n'
                    'The model is restricted to creating structures made of 1-unit-tall cuboid bricks on a 20x20x20 grid. It was trained on a dataset of 21 object categories: '
                    '*basket, bed, bench, birdhouse, bookshelf, bottle, bowl, bus, camera, car, chair, guitar, jar, mug, piano, pot, sofa, table, tower, train, vessel.* '
                    'Performance on prompts from outside these categories may be limited. This demo does not include texturing or coloring.',
        inputs=[in_prompt, in_optout],
        additional_inputs=[in_temperature, in_seed, in_bricks, in_rejections, in_regenerations],
        outputs=[out_img, out_txt],
        flagging_mode='never',
    )
    with demo:
        with gr.Row():
            examples = get_examples()
            dummy_name = gr.Textbox(visible=False, label='Name')
            dummy_out_img = gr.Image(visible=False, label='Result')
            gr.Examples(
                examples=[[name, example['prompt'], example['temperature'], example['seed'], example['output_img']]
                          for name, example in examples.items()],
                inputs=[dummy_name, in_prompt, in_temperature, in_seed, dummy_out_img],
                outputs=[out_img, out_txt],
                fn=lambda *args: (args[-1], examples[args[0]]['output_txt']),
                run_on_click=True,
            )

    concurrency_limit = 2 if os.environ.get('CONCURRENCY_LIMIT') is None else int(os.environ.get('CONCURRENCY_LIMIT'))
    demo.queue(default_concurrency_limit=concurrency_limit)
    demo.launch(share=True)


class LegoGenerator:
    def __init__(self, model: LegoGPT):
        self.model = model
        self.ctx = mp.get_context('spawn')

    def generate_lego(
            self,
            prompt: str,
            do_not_save_data: bool,
            temperature: float | None,
            seed: int | None,
            max_bricks: int | None,
            max_brick_rejections: int | None,
            max_regenerations: int | None,
    ) -> (str, str):
        # Set model parameters
        if temperature is not None: self.model.temperature = temperature
        if max_bricks is not None: self.model.max_bricks = max_bricks
        if max_brick_rejections is not None: self.model.max_brick_rejections = max_brick_rejections
        if max_regenerations is not None: self.model.max_regenerations = max_regenerations
        if seed is not None: transformers.set_seed(seed)

        # Generate LEGO
        print(f'Generating LEGO for prompt: "{prompt}"')
        start_time = time.time()
        output = self.model(prompt)

        # Write output LDR to file
        output_dir = os.path.abspath('out')
        output_uuid = str(uuid.uuid4())
        os.makedirs(output_dir, exist_ok=True)
        ldr_filename = os.path.join(output_dir, f'{output_uuid}.ldr')
        with open(ldr_filename, 'w') as f:
            f.write(output['lego'].to_ldr())
        generation_time = time.time() - start_time
        print(f'Finished generation in {generation_time:.1f}s!')

        # Render LEGO model to image
        print('Rendering image...')
        img_filename = os.path.join(output_dir, f'{output_uuid}.png')
        subprocess.run(['python', 'render_lego.py', '--in_file', ldr_filename, '--out_file', img_filename],
                       check=True)  # Run render as a subprocess to prevent issues with Blender
        rendering_time = time.time() - start_time - generation_time
        print(f'Finished rendering in {rendering_time:.1f}s!')

        # Save data persistently
        if not do_not_save_data:
            data_dir = '/data/apun/legogpt_demo_out'
            os.makedirs(data_dir, exist_ok=True)

            # Copy output image to persistent storage
            img_copy_filename = os.path.join(data_dir, f'{output_uuid}.png')
            shutil.copy(img_filename, img_copy_filename)

            # Save metadata
            metadata_filename = os.path.join(data_dir, f'{output_uuid}.json')
            with open(metadata_filename, 'w') as f:
                json.dump({
                    'prompt': prompt,
                    'temperature': self.model.temperature,
                    'seed': seed,
                    'max_bricks': self.model.max_bricks,
                    'max_brick_rejections': self.model.max_brick_rejections,
                    'max_regenerations': self.model.max_regenerations,
                    'start_timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
                    'generation_time': generation_time,
                    'rendering_time': rendering_time,
                    'output_txt': output['lego'].to_txt(),
                }, f)
            print(f'Saved data to {metadata_filename}.')

        return img_filename, output['lego'].to_txt()

    def generate_lego_subprocess(self, *args) -> (str, str):
        """
        Run generation as a subprocess so that multiple requests can be handled concurrently.
        """
        with self.ctx.Pool(1) as pool:
            return pool.starmap(self.generate_lego, [args])[0]


def get_help_string(field_name: str) -> str:
    """
    :param field_name: Name of a field in LegoGPTConfig.
    :return: Help string for the field.
    """
    data_fields = fields(LegoGPTConfig)
    name_field = next(f for f in data_fields if f.name == field_name)
    return name_field.metadata['help']


def get_examples(example_dir: str = os.path.abspath('examples')) -> dict[str, dict[str, str]]:
    examples_file = os.path.join(example_dir, 'examples.json')
    with open(examples_file) as f:
        examples = json.load(f)

    for example in examples.values():
        example['output_img'] = os.path.join(example_dir, example['output_img'])
    return examples


if __name__ == '__main__':
    main()