jbilcke-hf HF Staff commited on
Commit
216ab8e
·
verified ·
1 Parent(s): b4d5059

Update app_legacy.py

Browse files
Files changed (1) hide show
  1. app_legacy.py +105 -1
app_legacy.py CHANGED
@@ -1 +1,105 @@
1
- ..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import cv2
8
+ from diffusers import DiffusionPipeline
9
+ from script import SatelliteModelGenerator
10
+
11
+ # Initialize models and device
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype = torch.bfloat16
14
+
15
+ repo_id = "black-forest-labs/FLUX.1-dev"
16
+ adapter_id = "jbilcke-hf/flux-satellite"
17
+
18
+ flux_pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
19
+ flux_pipe.load_lora_weights(adapter_id)
20
+ flux_pipe = flux_pipe.to(device)
21
+
22
+ def generate_and_process_map(prompt: str) -> str | None:
23
+ """Generate satellite image from prompt and convert to 3D model."""
24
+ try:
25
+ # Set dimensions
26
+ width = height = 1024
27
+
28
+ # Generate random seed
29
+ seed = np.random.randint(0, np.iinfo(np.int32).max)
30
+
31
+ # Set random seeds
32
+ torch.manual_seed(seed)
33
+ np.random.seed(seed)
34
+
35
+ # Generate satellite image using FLUX
36
+ generator = torch.Generator(device=device).manual_seed(seed)
37
+ generated_image = flux_pipe(
38
+ prompt=prompt,
39
+ width=width,
40
+ height=height,
41
+ num_inference_steps=30,
42
+ generator=generator,
43
+ guidance_scale=7.5
44
+ ).images[0]
45
+
46
+ # Convert PIL Image to OpenCV format
47
+ cv_image = cv2.cvtColor(np.array(generated_image), cv2.COLOR_RGB2BGR)
48
+
49
+ # Initialize SatelliteModelGenerator
50
+ generator = SatelliteModelGenerator(building_height=0.09)
51
+
52
+ # Process image
53
+ print("Segmenting image...")
54
+ segmented_img = generator.segment_image(cv_image, window_size=5)
55
+
56
+ print("Estimating heights...")
57
+ height_map = generator.estimate_heights(cv_image, segmented_img)
58
+
59
+ # Generate mesh
60
+ print("Generating mesh...")
61
+ mesh = generator.generate_mesh(height_map, cv_image, add_walls=True)
62
+
63
+ # Export to GLB
64
+ temp_dir = tempfile.mkdtemp()
65
+ output_path = os.path.join(temp_dir, 'output.glb')
66
+ mesh.export(output_path)
67
+
68
+ return output_path
69
+
70
+ except Exception as e:
71
+ print(f"Error during generation: {str(e)}")
72
+ import traceback
73
+ traceback.print_exc()
74
+ return None
75
+
76
+ # Create Gradio interface
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("# Text to Map")
79
+ gr.Markdown("Generate 3D maps from text descriptions using FLUX and mesh generation.")
80
+
81
+ with gr.Row():
82
+ prompt_input = gr.Text(
83
+ label="Enter your prompt",
84
+ placeholder="eg. satellite view of downtown Manhattan"
85
+ )
86
+
87
+ with gr.Row():
88
+ generate_btn = gr.Button("Generate", variant="primary")
89
+
90
+ with gr.Row():
91
+ model_output = gr.Model3D(
92
+ label="Generated 3D Map",
93
+ clear_color=[0.0, 0.0, 0.0, 0.0],
94
+ )
95
+
96
+ # Event handler
97
+ generate_btn.click(
98
+ fn=generate_and_process_map,
99
+ inputs=[prompt_input],
100
+ outputs=[model_output],
101
+ api_name="generate"
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+ demo.queue().launch()