da03 commited on
Commit
4013838
·
1 Parent(s): a9d3da5
Files changed (1) hide show
  1. main.py +7 -2
main.py CHANGED
@@ -44,6 +44,8 @@ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-54k"
44
 
45
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-160k"
46
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-368k"
 
 
47
 
48
 
49
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
@@ -58,8 +60,11 @@ LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)
58
  #model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")
59
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model-noss")
60
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model")
61
- #model = initialize_model("config_final_model.yaml", MODEL_NAME)
62
- model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
 
 
 
63
 
64
  model = model.to(device)
65
  #model = torch.compile(model)
 
44
 
45
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-160k"
46
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-368k"
47
+ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
48
+ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
49
 
50
 
51
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
 
60
  #model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")
61
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model-noss")
62
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model")
63
+
64
+ if 'origunet' in MODEL_NAME:
65
+ model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
66
+ else:
67
+ model = initialize_model("config_final_model.yaml", MODEL_NAME)
68
 
69
  model = model.to(device)
70
  #model = torch.compile(model)