|
import json |
|
import os |
|
import subprocess |
|
import time |
|
import uuid |
|
import zipfile |
|
from dataclasses import fields |
|
from urllib.request import urlretrieve |
|
|
|
import gradio as gr |
|
import torch.multiprocessing as mp |
|
import transformers |
|
from legogpt.models import LegoGPT, LegoGPTConfig |
|
|
|
|
|
def setup(): |
|
|
|
licence_filename = 'gurobi.lic' |
|
licence_lines = [] |
|
for secret_name in ['WLSACCESSID', 'WLSSECRET', 'LICENSEID']: |
|
secret = os.environ.get(secret_name) |
|
if not secret: |
|
raise ValueError(f'Env variable {secret_name} not found. Please set it in the Hugging Face Space settings.') |
|
licence_lines.append(f'{secret_name}={secret}\n') |
|
with open(licence_filename, 'w') as f: |
|
f.writelines(licence_lines) |
|
os.environ['GRB_LICENSE_FILE'] = os.path.abspath(licence_filename) |
|
|
|
|
|
ldraw_zip_url = 'https://library.ldraw.org/library/updates/complete.zip' |
|
ldraw_zip_filename = 'complete.zip' |
|
urlretrieve(ldraw_zip_url, ldraw_zip_filename) |
|
with zipfile.ZipFile(ldraw_zip_filename) as zip_ref: |
|
zip_ref.extractall() |
|
os.environ['LDRAW_LIBRARY_PATH'] = os.path.abspath('ldraw') |
|
|
|
|
|
def main(): |
|
if os.environ.get('IS_HF_SPACE') == '1': |
|
print('Running in Hugging Face Space, setting up environment...') |
|
setup() |
|
|
|
model_cfg = LegoGPTConfig(max_regenerations=5) |
|
generator = LegoGenerator(LegoGPT(model_cfg)) |
|
|
|
|
|
in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a LEGO model.') |
|
in_temperature = gr.Slider(0.01, 2.0, value=model_cfg.temperature, step=0.01, |
|
label='Temperature', info=get_help_string('temperature')) |
|
in_seed = gr.Number(value=42, label='Seed', info='Random seed for generation.', precision=0, step=1) |
|
in_bricks = gr.Number(value=model_cfg.max_bricks, label='Max bricks', info=get_help_string('max_bricks'), |
|
precision=0, minimum=1, step=1) |
|
in_rejections = gr.Number(value=model_cfg.max_brick_rejections, label='Max brick rejections', |
|
info=get_help_string('max_brick_rejections'), precision=0, minimum=0, step=1) |
|
in_regenerations = gr.Number(value=model_cfg.max_regenerations, label='Max regenerations', |
|
info=get_help_string('max_regenerations'), precision=0, minimum=0, step=1) |
|
out_img = gr.Image(label='Output image', format='png') |
|
out_txt = gr.Textbox(label='Output LEGO bricks', lines=5, max_lines=5, show_copy_button=True, |
|
info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a ' |
|
'1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).') |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generator.generate_lego_subprocess, |
|
title='LegoGPT Demo', |
|
description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.\n\n' |
|
'The model is restricted to creating structures made of 1-unit-tall cuboid bricks on a 20x20x20 grid. It was trained on a dataset of 21 object categories: ' |
|
'*basket, bed, bench, birdhouse, bookshelf, bottle, bowl, bus, camera, car, chair, guitar, jar, mug, piano, pot, sofa, table, tower, train, vessel.* ' |
|
'Performance on prompts from outside these categories may be limited. This demo does not include texturing or coloring.', |
|
inputs=[in_prompt], |
|
additional_inputs=[in_temperature, in_seed, in_bricks, in_rejections, in_regenerations], |
|
outputs=[out_img, out_txt], |
|
flagging_mode='never', |
|
) |
|
with demo: |
|
with gr.Row(): |
|
examples = get_examples() |
|
dummy_name = gr.Textbox(visible=False, label='Name') |
|
dummy_out_img = gr.Image(visible=False, label='Result') |
|
gr.Examples( |
|
examples=[[name, example['prompt'], example['temperature'], example['seed'], example['output_img']] |
|
for name, example in examples.items()], |
|
inputs=[dummy_name, in_prompt, in_temperature, in_seed, dummy_out_img], |
|
outputs=[out_img, out_txt], |
|
fn=lambda *args: (args[-1], examples[args[0]]['output_txt']), |
|
run_on_click=True, |
|
) |
|
|
|
concurrency_limit = 2 if os.environ.get('CONCURRENCY_LIMIT') is None else int(os.environ.get('CONCURRENCY_LIMIT')) |
|
demo.queue(default_concurrency_limit=concurrency_limit) |
|
demo.launch(share=True) |
|
|
|
|
|
class LegoGenerator: |
|
def __init__(self, model: LegoGPT): |
|
self.model = model |
|
self.ctx = mp.get_context('spawn') |
|
|
|
def generate_lego( |
|
self, |
|
prompt: str, |
|
temperature: float | None, |
|
seed: int | None, |
|
max_bricks: int | None, |
|
max_brick_rejections: int | None, |
|
max_regenerations: int | None, |
|
): |
|
|
|
if temperature is not None: self.model.temperature = temperature |
|
if max_bricks is not None: self.model.max_bricks = max_bricks |
|
if max_brick_rejections is not None: self.model.max_brick_rejections = max_brick_rejections |
|
if max_regenerations is not None: self.model.max_regenerations = max_regenerations |
|
if seed is not None: transformers.set_seed(seed) |
|
|
|
|
|
print(f'Generating LEGO for prompt: "{prompt}"') |
|
start_time = time.time() |
|
output = self.model(prompt) |
|
|
|
|
|
output_dir = os.path.abspath('out') |
|
output_uuid = str(uuid.uuid4()) |
|
os.makedirs(output_dir, exist_ok=True) |
|
ldr_filename = os.path.join(output_dir, f'{output_uuid}.ldr') |
|
with open(ldr_filename, 'w') as f: |
|
f.write(output['lego'].to_ldr()) |
|
print(f'Finished generation in {time.time() - start_time:.1f}s!') |
|
|
|
|
|
print('Rendering image...') |
|
start_time = time.time() |
|
img_filename = os.path.join(output_dir, f'{output_uuid}.png') |
|
subprocess.run(['python', 'render_lego.py', '--in_file', ldr_filename, '--out_file', img_filename], |
|
check=True) |
|
print(f'Finished rendering in {time.time() - start_time:.1f}s!') |
|
|
|
return img_filename, output['lego'] |
|
|
|
def generate_lego_subprocess(self, *args): |
|
""" |
|
Run generation as a subprocess so that multiple requests can be handled concurrently. |
|
""" |
|
with self.ctx.Pool(1) as pool: |
|
return pool.starmap(self.generate_lego, [args])[0] |
|
|
|
|
|
def get_help_string(field_name: str) -> str: |
|
""" |
|
:param field_name: Name of a field in LegoGPTConfig. |
|
:return: Help string for the field. |
|
""" |
|
data_fields = fields(LegoGPTConfig) |
|
name_field = next(f for f in data_fields if f.name == field_name) |
|
return name_field.metadata['help'] |
|
|
|
|
|
def get_examples(example_dir: str = os.path.abspath('examples')) -> dict[str, dict[str, str]]: |
|
examples_file = os.path.join(example_dir, 'examples.json') |
|
with open(examples_file) as f: |
|
examples = json.load(f) |
|
|
|
for example in examples.values(): |
|
example['output_img'] = os.path.join(example_dir, example['output_img']) |
|
return examples |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|