gabar92 commited on
Commit
951d58b
·
1 Parent(s): b79dbcf

bugfix: model path

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -41,17 +41,20 @@ def main():
41
  """
42
  Main function to set up and run the Gradio demo.
43
  """
 
44
  args = parse_arguments()
45
 
46
  args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
47
 
48
  # Set up model
49
  die_token = os.getenv("DIE_TOKEN")
50
- model_path = hf_hub_download(
 
51
  repo_id="gabar92/die",
52
  filename=args.die_model_path,
53
  use_auth_token=die_token
54
  )
 
55
  die_model = UNetDIEModel(args=args)
56
 
57
  # Prepare example images
@@ -105,9 +108,10 @@ def parse_arguments():
105
  :return: argument namespace
106
  """
107
  parser = argparse.ArgumentParser()
108
- parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt", help="Path to the DIE model checkpoint")
109
- parser.add_argument("--example_image_path", default="example_images", help="Path to directory with example images")
110
- return parser.parse_args()
 
111
 
112
 
113
  if __name__ == "__main__":
 
41
  """
42
  Main function to set up and run the Gradio demo.
43
  """
44
+
45
  args = parse_arguments()
46
 
47
  args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
48
 
49
  # Set up model
50
  die_token = os.getenv("DIE_TOKEN")
51
+
52
+ args.die_model_path = hf_hub_download(
53
  repo_id="gabar92/die",
54
  filename=args.die_model_path,
55
  use_auth_token=die_token
56
  )
57
+
58
  die_model = UNetDIEModel(args=args)
59
 
60
  # Prepare example images
 
108
  :return: argument namespace
109
  """
110
  parser = argparse.ArgumentParser()
111
+
112
+ parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
113
+
114
+ parser.add_argument("--example_image_path", default="example_images")
115
 
116
 
117
  if __name__ == "__main__":