Spaces:
Running
Running
File size: 3,499 Bytes
139d7b2 827021c 2e786fb 1dc389e 827021c 57fc91e 1dc389e 827021c 1dc389e 827021c fdbc146 1dc389e e49c48c 1dc389e 827021c fdbc146 2e786fb 1dc389e 827021c fdbc146 d288725 827021c d288725 fdbc146 2e786fb fdbc146 d288725 2e786fb fdbc146 2e786fb fdbc146 d288725 2e786fb d288725 827021c 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 1dc389e fdbc146 8088244 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e cf8fd7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
import sys
import os
from PIL import Image
import torchvision.transforms as transforms
from safetensors.torch import load_file
photos_folder = "Photos"
# Download model files
repo_id = "Kiwinicki/sat2map-generator"
model_path = hf_hub_download(repo_id=repo_id, filename="generator.safetensors")
generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
# Add path to model
sys.path.append(os.path.dirname(generator_code_path))
from model import Generator, GeneratorConfig
# Initialize configuration
cfg = GeneratorConfig(
channels=3,
num_features=64,
num_residuals=12,
depth=4
)
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(cfg).to(device)
state_dict = load_file(model_path)
generator.load_state_dict(state_dict)
generator.eval()
# Transformations
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def process_image(image):
if image is None:
return None
# Convert to tensor
image_tensor = transform(image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output_tensor = generator(image_tensor)
# Prepare output
output_image = output_tensor.squeeze(0).cpu()
output_image = output_image * 0.5 + 0.5 # Denormalization
output_image = transforms.ToPILImage()(output_image)
return output_image
def load_images_from_folder(folder):
images = []
if not os.path.exists(folder):
os.makedirs(folder)
return images
for filename in os.listdir(folder):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(folder, filename)
try:
img = Image.open(img_path)
images.append((img, filename))
except Exception as e:
print(f"Error loading {filename}: {e}")
return images
def app():
images = load_images_from_folder(photos_folder)
gallery_images = [img[0] for img in images] if images else []
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
clear_button = gr.Button("Clear")
with gr.Column():
gallery = gr.Gallery(
label="Image Gallery",
value=gallery_images,
columns=3,
rows=2,
height="auto"
)
with gr.Column():
output_image = gr.Image(label="Result Image", type="pil")
# Handle gallery selection
def on_select(evt: gr.SelectData):
if 0 <= evt.index < len(images):
return images[evt.index][0]
return None
gallery.select(
fn=on_select,
outputs=input_image
)
# Process image when input changes
input_image.change(
fn=process_image,
inputs=input_image,
outputs=output_image
)
# Clear button functionality
clear_button.click(
fn=lambda: None,
outputs=input_image
)
demo.launch()
if __name__ == "__main__":
app() |