import gradio as gr import torch import numpy as np from PIL import Image import yaml import os from utils import load_img, plot_all, dict2namespace, seed_everything from models import instructir from text.models import LanguageModel, LMHead # Setup SEED = 42 seed_everything(SEED=SEED) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") CONFIG = "configs/eval5d.yml" LM_MODEL = "models/lm_instructir-7d.pt" MODEL_NAME = "models/im_instructir-7d.pt" # Load config with open(CONFIG, "r") as f: config = yaml.safe_load(f) cfg = dict2namespace(config) # Load image model model = instructir.create_model( input_channels=cfg.model.in_ch, width=cfg.model.width, enc_blks=cfg.model.enc_blks, middle_blk_num=cfg.model.middle_blk_num, dec_blks=cfg.model.dec_blks, txtdim=cfg.model.textdim ).to(device) model.load_state_dict(torch.load(MODEL_NAME, map_location=device), strict=True) model.eval() # Load language model language_model = LanguageModel(model=cfg.llm.model) lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses) lm_head.load_state_dict(torch.load(LM_MODEL, map_location=device), strict=True) lm_head.eval() def process(image, prompt): image = image.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 y = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).to(device) lm_embd = language_model(prompt) text_embd, _ = lm_head(lm_embd) x_hat = model(y, text_embd) restored = x_hat[0].permute(1, 2, 0).cpu().detach().numpy() restored = np.clip(restored, 0., 1.) restored = (restored * 255).astype(np.uint8) return Image.fromarray(restored) interface = gr.Interface( fn=process, inputs=[ gr.Image(type="pil"), gr.Textbox(label="Prompt", placeholder="Describe the restoration..."), ], outputs=gr.Image(type="pil"), title="swiftlens: Prompt-Guided Image Restoration" ) if __name__ == "__main__": interface.launch()