import os import tempfile import torch import numpy as np import gradio as gr from PIL import Image import cv2 from diffusers import DiffusionPipeline from script import SatelliteModelGenerator # Initialize models and device device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 repo_id = "black-forest-labs/FLUX.1-dev" adapter_id = "jbilcke-hf/flux-satellite" flux_pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16) flux_pipe.load_lora_weights(adapter_id) flux_pipe = pipeline.to(device) def generate_and_process_map(prompt: str) -> str | None: """Generate satellite image from prompt and convert to 3D model.""" try: # Set dimensions width = height = 1024 # Generate random seed seed = np.random.randint(0, np.iinfo(np.int32).max) # Set random seeds torch.manual_seed(seed) np.random.seed(seed) # Generate satellite image using FLUX generator = torch.Generator(device=device).manual_seed(seed) generated_image = flux_pipe( prompt=prompt, width=width, height=height, num_inference_steps=30, generator=generator, guidance_scale=7.5 ).images[0] # Convert PIL Image to OpenCV format cv_image = cv2.cvtColor(np.array(generated_image), cv2.COLOR_RGB2BGR) # Initialize SatelliteModelGenerator generator = SatelliteModelGenerator(building_height=0.09) # Process image print("Segmenting image...") segmented_img = generator.segment_image(cv_image, window_size=5) print("Estimating heights...") height_map = generator.estimate_heights(cv_image, segmented_img) # Generate mesh print("Generating mesh...") mesh = generator.generate_mesh(height_map, cv_image, add_walls=True) # Export to GLB temp_dir = tempfile.mkdtemp() output_path = os.path.join(temp_dir, 'output.glb') mesh.export(output_path) return output_path except Exception as e: print(f"Error during generation: {str(e)}") import traceback traceback.print_exc() return None # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Text to Map") gr.Markdown("Generate 3D maps from text descriptions using FLUX and mesh generation.") with gr.Row(): prompt_input = gr.Text( label="Enter your prompt", placeholder="eg. satellite view of downtown Manhattan" ) with gr.Row(): generate_btn = gr.Button("Generate", variant="primary") with gr.Row(): model_output = gr.Model3D( label="Generated 3D Map", clear_color=[0.0, 0.0, 0.0, 0.0], ) # Event handler generate_btn.click( fn=generate_and_process_map, inputs=[prompt_input], outputs=[model_output], api_name="generate" ) if __name__ == "__main__": demo.queue().launch()