gabar92 commited on
Commit
b79dbcf
·
1 Parent(s): 814657b

move to cuda is device is available

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import argparse
2
  import os
3
  from functools import partial
 
4
 
5
  import gradio as gr
6
  from PIL import Image
@@ -41,6 +42,8 @@ def main():
41
  Main function to set up and run the Gradio demo.
42
  """
43
  args = parse_arguments()
 
 
44
 
45
  # Set up model
46
  die_token = os.getenv("DIE_TOKEN")
@@ -49,8 +52,7 @@ def main():
49
  filename=args.die_model_path,
50
  use_auth_token=die_token
51
  )
52
- die_model = UNetDIEModel(args=model_path)
53
- device = args.device
54
 
55
  # Prepare example images
56
  example_image_list = [
@@ -68,7 +70,7 @@ def main():
68
  """
69
 
70
  # Partial function for inference with model and device arguments
71
- partial_die_inference = partial(die_inference, die_model=die_model, device=device)
72
 
73
  with gr.Blocks() as demo:
74
  with gr.Row():
@@ -104,7 +106,6 @@ def parse_arguments():
104
  """
105
  parser = argparse.ArgumentParser()
106
  parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt", help="Path to the DIE model checkpoint")
107
- parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Device to run the model on")
108
  parser.add_argument("--example_image_path", default="example_images", help="Path to directory with example images")
109
  return parser.parse_args()
110
 
 
1
  import argparse
2
  import os
3
  from functools import partial
4
+ import torch
5
 
6
  import gradio as gr
7
  from PIL import Image
 
42
  Main function to set up and run the Gradio demo.
43
  """
44
  args = parse_arguments()
45
+
46
+ args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
47
 
48
  # Set up model
49
  die_token = os.getenv("DIE_TOKEN")
 
52
  filename=args.die_model_path,
53
  use_auth_token=die_token
54
  )
55
+ die_model = UNetDIEModel(args=args)
 
56
 
57
  # Prepare example images
58
  example_image_list = [
 
70
  """
71
 
72
  # Partial function for inference with model and device arguments
73
+ partial_die_inference = partial(die_inference, die_model=die_model, device=args.device)
74
 
75
  with gr.Blocks() as demo:
76
  with gr.Row():
 
106
  """
107
  parser = argparse.ArgumentParser()
108
  parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt", help="Path to the DIE model checkpoint")
 
109
  parser.add_argument("--example_image_path", default="example_images", help="Path to directory with example images")
110
  return parser.parse_args()
111