justus-tobias commited on
Commit
60aba0c
·
1 Parent(s): 3d2b6af

updated spaces

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,14 +1,16 @@
 
1
  import gradio as gr
2
  import torch
3
  from huggingface_hub import hf_hub_download
4
  from moshi.models import loaders, LMGen
5
  import numpy as np
6
- import spaces
7
 
8
 
9
  @spaces.GPU
10
  def process_wav_new(in_wav):
11
  """wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]"""
 
12
 
13
  mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
14
  mimi = loaders.get_mimi(mimi_weight, device='cpu')
@@ -28,7 +30,7 @@ def process_wav_new(in_wav):
28
  assert codes.shape[-1] == 1, codes.shape
29
  all_codes.append(codes)
30
 
31
- mimi.cuda()
32
  moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
33
  moshi = loaders.get_moshi_lm(moshi_weight, device='cuda')
34
  lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) # this handles sampling params etc.
@@ -38,7 +40,7 @@ def process_wav_new(in_wav):
38
  with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
39
  for idx, code in enumerate(all_codes):
40
  print("CODE: ", code.shape)
41
- tokens_out = lm_gen.step(code.cuda())
42
  # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
43
  if tokens_out is not None:
44
  wav_chunk = mimi.decode(tokens_out[:, 1:])
 
1
+ import spaces
2
  import gradio as gr
3
  import torch
4
  from huggingface_hub import hf_hub_download
5
  from moshi.models import loaders, LMGen
6
  import numpy as np
7
+
8
 
9
 
10
  @spaces.GPU
11
  def process_wav_new(in_wav):
12
  """wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]"""
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
16
  mimi = loaders.get_mimi(mimi_weight, device='cpu')
 
30
  assert codes.shape[-1] == 1, codes.shape
31
  all_codes.append(codes)
32
 
33
+ mimi.to(device)
34
  moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
35
  moshi = loaders.get_moshi_lm(moshi_weight, device='cuda')
36
  lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) # this handles sampling params etc.
 
40
  with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
41
  for idx, code in enumerate(all_codes):
42
  print("CODE: ", code.shape)
43
+ tokens_out = lm_gen.step(code.to(device))
44
  # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
45
  if tokens_out is not None:
46
  wav_chunk = mimi.decode(tokens_out[:, 1:])