Spaces:
Sleeping
Sleeping
File size: 1,677 Bytes
1ac39da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import gradio as gr
import torch
from diffusers import FluxPipeline
# Function to load the model and list layers
def list_flux_layers():
try:
# Load the FLUX.1-dev model
# Using torch.bfloat16 to reduce memory usage, offloading to CPU if needed
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload() # Offload to CPU to save VRAM
# Access the transformer (main component of Flux) and get layer names
model = pipe.transformer # Flux's core transformer model
layer_names = []
# Iterate through all named modules to get layer names
for name, module in model.named_modules():
layer_names.append(name)
# Format the output as a numbered list
output = "\n".join([f"{i+1}. {name}" for i, name in enumerate(layer_names)])
return f"Layers of FLUX.1-dev:\n\n{output}"
except Exception as e:
return f"Error loading model or listing layers: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="FLUX.1-dev Layer Lister") as demo:
gr.Markdown("# FLUX.1-dev Layer Lister")
gr.Markdown("Click the button below to list all layers in the black-forest-labs/FLUX.1-dev model.")
# Button to trigger the layer listing
btn = gr.Button("List Layers")
# Output area for the layer names
output = gr.Textbox(label="Model Layers", lines=20, placeholder="Layer names will appear here...")
# Connect the button to the function
btn.click(fn=list_flux_layers, inputs=None, outputs=output)
# Launch the app
demo.launch() |