da03 commited on
Commit
3a94a55
·
1 Parent(s): b64a257
Files changed (3) hide show
  1. config_rnn_measure_latency.yaml +2 -4
  2. main.py +1 -1
  3. utils.py +1 -1
config_rnn_measure_latency.yaml CHANGED
@@ -26,9 +26,7 @@ model:
26
  in_channels: 20
27
  out_channels: 4
28
  model_channels: 256
29
- attention_resolutions:
30
- - 4
31
- - 2
32
  num_res_blocks: 2
33
  channel_mult:
34
  - 1
@@ -41,7 +39,7 @@ model:
41
  target: ldm.modules.encoders.temporal_encoder.TemporalEncoder
42
  params:
43
  input_channels: 6
44
- hidden_size: 4096
45
  num_layers: 1
46
  dropout: 0.1
47
  output_channels: 16
 
26
  in_channels: 20
27
  out_channels: 4
28
  model_channels: 256
29
+ attention_resolutions: []
 
 
30
  num_res_blocks: 2
31
  channel_mult:
32
  - 1
 
39
  target: ldm.modules.encoders.temporal_encoder.TemporalEncoder
40
  params:
41
  input_channels: 6
42
+ hidden_size: 2048
43
  num_layers: 1
44
  dropout: 0.1
45
  output_channels: 16
main.py CHANGED
@@ -28,7 +28,7 @@ LATENT_DIMS = (4, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)
28
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
  # Initialize the model at the start of your application
30
  #model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
31
- model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")
32
 
33
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
  model = model.to(device)
 
28
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
  # Initialize the model at the start of your application
30
  #model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
31
+ model = initialize_model("config_rnn_measure_latency.yaml", "yuntian-deng/computer-model")
32
 
33
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
  model = model.to(device)
utils.py CHANGED
@@ -10,7 +10,7 @@ import os
10
  import time
11
  DEBUG = False
12
 
13
- def load_model_from_config(config_path, model_name, device='cuda', load=True):
14
  # Load the config file
15
  config = OmegaConf.load(config_path)
16
 
 
10
  import time
11
  DEBUG = False
12
 
13
+ def load_model_from_config(config_path, model_name, device='cuda', load=False):
14
  # Load the config file
15
  config = OmegaConf.load(config_path)
16