File size: 3,842 Bytes
ef334e8
 
e27d04f
ef334e8
4b77478
e27d04f
 
4b77478
e27d04f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b77478
 
e27d04f
 
4b77478
 
 
e27d04f
4b77478
 
e27d04f
ef334e8
4b77478
e27d04f
4b77478
 
0839bbf
4b77478
 
e27d04f
 
 
0839bbf
4b77478
e27d04f
4b77478
e27d04f
4b77478
e27d04f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557154a
 
 
e27d04f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import torch
import gc
import gradio as gr
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from diffusers import StableDiffusionPipeline
from huggingface_hub import login

# Initialize model once and reuse
def init_pipeline():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if device == "cuda" else torch.float32
    
    # Authenticate with HF token
    login(token=os.environ.get("HF_AUTH_TOKEN"))
    
    pipeline = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        torch_dtype=torch_dtype,
        use_auth_token=True
    ).to(device)
    
    # Optimize for performance
    if device == "cuda":
        pipeline.enable_attention_slicing()
        pipeline.enable_xformers_memory_efficient_attention()
    
    return pipeline, device

# Initialize pipeline at startup
pipe, device = init_pipeline()

# Define your original loss functions
def edge_loss(image_tensor):
    grayscale = image_tensor.mean(dim=0, keepdim=True)
    sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], 
                          device=device).float().unsqueeze(0).unsqueeze(0)
    sobel_y = sobel_x.transpose(2, 3)
    gx = F.conv2d(grayscale, sobel_x, padding=1)
    gy = F.conv2d(grayscale, sobel_y, padding=1)
    return -torch.mean(torch.sqrt(gx**2 + gy**2 + 1e-6))

def texture_loss(image_tensor):
    return F.mse_loss(image_tensor, torch.rand_like(image_tensor))

def entropy_loss(image_tensor):
    hist = torch.histc(image_tensor, bins=256, min=0, max=1)
    hist = hist / hist.sum()
    return -torch.sum(hist * torch.log(hist + 1e-7))

def symmetry_loss(image_tensor):
    width = image_tensor.shape[-1]
    left = image_tensor[..., :width//2]
    right = torch.flip(image_tensor[..., width//2:], dims=[-1])
    return F.mse_loss(left, right)

def contrast_loss(image_tensor):
    return -torch.std(image_tensor)

# Modified generation function
def generate_images(seed):
    # Clear memory before generation
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    # Generate base image
    generator = torch.Generator(device).manual_seed(int(seed))
    image = pipe(
        "A futuristic city skyline at sunset",
        generator=generator,
        num_inference_steps=50,
        guidance_scale=7.5
    ).images[0]

    # Process image
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize([0.5], [0.5])
    ])
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Calculate losses
    losses = {
        "Edge": edge_loss(image_tensor),
        "Texture": texture_loss(image_tensor),
        "Entropy": entropy_loss(image_tensor),
        "Symmetry": symmetry_loss(image_tensor),
        "Contrast": contrast_loss(image_tensor)
    }

    # Create thumbnail grid
    thumb = image.copy()
    thumb.thumbnail((256, 256))
    grid = Image.new('RGB', (256 * 2, 256))
    grid.paste(thumb, (0, 0))
    
    # Add visualization placeholder
    vis = Image.new('RGB', (256, 256), color='white')
    grid.paste(vis, (256, 0))

    return grid, "\n".join([f"{k}: {v.item():.4f}" for k,v in losses.items()])

# Gradio interface
def create_interface():
    return gr.Interface(
        fn=generate_images,
        inputs=gr.Number(label="Seed", value=42),
        outputs=[
            gr.Image(label="Generated Image & Visualizations"),
            gr.Textbox(label="Loss Values")
        ],
        title="Stable Diffusion Loss Analysis",
        description="Generate images and visualize different loss metrics"
    )

if __name__ == "__main__":
    import sys
    if sys.version_info < (3, 10):
        raise RuntimeError("Python 3.10 or later required")
    interface = create_interface()
    interface.launch(server_port=7860, share=False)