Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from diffusers import FluxPipeline | |
# Function to load the model and list layers | |
def list_flux_layers(): | |
try: | |
# Load the FLUX.1-dev model | |
# Using torch.bfloat16 to reduce memory usage, offloading to CPU if needed | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.enable_model_cpu_offload() # Offload to CPU to save VRAM | |
# Access the transformer (main component of Flux) and get layer names | |
model = pipe.transformer # Flux's core transformer model | |
layer_names = [] | |
# Iterate through all named modules to get layer names | |
for name, module in model.named_modules(): | |
layer_names.append(name) | |
# Format the output as a numbered list | |
output = "\n".join([f"{i+1}. {name}" for i, name in enumerate(layer_names)]) | |
return f"Layers of FLUX.1-dev:\n\n{output}" | |
except Exception as e: | |
return f"Error loading model or listing layers: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="FLUX.1-dev Layer Lister") as demo: | |
gr.Markdown("# FLUX.1-dev Layer Lister") | |
gr.Markdown("Click the button below to list all layers in the black-forest-labs/FLUX.1-dev model.") | |
# Button to trigger the layer listing | |
btn = gr.Button("List Layers") | |
# Output area for the layer names | |
output = gr.Textbox(label="Model Layers", lines=20, placeholder="Layer names will appear here...") | |
# Connect the button to the function | |
btn.click(fn=list_flux_layers, inputs=None, outputs=output) | |
# Launch the app | |
demo.launch() |