|
|
|
import gradio as gr |
|
import os |
|
|
|
from inference_runner import MaIR_Upsampler |
|
|
|
model_cache = {} |
|
|
|
def get_model(model_name): |
|
"""Loads a model into the cache if it's not already there.""" |
|
if model_name not in model_cache: |
|
print(f"Loading model {model_name} into cache...") |
|
model_cache[model_name] = MaIR_Upsampler(model_name=model_name) |
|
return model_cache[model_name] |
|
|
|
def inference_api(image, model_name): |
|
""" |
|
This is the function that the API will call. |
|
It takes a NumPy array and a model name string as input. |
|
""" |
|
if image is None: |
|
raise ValueError("No image provided.") |
|
|
|
upsampler = get_model(model_name) |
|
output_image = upsampler.process(image) |
|
return output_image |
|
|
|
interface = gr.Interface( |
|
fn=inference_api, |
|
inputs=[ |
|
gr.Image(type="numpy", label="Input Image"), |
|
gr.Dropdown( |
|
choices=['MaIR-SRx4', 'MaIR-SRx2', 'MaIR-CDN-s50'], |
|
value='MaIR-SRx4', |
|
label="Select Model" |
|
), |
|
], |
|
outputs=gr.Image(type="numpy", label="Output Image"), |
|
title="MaIR: Image Restoration API", |
|
description="API for MaIR models. Use the '/api' endpoint for programmatic access." |
|
) |
|
|
|
interface.launch() |