move to cuda is device is available
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import argparse
|
2 |
import os
|
3 |
from functools import partial
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
@@ -41,6 +42,8 @@ def main():
|
|
41 |
Main function to set up and run the Gradio demo.
|
42 |
"""
|
43 |
args = parse_arguments()
|
|
|
|
|
44 |
|
45 |
# Set up model
|
46 |
die_token = os.getenv("DIE_TOKEN")
|
@@ -49,8 +52,7 @@ def main():
|
|
49 |
filename=args.die_model_path,
|
50 |
use_auth_token=die_token
|
51 |
)
|
52 |
-
die_model = UNetDIEModel(args=
|
53 |
-
device = args.device
|
54 |
|
55 |
# Prepare example images
|
56 |
example_image_list = [
|
@@ -68,7 +70,7 @@ def main():
|
|
68 |
"""
|
69 |
|
70 |
# Partial function for inference with model and device arguments
|
71 |
-
partial_die_inference = partial(die_inference, die_model=die_model, device=device)
|
72 |
|
73 |
with gr.Blocks() as demo:
|
74 |
with gr.Row():
|
@@ -104,7 +106,6 @@ def parse_arguments():
|
|
104 |
"""
|
105 |
parser = argparse.ArgumentParser()
|
106 |
parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt", help="Path to the DIE model checkpoint")
|
107 |
-
parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Device to run the model on")
|
108 |
parser.add_argument("--example_image_path", default="example_images", help="Path to directory with example images")
|
109 |
return parser.parse_args()
|
110 |
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
from functools import partial
|
4 |
+
import torch
|
5 |
|
6 |
import gradio as gr
|
7 |
from PIL import Image
|
|
|
42 |
Main function to set up and run the Gradio demo.
|
43 |
"""
|
44 |
args = parse_arguments()
|
45 |
+
|
46 |
+
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
47 |
|
48 |
# Set up model
|
49 |
die_token = os.getenv("DIE_TOKEN")
|
|
|
52 |
filename=args.die_model_path,
|
53 |
use_auth_token=die_token
|
54 |
)
|
55 |
+
die_model = UNetDIEModel(args=args)
|
|
|
56 |
|
57 |
# Prepare example images
|
58 |
example_image_list = [
|
|
|
70 |
"""
|
71 |
|
72 |
# Partial function for inference with model and device arguments
|
73 |
+
partial_die_inference = partial(die_inference, die_model=die_model, device=args.device)
|
74 |
|
75 |
with gr.Blocks() as demo:
|
76 |
with gr.Row():
|
|
|
106 |
"""
|
107 |
parser = argparse.ArgumentParser()
|
108 |
parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt", help="Path to the DIE model checkpoint")
|
|
|
109 |
parser.add_argument("--example_image_path", default="example_images", help="Path to directory with example images")
|
110 |
return parser.parse_args()
|
111 |
|