asigalov61 commited on
Commit
7c19e53
·
verified ·
1 Parent(s): 871e02d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -33,7 +33,7 @@ def GenerateGroove():
33
 
34
  SEQ_LEN = 4096 # Models seq len
35
  PAD_IDX = 1664 # Models pad index
36
- DEVICE = 'cuda' # 'cuda'
37
 
38
  # instantiate the model
39
 
@@ -124,7 +124,7 @@ def GenerateGroove():
124
 
125
  while batch_value > 255 and nc < max_notes_per_chord:
126
 
127
- x = torch.tensor([seq] * num_samples, dtype=torch.long, device='cuda')
128
 
129
  with ctx:
130
  out = model.generate(x,
 
33
 
34
  SEQ_LEN = 4096 # Models seq len
35
  PAD_IDX = 1664 # Models pad index
36
+ DEVICE = 'cpu' # 'cuda'
37
 
38
  # instantiate the model
39
 
 
124
 
125
  while batch_value > 255 and nc < max_notes_per_chord:
126
 
127
+ x = torch.tensor([seq] * num_samples, dtype=torch.long, device=DEVICE)
128
 
129
  with ctx:
130
  out = model.generate(x,