Spaces:
Running
Running
File size: 3,320 Bytes
139d7b2 827021c 2e786fb 35a938f 827021c 57fc91e 35a938f 827021c 35a938f 827021c fdbc146 35a938f 1dc389e e49c48c fdbc146 35a938f 2e786fb 35a938f 827021c fdbc146 d288725 827021c 35a938f d288725 fdbc146 2e786fb fdbc146 d288725 2e786fb fdbc146 2e786fb fdbc146 d288725 827021c 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 1dc389e fdbc146 8088244 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e cf8fd7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
import sys
import os
from PIL import Image
import torchvision.transforms as transforms
photos_folder = "Photos"
# Download model and config
repo_id = "Kiwinicki/sat2map-generator"
generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
# Add path to model
sys.path.append(os.path.dirname(model_path))
from model import Generator, GeneratorConfig
# Initialize model
cfg = GeneratorConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(cfg).to(device)
generator.load_state_dict(torch.load(generator_path))
generator.eval()
# Transformations
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def process_image(image):
if image is None:
return None
# Convert to tensor
image_tensor = transform(image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output_tensor = generator(image_tensor)
# Prepare output
output_image = output_tensor.squeeze(0).cpu()
output_image = output_image * 0.5 + 0.5 # Denormalization
output_image = transforms.ToPILImage()(output_image)
return output_image
def load_images_from_folder(folder):
images = []
if not os.path.exists(folder):
os.makedirs(folder)
return images
for filename in os.listdir(folder):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(folder, filename)
try:
img = Image.open(img_path)
images.append((img, filename))
except Exception as e:
print(f"Error loading {filename}: {e}")
return images
def app():
images = load_images_from_folder(photos_folder)
gallery_images = [img[0] for img in images] if images else []
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
clear_button = gr.Button("Clear")
with gr.Column():
gallery = gr.Gallery(
label="Image Gallery",
value=gallery_images,
columns=3,
rows=2,
height="auto"
)
with gr.Column():
output_image = gr.Image(label="Result Image", type="pil")
# Handle gallery selection
def on_select(evt: gr.SelectData):
if 0 <= evt.index < len(images):
return images[evt.index][0]
return None
gallery.select(
fn=on_select,
outputs=input_image
)
# Process image when input changes
input_image.change(
fn=process_image,
inputs=input_image,
outputs=output_image
)
# Clear button functionality
clear_button.click(
fn=lambda: None,
outputs=input_image
)
demo.launch()
if __name__ == "__main__":
app() |