bugfix: model path
Browse files
app.py
CHANGED
@@ -41,17 +41,20 @@ def main():
|
|
41 |
"""
|
42 |
Main function to set up and run the Gradio demo.
|
43 |
"""
|
|
|
44 |
args = parse_arguments()
|
45 |
|
46 |
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
47 |
|
48 |
# Set up model
|
49 |
die_token = os.getenv("DIE_TOKEN")
|
50 |
-
|
|
|
51 |
repo_id="gabar92/die",
|
52 |
filename=args.die_model_path,
|
53 |
use_auth_token=die_token
|
54 |
)
|
|
|
55 |
die_model = UNetDIEModel(args=args)
|
56 |
|
57 |
# Prepare example images
|
@@ -105,9 +108,10 @@ def parse_arguments():
|
|
105 |
:return: argument namespace
|
106 |
"""
|
107 |
parser = argparse.ArgumentParser()
|
108 |
-
|
109 |
-
parser.add_argument("--
|
110 |
-
|
|
|
111 |
|
112 |
|
113 |
if __name__ == "__main__":
|
|
|
41 |
"""
|
42 |
Main function to set up and run the Gradio demo.
|
43 |
"""
|
44 |
+
|
45 |
args = parse_arguments()
|
46 |
|
47 |
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
48 |
|
49 |
# Set up model
|
50 |
die_token = os.getenv("DIE_TOKEN")
|
51 |
+
|
52 |
+
args.die_model_path = hf_hub_download(
|
53 |
repo_id="gabar92/die",
|
54 |
filename=args.die_model_path,
|
55 |
use_auth_token=die_token
|
56 |
)
|
57 |
+
|
58 |
die_model = UNetDIEModel(args=args)
|
59 |
|
60 |
# Prepare example images
|
|
|
108 |
:return: argument namespace
|
109 |
"""
|
110 |
parser = argparse.ArgumentParser()
|
111 |
+
|
112 |
+
parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
|
113 |
+
|
114 |
+
parser.add_argument("--example_image_path", default="example_images")
|
115 |
|
116 |
|
117 |
if __name__ == "__main__":
|