# app.py import gradio as gr import os from inference_runner import MaIR_Upsampler model_cache = {} def get_model(model_name): """Loads a model into the cache if it's not already there.""" if model_name not in model_cache: print(f"Loading model {model_name} into cache...") model_cache[model_name] = MaIR_Upsampler(model_name=model_name) return model_cache[model_name] def inference_api(image, model_name): """ This is the function that the API will call. It takes a NumPy array and a model name string as input. """ if image is None: raise ValueError("No image provided.") upsampler = get_model(model_name) output_image = upsampler.process(image) return output_image interface = gr.Interface( fn=inference_api, inputs=[ gr.Image(type="numpy", label="Input Image"), gr.Dropdown( choices=['MaIR-SRx4', 'MaIR-SRx2', 'MaIR-CDN-s50'], value='MaIR-SRx4', label="Select Model" ), ], outputs=gr.Image(type="numpy", label="Output Image"), title="MaIR: Image Restoration API", description="API for MaIR models. Use the '/api' endpoint for programmatic access." ) interface.launch()