jhansss commited on
Commit
6f349df
·
1 Parent(s): 33b7ea8

Refactor load_song_database to accept config and update data loading logic; add lyric_word_length calculation in create_features

Browse files
data/song2word_lengths.json ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "kising_421": [
3
+ 8,
4
+ 8,
5
+ 4,
6
+ 5,
7
+ 6,
8
+ 2,
9
+ 10,
10
+ 10,
11
+ 4,
12
+ 6,
13
+ 6,
14
+ 8,
15
+ 3,
16
+ 2,
17
+ 4,
18
+ 4,
19
+ 2,
20
+ 3,
21
+ 4,
22
+ 6,
23
+ 14,
24
+ 7,
25
+ 7,
26
+ 8,
27
+ 9,
28
+ 10,
29
+ 4,
30
+ 6,
31
+ 6,
32
+ 8,
33
+ 7,
34
+ 7,
35
+ 7,
36
+ 7,
37
+ 7,
38
+ 8,
39
+ 7,
40
+ 8,
41
+ 7,
42
+ 6,
43
+ 7,
44
+ 6,
45
+ 7,
46
+ 6,
47
+ 12,
48
+ 13,
49
+ 5,
50
+ 12,
51
+ 12
52
+ ],
53
+ "kising_422": [
54
+ 10,
55
+ 10,
56
+ 10,
57
+ 10,
58
+ 10,
59
+ 10,
60
+ 10,
61
+ 10,
62
+ 9,
63
+ 9,
64
+ 9,
65
+ 7,
66
+ 9,
67
+ 9,
68
+ 10,
69
+ 6,
70
+ 6,
71
+ 11,
72
+ 10,
73
+ 10,
74
+ 10,
75
+ 9,
76
+ 9,
77
+ 9,
78
+ 7,
79
+ 9,
80
+ 9,
81
+ 9,
82
+ 6,
83
+ 6,
84
+ 9,
85
+ 9,
86
+ 9,
87
+ 7,
88
+ 9,
89
+ 9,
90
+ 9,
91
+ 6,
92
+ 6,
93
+ 6,
94
+ 3
95
+ ],
96
+ "kising_423": [
97
+ 1,
98
+ 3,
99
+ 4,
100
+ 2,
101
+ 1,
102
+ 3,
103
+ 3,
104
+ 4,
105
+ 1,
106
+ 3,
107
+ 3,
108
+ 3,
109
+ 4,
110
+ 3,
111
+ 4,
112
+ 17,
113
+ 16,
114
+ 17,
115
+ 13,
116
+ 2,
117
+ 3,
118
+ 6,
119
+ 11,
120
+ 2,
121
+ 1,
122
+ 3,
123
+ 3,
124
+ 2,
125
+ 1,
126
+ 3,
127
+ 3,
128
+ 2,
129
+ 3,
130
+ 3,
131
+ 3,
132
+ 1,
133
+ 3,
134
+ 3,
135
+ 3,
136
+ 17,
137
+ 16,
138
+ 17,
139
+ 13,
140
+ 2,
141
+ 2,
142
+ 6,
143
+ 10,
144
+ 2,
145
+ 5,
146
+ 13,
147
+ 2,
148
+ 4,
149
+ 3,
150
+ 2
151
+ ],
152
+ "kising_424": [
153
+ 3,
154
+ 5,
155
+ 3,
156
+ 5,
157
+ 3,
158
+ 5,
159
+ 5,
160
+ 3,
161
+ 5,
162
+ 3,
163
+ 5,
164
+ 3,
165
+ 5,
166
+ 4,
167
+ 5,
168
+ 3,
169
+ 2,
170
+ 3,
171
+ 2,
172
+ 5,
173
+ 5,
174
+ 5,
175
+ 5,
176
+ 5,
177
+ 6,
178
+ 3,
179
+ 5,
180
+ 3,
181
+ 5,
182
+ 3,
183
+ 5,
184
+ 4,
185
+ 5,
186
+ 3,
187
+ 2,
188
+ 5,
189
+ 5,
190
+ 5,
191
+ 8,
192
+ 5,
193
+ 5,
194
+ 6,
195
+ 3,
196
+ 2,
197
+ 3,
198
+ 3,
199
+ 5,
200
+ 5,
201
+ 5,
202
+ 5,
203
+ 5,
204
+ 7,
205
+ 5,
206
+ 5,
207
+ 5,
208
+ 6,
209
+ 3,
210
+ 5
211
+ ],
212
+ "kising_425": [
213
+ 3,
214
+ 4,
215
+ 3,
216
+ 3,
217
+ 6,
218
+ 2,
219
+ 3,
220
+ 8,
221
+ 2,
222
+ 2,
223
+ 4,
224
+ 4,
225
+ 2,
226
+ 4,
227
+ 5,
228
+ 5,
229
+ 7,
230
+ 2,
231
+ 5,
232
+ 5,
233
+ 8,
234
+ 5,
235
+ 6,
236
+ 1,
237
+ 3,
238
+ 4,
239
+ 1,
240
+ 4,
241
+ 3,
242
+ 4,
243
+ 1,
244
+ 3,
245
+ 8,
246
+ 2,
247
+ 2,
248
+ 4,
249
+ 4,
250
+ 3,
251
+ 4,
252
+ 5,
253
+ 5,
254
+ 7,
255
+ 2,
256
+ 5,
257
+ 6,
258
+ 8,
259
+ 5,
260
+ 5,
261
+ 9,
262
+ 5,
263
+ 5,
264
+ 8,
265
+ 5,
266
+ 5,
267
+ 7,
268
+ 2,
269
+ 5,
270
+ 5,
271
+ 8,
272
+ 5,
273
+ 5,
274
+ 5,
275
+ 2,
276
+ 1
277
+ ],
278
+ "kising_426": [
279
+ 3,
280
+ 8,
281
+ 8,
282
+ 12,
283
+ 3,
284
+ 8,
285
+ 8,
286
+ 12,
287
+ 29,
288
+ 6,
289
+ 24,
290
+ 12,
291
+ 10,
292
+ 9,
293
+ 6,
294
+ 3,
295
+ 5,
296
+ 3,
297
+ 16,
298
+ 3,
299
+ 8,
300
+ 8,
301
+ 12,
302
+ 3,
303
+ 8,
304
+ 8,
305
+ 12,
306
+ 29,
307
+ 6,
308
+ 24,
309
+ 10,
310
+ 8,
311
+ 9,
312
+ 6,
313
+ 3,
314
+ 4
315
+ ],
316
+ "kising_427": [
317
+ 7,
318
+ 7,
319
+ 9,
320
+ 6,
321
+ 7,
322
+ 7,
323
+ 8,
324
+ 7,
325
+ 4,
326
+ 4,
327
+ 11,
328
+ 15,
329
+ 15,
330
+ 15,
331
+ 10,
332
+ 6,
333
+ 4,
334
+ 6,
335
+ 10,
336
+ 2,
337
+ 2,
338
+ 7,
339
+ 7,
340
+ 15,
341
+ 4,
342
+ 4,
343
+ 10,
344
+ 15,
345
+ 15,
346
+ 15,
347
+ 10,
348
+ 5,
349
+ 4,
350
+ 6,
351
+ 10,
352
+ 2,
353
+ 17,
354
+ 15,
355
+ 14,
356
+ 10,
357
+ 5,
358
+ 4,
359
+ 16,
360
+ 1,
361
+ 2,
362
+ 15,
363
+ 15,
364
+ 8,
365
+ 7,
366
+ 10,
367
+ 5,
368
+ 4,
369
+ 16,
370
+ 2,
371
+ 2,
372
+ 10,
373
+ 2,
374
+ 2,
375
+ 10,
376
+ 7,
377
+ 2
378
+ ],
379
+ "kising_428": [
380
+ 5,
381
+ 1,
382
+ 10,
383
+ 8,
384
+ 6,
385
+ 5,
386
+ 10,
387
+ 9,
388
+ 6,
389
+ 5,
390
+ 13,
391
+ 2,
392
+ 5,
393
+ 2,
394
+ 13,
395
+ 3,
396
+ 7,
397
+ 8,
398
+ 7,
399
+ 12,
400
+ 7,
401
+ 8,
402
+ 7,
403
+ 6,
404
+ 2,
405
+ 1,
406
+ 10,
407
+ 9,
408
+ 11,
409
+ 13,
410
+ 2,
411
+ 7,
412
+ 1,
413
+ 12,
414
+ 3,
415
+ 7,
416
+ 8,
417
+ 7,
418
+ 6,
419
+ 5,
420
+ 7,
421
+ 8,
422
+ 7,
423
+ 5,
424
+ 6,
425
+ 9,
426
+ 7,
427
+ 7,
428
+ 6,
429
+ 5,
430
+ 7,
431
+ 8,
432
+ 7,
433
+ 5,
434
+ 5,
435
+ 4,
436
+ 2
437
+ ],
438
+ "kising_429": [
439
+ 4,
440
+ 3,
441
+ 4,
442
+ 3,
443
+ 4,
444
+ 3,
445
+ 4,
446
+ 3,
447
+ 6,
448
+ 5,
449
+ 3,
450
+ 1,
451
+ 5,
452
+ 4,
453
+ 2,
454
+ 1,
455
+ 7,
456
+ 8,
457
+ 19,
458
+ 8,
459
+ 8,
460
+ 19,
461
+ 5,
462
+ 1,
463
+ 4,
464
+ 4,
465
+ 1,
466
+ 4,
467
+ 3,
468
+ 4,
469
+ 3,
470
+ 4,
471
+ 3,
472
+ 4,
473
+ 3,
474
+ 2,
475
+ 5,
476
+ 4,
477
+ 3,
478
+ 7,
479
+ 6,
480
+ 19,
481
+ 7,
482
+ 7,
483
+ 19,
484
+ 7,
485
+ 7,
486
+ 27,
487
+ 7,
488
+ 19,
489
+ 3,
490
+ 3,
491
+ 8,
492
+ 5,
493
+ 7,
494
+ 7,
495
+ 7,
496
+ 19,
497
+ 8,
498
+ 6,
499
+ 19,
500
+ 5,
501
+ 5,
502
+ 5,
503
+ 5,
504
+ 5,
505
+ 5,
506
+ 8,
507
+ 1,
508
+ 1
509
+ ],
510
+ "kising_430": [
511
+ 5,
512
+ 7,
513
+ 4,
514
+ 7,
515
+ 15,
516
+ 3,
517
+ 5,
518
+ 5,
519
+ 7,
520
+ 4,
521
+ 7,
522
+ 15,
523
+ 3,
524
+ 4,
525
+ 2,
526
+ 6,
527
+ 6,
528
+ 8,
529
+ 3,
530
+ 6,
531
+ 7,
532
+ 21,
533
+ 6,
534
+ 13,
535
+ 6,
536
+ 17,
537
+ 12,
538
+ 2,
539
+ 6,
540
+ 6,
541
+ 6,
542
+ 12,
543
+ 6,
544
+ 7,
545
+ 19,
546
+ 7,
547
+ 12,
548
+ 28,
549
+ 4,
550
+ 1,
551
+ 1,
552
+ 6,
553
+ 7,
554
+ 12,
555
+ 7,
556
+ 17,
557
+ 3,
558
+ 5,
559
+ 5,
560
+ 2,
561
+ 2
562
+ ],
563
+ "kising_431": [
564
+ 7,
565
+ 5,
566
+ 7,
567
+ 5,
568
+ 9,
569
+ 5,
570
+ 8,
571
+ 5,
572
+ 4,
573
+ 4,
574
+ 5,
575
+ 3,
576
+ 3,
577
+ 2,
578
+ 2,
579
+ 4,
580
+ 9,
581
+ 6,
582
+ 3,
583
+ 2,
584
+ 7,
585
+ 5,
586
+ 7,
587
+ 5,
588
+ 9,
589
+ 5,
590
+ 8,
591
+ 5,
592
+ 8,
593
+ 5,
594
+ 8,
595
+ 2,
596
+ 4,
597
+ 9,
598
+ 3,
599
+ 3,
600
+ 3,
601
+ 2,
602
+ 6,
603
+ 1,
604
+ 2,
605
+ 2
606
+ ],
607
+ "kising_432": [
608
+ 7,
609
+ 5,
610
+ 7,
611
+ 5,
612
+ 9,
613
+ 5,
614
+ 9,
615
+ 5,
616
+ 4,
617
+ 4,
618
+ 5,
619
+ 3,
620
+ 3,
621
+ 2,
622
+ 2,
623
+ 4,
624
+ 4,
625
+ 5,
626
+ 3,
627
+ 3,
628
+ 3,
629
+ 2,
630
+ 7,
631
+ 5,
632
+ 7,
633
+ 5,
634
+ 9,
635
+ 6,
636
+ 8,
637
+ 5,
638
+ 4,
639
+ 4,
640
+ 5,
641
+ 3,
642
+ 3,
643
+ 2,
644
+ 2,
645
+ 4,
646
+ 4,
647
+ 5,
648
+ 3,
649
+ 3,
650
+ 3,
651
+ 3,
652
+ 3,
653
+ 4,
654
+ 2,
655
+ 2
656
+ ],
657
+ "kising_433": [
658
+ 11,
659
+ 12,
660
+ 12,
661
+ 24,
662
+ 11,
663
+ 7,
664
+ 6,
665
+ 8,
666
+ 7,
667
+ 6,
668
+ 9,
669
+ 7,
670
+ 6,
671
+ 9,
672
+ 13,
673
+ 8,
674
+ 8,
675
+ 7,
676
+ 9,
677
+ 7,
678
+ 12,
679
+ 12,
680
+ 12,
681
+ 11,
682
+ 12,
683
+ 10,
684
+ 7,
685
+ 6,
686
+ 9,
687
+ 7,
688
+ 6,
689
+ 9,
690
+ 13,
691
+ 9,
692
+ 13,
693
+ 9,
694
+ 13,
695
+ 8,
696
+ 13,
697
+ 9,
698
+ 13,
699
+ 9,
700
+ 13,
701
+ 8,
702
+ 8,
703
+ 2,
704
+ 8,
705
+ 2
706
+ ],
707
+ "kising_434": [
708
+ 9,
709
+ 2,
710
+ 8,
711
+ 11,
712
+ 5,
713
+ 9,
714
+ 14,
715
+ 8,
716
+ 11,
717
+ 11,
718
+ 6,
719
+ 6,
720
+ 2,
721
+ 7,
722
+ 2,
723
+ 6,
724
+ 7,
725
+ 2,
726
+ 7,
727
+ 4,
728
+ 7,
729
+ 6,
730
+ 8,
731
+ 10,
732
+ 5,
733
+ 9,
734
+ 7,
735
+ 6,
736
+ 9,
737
+ 11,
738
+ 11,
739
+ 2,
740
+ 6,
741
+ 6,
742
+ 9,
743
+ 2,
744
+ 6,
745
+ 7,
746
+ 2,
747
+ 8,
748
+ 4,
749
+ 6,
750
+ 7,
751
+ 2,
752
+ 7,
753
+ 2,
754
+ 6,
755
+ 7,
756
+ 2,
757
+ 6,
758
+ 5
759
+ ],
760
+ "kising_441-2": [
761
+ 11,
762
+ 4,
763
+ 9,
764
+ 4,
765
+ 5,
766
+ 4,
767
+ 5,
768
+ 14,
769
+ 10,
770
+ 23,
771
+ 12,
772
+ 4
773
+ ],
774
+ "kising_441": [
775
+ 2,
776
+ 6,
777
+ 6,
778
+ 6,
779
+ 18,
780
+ 4,
781
+ 24,
782
+ 6,
783
+ 12,
784
+ 12,
785
+ 12,
786
+ 4,
787
+ 7,
788
+ 4,
789
+ 9,
790
+ 4,
791
+ 4,
792
+ 4,
793
+ 13,
794
+ 7,
795
+ 1,
796
+ 7,
797
+ 13,
798
+ 18,
799
+ 4,
800
+ 24,
801
+ 6,
802
+ 12,
803
+ 12,
804
+ 12,
805
+ 4
806
+ ],
807
+ "kising_442": [
808
+ 9,
809
+ 8,
810
+ 8,
811
+ 11,
812
+ 9,
813
+ 8,
814
+ 19,
815
+ 8,
816
+ 8,
817
+ 8,
818
+ 11,
819
+ 9,
820
+ 8,
821
+ 8,
822
+ 11,
823
+ 9,
824
+ 8,
825
+ 19,
826
+ 8,
827
+ 8,
828
+ 8,
829
+ 11,
830
+ 8,
831
+ 8,
832
+ 19,
833
+ 9,
834
+ 8,
835
+ 8,
836
+ 12,
837
+ 8,
838
+ 8,
839
+ 8,
840
+ 11,
841
+ 8,
842
+ 8,
843
+ 19
844
+ ],
845
+ "kising_443": [
846
+ 7,
847
+ 16,
848
+ 7,
849
+ 8,
850
+ 9,
851
+ 3,
852
+ 1,
853
+ 3,
854
+ 3,
855
+ 33,
856
+ 33,
857
+ 7,
858
+ 17,
859
+ 1,
860
+ 7,
861
+ 8,
862
+ 7,
863
+ 2,
864
+ 33,
865
+ 33,
866
+ 1,
867
+ 5,
868
+ 2,
869
+ 1
870
+ ]
871
+ }
offline_process/create_features.py CHANGED
@@ -9,7 +9,19 @@ combined = combined.filter(lambda x: x["singer"] == "barber")
9
 
10
  # 3. create a new column, which counts the nonzero numbers in the list in the note_midi column
11
  combined = combined.map(
12
- lambda x: {"note_midi_length": len([n for n in x["note_midi"] if n != 0])}
 
 
 
 
 
 
 
 
 
 
 
 
13
  )
14
 
15
  # 4. sort by segment_id
@@ -19,6 +31,7 @@ combined = combined.sort("segment_id")
19
  prev_songid = None
20
  prev_song_segment_id = None
21
  song2note_lengths = {}
 
22
  for row in combined:
23
  # segment_id: kising_barber_{songid}_{song_segment_id}
24
  _, _, songid, song_segment_id = row["segment_id"].split("_")
@@ -28,30 +41,26 @@ for row in combined:
28
  song_segment_id == "001"
29
  ), f"prev_songid: {prev_songid}, songid: {songid}, song_segment_id: {song_segment_id}"
30
  song2note_lengths[f"kising_{songid}"] = [row["note_midi_length"]]
 
31
  else:
32
  assert (
33
  int(song_segment_id) >= int(prev_song_segment_id) + 1
34
  ), f"prev_song_segment_id: {prev_song_segment_id}, song_segment_id: {song_segment_id}"
35
  song2note_lengths[f"kising_{songid}"].append(row["note_midi_length"])
 
36
  prev_songid = songid
37
  prev_song_segment_id = song_segment_id
38
 
39
  # 6. write to json
40
  import json
41
 
42
- with open("song2note_lengths.json", "w") as f:
43
  json.dump(song2note_lengths, f, indent=4)
44
 
45
- # 7. convert to pandas DataFrame
46
- import pandas as pd
47
-
48
- df = pd.DataFrame.from_dict(combined)
49
- df = df.drop(columns=["audio", "singer"])
50
- df["segment_id"] = df["segment_id"].str.replace("kising_barber_", "kising_")
51
- # export to csv
52
- df.to_csv("song_db.csv", index=False)
53
 
54
- # 8. push score segments to hub
55
  # remove audio and singer columns
56
  combined = combined.remove_columns(["audio", "singer"])
57
  # replace kising_barber_ with kising_
 
9
 
10
  # 3. create a new column, which counts the nonzero numbers in the list in the note_midi column
11
  combined = combined.map(
12
+ lambda x: {
13
+ "note_midi_length": len([n for n in x["note_midi"] if n != 0]),
14
+ "lyric_word_length": len(
15
+ [word for word in x["note_lyrics"] if word not in ["<AP>", "<SP>", "-"]]
16
+ ), # counts the number of actual words (or characters for, e.g., Chinese/Japanese)
17
+ }
18
+ )
19
+ combined = combined.map(
20
+ lambda x: {
21
+ "lyric_word_length": len(
22
+ [word for word in x["note_lyrics"] if word not in ["<AP>", "<SP>", "-"]]
23
+ )
24
+ } # counts the number of actual words (or characters for, e.g., Chinese/Japanese)
25
  )
26
 
27
  # 4. sort by segment_id
 
31
  prev_songid = None
32
  prev_song_segment_id = None
33
  song2note_lengths = {}
34
+ song2word_lengths = {}
35
  for row in combined:
36
  # segment_id: kising_barber_{songid}_{song_segment_id}
37
  _, _, songid, song_segment_id = row["segment_id"].split("_")
 
41
  song_segment_id == "001"
42
  ), f"prev_songid: {prev_songid}, songid: {songid}, song_segment_id: {song_segment_id}"
43
  song2note_lengths[f"kising_{songid}"] = [row["note_midi_length"]]
44
+ song2word_lengths[f"kising_{songid}"] = [row["lyric_word_length"]]
45
  else:
46
  assert (
47
  int(song_segment_id) >= int(prev_song_segment_id) + 1
48
  ), f"prev_song_segment_id: {prev_song_segment_id}, song_segment_id: {song_segment_id}"
49
  song2note_lengths[f"kising_{songid}"].append(row["note_midi_length"])
50
+ song2word_lengths[f"kising_{songid}"].append(row["lyric_word_length"])
51
  prev_songid = songid
52
  prev_song_segment_id = song_segment_id
53
 
54
  # 6. write to json
55
  import json
56
 
57
+ with open("data/song2note_lengths.json", "w") as f:
58
  json.dump(song2note_lengths, f, indent=4)
59
 
60
+ with open("data/song2word_lengths.json", "w") as f:
61
+ json.dump(song2word_lengths, f, indent=4)
 
 
 
 
 
 
62
 
63
+ # 7. push score segments to hub
64
  # remove audio and singer columns
65
  combined = combined.remove_columns(["audio", "singer"])
66
  # replace kising_barber_ with kising_
svs_utils.py CHANGED
@@ -310,14 +310,17 @@ def song_segment_iterator(song_db, metadata):
310
  raise NotImplementedError(f"song name {song_name} not supported")
311
 
312
 
313
- def load_song_database():
314
  song_db = load_dataset(
315
  "jhansss/kising_score_segments", cache_dir="cache", split="train"
316
  ).to_pandas()
317
  song_db.set_index("segment_id", inplace=True)
318
-
319
- with open("data/song2note_lengths.json", "r") as f:
320
- song2note_lengths = json.load(f)
 
 
 
321
  return song2note_lengths, song_db
322
 
323
 
@@ -342,7 +345,7 @@ if __name__ == "__main__":
342
  if config.melody_source.startswith("random_select"):
343
  # load song database: jhansss/kising_score_segments
344
  from datasets import load_dataset
345
- song2note_lengths, song_db = load_song_database()
346
 
347
  # get song_name and phrase_length
348
  phrase_length, metadata = estimate_sentence_length(None, config, song2note_lengths)
 
310
  raise NotImplementedError(f"song name {song_name} not supported")
311
 
312
 
313
+ def load_song_database(config):
314
  song_db = load_dataset(
315
  "jhansss/kising_score_segments", cache_dir="cache", split="train"
316
  ).to_pandas()
317
  song_db.set_index("segment_id", inplace=True)
318
+ if ".take_lyric_continuation" in config.melody_source:
319
+ with open("data/song2word_lengths.json", "r") as f:
320
+ song2note_lengths = json.load(f)
321
+ else:
322
+ with open("data/song2note_lengths.json", "r") as f:
323
+ song2note_lengths = json.load(f)
324
  return song2note_lengths, song_db
325
 
326
 
 
345
  if config.melody_source.startswith("random_select"):
346
  # load song database: jhansss/kising_score_segments
347
  from datasets import load_dataset
348
+ song2note_lengths, song_db = load_song_database(config)
349
 
350
  # get song_name and phrase_length
351
  phrase_length, metadata = estimate_sentence_length(None, config, song2note_lengths)