Update app.py
Browse files
app.py
CHANGED
|
@@ -89,7 +89,7 @@ def GenerateGroove():
|
|
| 89 |
|
| 90 |
print('Sample input events', drums_score[:5])
|
| 91 |
print('=' * 70)
|
| 92 |
-
print('
|
| 93 |
|
| 94 |
num_prime_chords = 7
|
| 95 |
|
|
@@ -98,15 +98,54 @@ def GenerateGroove():
|
|
| 98 |
for d in drums_score[:num_prime_chords]:
|
| 99 |
|
| 100 |
outy.extend(d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
for i in
|
| 103 |
|
| 104 |
outy.extend(drums_score[i])
|
| 105 |
|
| 106 |
if i == num_prime_chords:
|
| 107 |
-
outy.append(256+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
out =
|
| 110 |
|
| 111 |
outy.extend(out)
|
| 112 |
|
|
|
|
| 89 |
|
| 90 |
print('Sample input events', drums_score[:5])
|
| 91 |
print('=' * 70)
|
| 92 |
+
print('Prepping drums track...')
|
| 93 |
|
| 94 |
num_prime_chords = 7
|
| 95 |
|
|
|
|
| 98 |
for d in drums_score[:num_prime_chords]:
|
| 99 |
|
| 100 |
outy.extend(d)
|
| 101 |
+
|
| 102 |
+
print('Generating...')
|
| 103 |
+
|
| 104 |
+
max_notes_per_chord=8,
|
| 105 |
+
num_samples=4,
|
| 106 |
+
num_memory_tokens = 4096,
|
| 107 |
+
temperature=1.0):
|
| 108 |
|
| 109 |
+
for i in range(num_prime_chords, len(drums_score)):
|
| 110 |
|
| 111 |
outy.extend(drums_score[i])
|
| 112 |
|
| 113 |
if i == num_prime_chords:
|
| 114 |
+
outy.append(256+12)
|
| 115 |
+
|
| 116 |
+
input_seq = outy[-num_memory_tokens:]
|
| 117 |
+
|
| 118 |
+
seq = copy.deepcopy(input_seq)
|
| 119 |
+
|
| 120 |
+
batch_value = 256
|
| 121 |
+
|
| 122 |
+
nc = 0
|
| 123 |
+
|
| 124 |
+
while batch_value > 255 and nc < max_notes_per_chord:
|
| 125 |
+
|
| 126 |
+
x = torch.tensor([seq] * num_samples, dtype=torch.long, device='cuda')
|
| 127 |
+
|
| 128 |
+
with ctx:
|
| 129 |
+
out = model.generate(x,
|
| 130 |
+
1,
|
| 131 |
+
temperature=temperature,
|
| 132 |
+
return_prime=False,
|
| 133 |
+
verbose=False)
|
| 134 |
+
|
| 135 |
+
out1 = [o[0] for o in out.tolist() if o[0] > 255]
|
| 136 |
+
|
| 137 |
+
if not out1:
|
| 138 |
+
out1 = [-1]
|
| 139 |
+
|
| 140 |
+
batch_value = random.choice(out1)
|
| 141 |
+
|
| 142 |
+
if batch_value > 255:
|
| 143 |
+
seq.append(batch_value)
|
| 144 |
+
|
| 145 |
+
if batch_value > 383:
|
| 146 |
+
nc += 1
|
| 147 |
|
| 148 |
+
out = seq[len(input_seq):]
|
| 149 |
|
| 150 |
outy.extend(out)
|
| 151 |
|