AvaLovelace commited on
Commit
b24c174
·
1 Parent(s): def3000

Add multiprocessing

Browse files
Files changed (1) hide show
  1. app.py +60 -42
app.py CHANGED
@@ -8,6 +8,7 @@ from dataclasses import fields
8
  from urllib.request import urlretrieve
9
 
10
  import gradio as gr
 
11
  import transformers
12
  from legogpt.models import LegoGPT, LegoGPTConfig
13
 
@@ -39,47 +40,8 @@ def main():
39
  print('Running in Hugging Face Space, setting up environment...')
40
  setup()
41
 
42
- model_cfg = LegoGPTConfig(max_regenerations=10)
43
- model = LegoGPT(model_cfg)
44
-
45
- def generate_lego(
46
- prompt: str,
47
- temperature: float | None,
48
- seed: int | None,
49
- max_bricks: int | None,
50
- max_brick_rejections: int | None,
51
- max_regenerations: int | None,
52
- ):
53
- # Set model parameters
54
- if temperature is not None: model.temperature = temperature
55
- if max_bricks is not None: model.max_bricks = max_bricks
56
- if max_brick_rejections is not None: model.max_brick_rejections = max_brick_rejections
57
- if max_regenerations is not None: model.max_regenerations = max_regenerations
58
- if seed is not None: transformers.set_seed(seed)
59
-
60
- # Generate LEGO
61
- print(f'Generating LEGO for prompt: "{prompt}"')
62
- start_time = time.time()
63
- output = model(prompt)
64
-
65
- # Write output LDR to file
66
- output_dir = os.path.abspath('out')
67
- output_uuid = str(uuid.uuid4())
68
- os.makedirs(output_dir, exist_ok=True)
69
- ldr_filename = os.path.join(output_dir, f'{output_uuid}.ldr')
70
- with open(ldr_filename, 'w') as f:
71
- f.write(output['lego'].to_ldr())
72
- print(f'Finished generation in {time.time() - start_time:.1f}s!')
73
-
74
- # Render LEGO model to image
75
- print('Rendering image...')
76
- start_time = time.time()
77
- img_filename = os.path.join(output_dir, f'{output_uuid}.png')
78
- subprocess.run(['python', 'render_lego.py', '--in_file', ldr_filename, '--out_file', img_filename],
79
- check=True) # Run render as a subprocess to prevent issues with Blender
80
- print(f'Finished rendering in {time.time() - start_time:.1f}s!')
81
-
82
- return img_filename, output['lego']
83
 
84
  # Define inputs and outputs
85
  in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a LEGO model.')
@@ -99,7 +61,7 @@ def main():
99
 
100
  # Define Gradio interface
101
  demo = gr.Interface(
102
- fn=generate_lego,
103
  title='LegoGPT Demo',
104
  description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.\n\n'
105
  'The model is restricted to creating structures made of 1-unit-tall cuboid bricks on a 20x20x20 grid. It was trained on a dataset of 21 object categories: '
@@ -123,9 +85,65 @@ def main():
123
  fn=lambda *args: (args[-1], examples[args[0]]['output_txt']),
124
  run_on_click=True,
125
  )
 
 
 
126
  demo.launch(share=True)
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def get_help_string(field_name: str) -> str:
130
  """
131
  :param field_name: Name of a field in LegoGPTConfig.
 
8
  from urllib.request import urlretrieve
9
 
10
  import gradio as gr
11
+ import torch.multiprocessing as mp
12
  import transformers
13
  from legogpt.models import LegoGPT, LegoGPTConfig
14
 
 
40
  print('Running in Hugging Face Space, setting up environment...')
41
  setup()
42
 
43
+ model_cfg = LegoGPTConfig(max_regenerations=5)
44
+ generator = LegoGenerator(LegoGPT(model_cfg))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Define inputs and outputs
47
  in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a LEGO model.')
 
61
 
62
  # Define Gradio interface
63
  demo = gr.Interface(
64
+ fn=generator.generate_lego_subprocess,
65
  title='LegoGPT Demo',
66
  description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.\n\n'
67
  'The model is restricted to creating structures made of 1-unit-tall cuboid bricks on a 20x20x20 grid. It was trained on a dataset of 21 object categories: '
 
85
  fn=lambda *args: (args[-1], examples[args[0]]['output_txt']),
86
  run_on_click=True,
87
  )
88
+
89
+ concurrency_limit = 2 if os.environ.get('CONCURRENCY_LIMIT') is None else int(os.environ.get('CONCURRENCY_LIMIT'))
90
+ demo.queue(default_concurrency_limit=concurrency_limit)
91
  demo.launch(share=True)
92
 
93
 
94
+ class LegoGenerator:
95
+ def __init__(self, model: LegoGPT):
96
+ self.model = model
97
+ self.ctx = mp.get_context('spawn')
98
+
99
+ def generate_lego(
100
+ self,
101
+ prompt: str,
102
+ temperature: float | None,
103
+ seed: int | None,
104
+ max_bricks: int | None,
105
+ max_brick_rejections: int | None,
106
+ max_regenerations: int | None,
107
+ ):
108
+ # Set model parameters
109
+ if temperature is not None: self.model.temperature = temperature
110
+ if max_bricks is not None: self.model.max_bricks = max_bricks
111
+ if max_brick_rejections is not None: self.model.max_brick_rejections = max_brick_rejections
112
+ if max_regenerations is not None: self.model.max_regenerations = max_regenerations
113
+ if seed is not None: transformers.set_seed(seed)
114
+
115
+ # Generate LEGO
116
+ print(f'Generating LEGO for prompt: "{prompt}"')
117
+ start_time = time.time()
118
+ output = self.model(prompt)
119
+
120
+ # Write output LDR to file
121
+ output_dir = os.path.abspath('out')
122
+ output_uuid = str(uuid.uuid4())
123
+ os.makedirs(output_dir, exist_ok=True)
124
+ ldr_filename = os.path.join(output_dir, f'{output_uuid}.ldr')
125
+ with open(ldr_filename, 'w') as f:
126
+ f.write(output['lego'].to_ldr())
127
+ print(f'Finished generation in {time.time() - start_time:.1f}s!')
128
+
129
+ # Render LEGO model to image
130
+ print('Rendering image...')
131
+ start_time = time.time()
132
+ img_filename = os.path.join(output_dir, f'{output_uuid}.png')
133
+ subprocess.run(['python', 'render_lego.py', '--in_file', ldr_filename, '--out_file', img_filename],
134
+ check=True) # Run render as a subprocess to prevent issues with Blender
135
+ print(f'Finished rendering in {time.time() - start_time:.1f}s!')
136
+
137
+ return img_filename, output['lego']
138
+
139
+ def generate_lego_subprocess(self, *args):
140
+ """
141
+ Run generation as a subprocess so that multiple requests can be handled concurrently.
142
+ """
143
+ with self.ctx.Pool(1) as pool:
144
+ return pool.starmap(self.generate_lego, [args])[0]
145
+
146
+
147
  def get_help_string(field_name: str) -> str:
148
  """
149
  :param field_name: Name of a field in LegoGPTConfig.