P-rateek commited on
Commit
ac7a019
·
verified ·
1 Parent(s): 2ed935f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import os
4
+
5
+ # Import the core logic from your "normal inference file"
6
+ from NoiseFilter.MaIR.inference_runner import MaIR_Upsampler
7
+
8
+ # --- Global model cache for performance ---
9
+ # This dictionary will store loaded models to avoid reloading on every API call.
10
+ model_cache = {}
11
+
12
+ def get_model(model_name):
13
+ """Loads a model into the cache if it's not already there."""
14
+ if model_name not in model_cache:
15
+ print(f"Loading model {model_name} into cache...")
16
+ model_cache[model_name] = MaIR_Upsampler(model_name=model_name)
17
+ return model_cache[model_name]
18
+
19
+ # --- API Function ---
20
+ def inference_api(image, model_name):
21
+ """
22
+ This is the function that the API will call.
23
+ It takes a NumPy array and a model name string as input.
24
+ """
25
+ if image is None:
26
+ # Gradio handles this by not running, but good practice for raw API calls.
27
+ raise ValueError("No image provided.")
28
+
29
+ upsampler = get_model(model_name)
30
+ output_image = upsampler.process(image)
31
+ return output_image
32
+
33
+ # --- Create the Gradio Interface (for API generation) ---
34
+ # We define a minimal interface. The primary goal is API exposure.
35
+ interface = gr.Interface(
36
+ fn=inference_api,
37
+ inputs=[
38
+ gr.Image(type="numpy", label="Input Image"),
39
+ gr.Dropdown(
40
+ choices=['MaIR-SRx4', 'MaIR-SRx2', 'MaIR-CDN-s50'],
41
+ value='MaIR-SRx4',
42
+ label="Select Model"
43
+ ),
44
+ ],
45
+ outputs=gr.Image(type="numpy", label="Output Image"),
46
+ title="MaIR: Image Restoration API",
47
+ description="API for MaIR models. Use the '/api' endpoint for programmatic access."
48
+ )
49
+
50
+ # Launch the app. This will start the web server and create the API.
51
+ interface.launch()