Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ def GenerateGroove():
|
|
33 |
|
34 |
SEQ_LEN = 4096 # Models seq len
|
35 |
PAD_IDX = 1664 # Models pad index
|
36 |
-
DEVICE = '
|
37 |
|
38 |
# instantiate the model
|
39 |
|
@@ -124,7 +124,7 @@ def GenerateGroove():
|
|
124 |
|
125 |
while batch_value > 255 and nc < max_notes_per_chord:
|
126 |
|
127 |
-
x = torch.tensor([seq] * num_samples, dtype=torch.long, device=
|
128 |
|
129 |
with ctx:
|
130 |
out = model.generate(x,
|
|
|
33 |
|
34 |
SEQ_LEN = 4096 # Models seq len
|
35 |
PAD_IDX = 1664 # Models pad index
|
36 |
+
DEVICE = 'cpu' # 'cuda'
|
37 |
|
38 |
# instantiate the model
|
39 |
|
|
|
124 |
|
125 |
while batch_value > 255 and nc < max_notes_per_chord:
|
126 |
|
127 |
+
x = torch.tensor([seq] * num_samples, dtype=torch.long, device=DEVICE)
|
128 |
|
129 |
with ctx:
|
130 |
out = model.generate(x,
|