Spaces:
Running
on
A100
Running
on
A100
File size: 9,426 Bytes
def3000 5218da5 def3000 b24c174 def3000 b24c174 def3000 f75356d 5218da5 def3000 f75356d def3000 b24c174 def3000 5218da5 def3000 b24c174 287f91c def3000 b24c174 5218da5 b24c174 5218da5 b24c174 5218da5 b24c174 5218da5 b24c174 def3000 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import json
import os
import shutil
import subprocess
import time
import uuid
import zipfile
from dataclasses import fields
from urllib.request import urlretrieve
import gradio as gr
import torch.multiprocessing as mp
import transformers
from legogpt.models import LegoGPT, LegoGPTConfig
def setup():
# Set up Gurobi licence
licence_filename = 'gurobi.lic'
licence_lines = []
for secret_name in ['WLSACCESSID', 'WLSSECRET', 'LICENSEID']:
secret = os.environ.get(secret_name)
if not secret:
raise ValueError(f'Env variable {secret_name} not found. Please set it in the Hugging Face Space settings.')
licence_lines.append(f'{secret_name}={secret}\n')
with open(licence_filename, 'w') as f:
f.writelines(licence_lines)
os.environ['GRB_LICENSE_FILE'] = os.path.abspath(licence_filename)
# Download LDraw part library and set LDraw library path
ldraw_zip_url = 'https://library.ldraw.org/library/updates/complete.zip'
ldraw_zip_filename = 'complete.zip'
urlretrieve(ldraw_zip_url, ldraw_zip_filename)
with zipfile.ZipFile(ldraw_zip_filename) as zip_ref:
zip_ref.extractall()
os.environ['LDRAW_LIBRARY_PATH'] = os.path.abspath('ldraw')
def main():
if os.environ.get('IS_HF_SPACE') == '1':
print('Running in Hugging Face Space, setting up environment...')
setup()
model_cfg = LegoGPTConfig(max_regenerations=5)
generator = LegoGenerator(LegoGPT(model_cfg))
# Define inputs and outputs
in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a LEGO model.', max_length=500)
in_optout = gr.Checkbox(label='Do not save my data',
info='We may collect model inputs and outputs to help us improve the model. '
'Your data will never be shared or used for any other purpose. '
'If you wish to opt out of data collection, please check the box below.')
in_temperature = gr.Slider(0.01, 2.0, value=model_cfg.temperature, step=0.01,
label='Temperature', info=get_help_string('temperature'))
in_seed = gr.Number(value=42, label='Seed', info='Random seed for generation.',
precision=0, minimum=0, maximum=2 ** 32 - 1, step=1)
in_bricks = gr.Number(value=model_cfg.max_bricks, label='Max bricks', info=get_help_string('max_bricks'),
precision=0, minimum=1, step=1)
in_rejections = gr.Number(value=model_cfg.max_brick_rejections, label='Max brick rejections',
info=get_help_string('max_brick_rejections'), precision=0, minimum=0, step=1)
in_regenerations = gr.Number(value=model_cfg.max_regenerations, label='Max regenerations',
info=get_help_string('max_regenerations'), precision=0, minimum=0, step=1)
out_img = gr.Image(label='Output image', format='png')
out_txt = gr.Textbox(label='Output LEGO bricks', lines=5, max_lines=5, show_copy_button=True,
info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a '
'1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).')
# Define Gradio interface
demo = gr.Interface(
fn=generator.generate_lego_subprocess,
title='LegoGPT Demo',
description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.\n\n'
'The model is restricted to creating structures made of 1-unit-tall cuboid bricks on a 20x20x20 grid. It was trained on a dataset of 21 object categories: '
'*basket, bed, bench, birdhouse, bookshelf, bottle, bowl, bus, camera, car, chair, guitar, jar, mug, piano, pot, sofa, table, tower, train, vessel.* '
'Performance on prompts from outside these categories may be limited. This demo does not include texturing or coloring.',
inputs=[in_prompt, in_optout],
additional_inputs=[in_temperature, in_seed, in_bricks, in_rejections, in_regenerations],
outputs=[out_img, out_txt],
flagging_mode='never',
)
with demo:
with gr.Row():
examples = get_examples()
dummy_name = gr.Textbox(visible=False, label='Name')
dummy_out_img = gr.Image(visible=False, label='Result')
gr.Examples(
examples=[[name, example['prompt'], example['temperature'], example['seed'], example['output_img']]
for name, example in examples.items()],
inputs=[dummy_name, in_prompt, in_temperature, in_seed, dummy_out_img],
outputs=[out_img, out_txt],
fn=lambda *args: (args[-1], examples[args[0]]['output_txt']),
run_on_click=True,
)
concurrency_limit = 2 if os.environ.get('CONCURRENCY_LIMIT') is None else int(os.environ.get('CONCURRENCY_LIMIT'))
demo.queue(default_concurrency_limit=concurrency_limit)
demo.launch(share=True)
class LegoGenerator:
def __init__(self, model: LegoGPT):
self.model = model
self.ctx = mp.get_context('spawn')
def generate_lego(
self,
prompt: str,
do_not_save_data: bool,
temperature: float | None,
seed: int | None,
max_bricks: int | None,
max_brick_rejections: int | None,
max_regenerations: int | None,
) -> (str, str):
# Set model parameters
if temperature is not None: self.model.temperature = temperature
if max_bricks is not None: self.model.max_bricks = max_bricks
if max_brick_rejections is not None: self.model.max_brick_rejections = max_brick_rejections
if max_regenerations is not None: self.model.max_regenerations = max_regenerations
if seed is not None: transformers.set_seed(seed)
# Generate LEGO
print(f'Generating LEGO for prompt: "{prompt}"')
start_time = time.time()
output = self.model(prompt)
# Write output LDR to file
output_dir = os.path.abspath('out')
output_uuid = str(uuid.uuid4())
os.makedirs(output_dir, exist_ok=True)
ldr_filename = os.path.join(output_dir, f'{output_uuid}.ldr')
with open(ldr_filename, 'w') as f:
f.write(output['lego'].to_ldr())
generation_time = time.time() - start_time
print(f'Finished generation in {generation_time:.1f}s!')
# Render LEGO model to image
print('Rendering image...')
img_filename = os.path.join(output_dir, f'{output_uuid}.png')
subprocess.run(['python', 'render_lego.py', '--in_file', ldr_filename, '--out_file', img_filename],
check=True) # Run render as a subprocess to prevent issues with Blender
rendering_time = time.time() - start_time - generation_time
print(f'Finished rendering in {rendering_time:.1f}s!')
# Save data persistently
if not do_not_save_data:
data_dir = '/data/apun/legogpt_demo_out'
os.makedirs(data_dir, exist_ok=True)
# Copy output image to persistent storage
img_copy_filename = os.path.join(data_dir, f'{output_uuid}.png')
shutil.copy(img_filename, img_copy_filename)
# Save metadata
metadata_filename = os.path.join(data_dir, f'{output_uuid}.json')
with open(metadata_filename, 'w') as f:
json.dump({
'prompt': prompt,
'temperature': self.model.temperature,
'seed': seed,
'max_bricks': self.model.max_bricks,
'max_brick_rejections': self.model.max_brick_rejections,
'max_regenerations': self.model.max_regenerations,
'start_timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
'generation_time': generation_time,
'rendering_time': rendering_time,
'output_txt': output['lego'].to_txt(),
}, f)
print(f'Saved data to {metadata_filename}.')
return img_filename, output['lego'].to_txt()
def generate_lego_subprocess(self, *args) -> (str, str):
"""
Run generation as a subprocess so that multiple requests can be handled concurrently.
"""
with self.ctx.Pool(1) as pool:
return pool.starmap(self.generate_lego, [args])[0]
def get_help_string(field_name: str) -> str:
"""
:param field_name: Name of a field in LegoGPTConfig.
:return: Help string for the field.
"""
data_fields = fields(LegoGPTConfig)
name_field = next(f for f in data_fields if f.name == field_name)
return name_field.metadata['help']
def get_examples(example_dir: str = os.path.abspath('examples')) -> dict[str, dict[str, str]]:
examples_file = os.path.join(example_dir, 'examples.json')
with open(examples_file) as f:
examples = json.load(f)
for example in examples.values():
example['output_img'] = os.path.join(example_dir, example['output_img'])
return examples
if __name__ == '__main__':
main()
|