SreyanG-NVIDIA commited on
Commit
816092e
·
verified ·
1 Parent(s): 542f3ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -12,13 +12,13 @@ from huggingface_hub import snapshot_download
12
  MODEL_BASE_SINGLE = snapshot_download(repo_id="nvidia/audio-flamingo-3")
13
  MODEL_BASE_THINK = os.path.join(MODEL_BASE_SINGLE, 'stage35')
14
 
15
- model_single = llava.load(MODEL_BASE_SINGLE, model_base=None, devices=[0,1])
16
  generation_config_single = model_single.default_generation_config
17
 
18
  model_think = PeftModel.from_pretrained(
19
  model_single,
20
  MODEL_BASE_THINK,
21
- device_map="cuda:2",
22
  torch_dtype=torch.float16,
23
  )
24
 
@@ -26,7 +26,7 @@ model_think = PeftModel.from_pretrained(
26
  # MULTI-TURN MODEL SETUP
27
  # ---------------------------------
28
  MODEL_BASE_MULTI = snapshot_download(repo_id="nvidia/audio-flamingo-3-chat")
29
- model_multi = llava.load(MODEL_BASE_MULTI, model_base=None, devices=[3])
30
  generation_config_multi = model_multi.default_generation_config
31
 
32
 
 
12
  MODEL_BASE_SINGLE = snapshot_download(repo_id="nvidia/audio-flamingo-3")
13
  MODEL_BASE_THINK = os.path.join(MODEL_BASE_SINGLE, 'stage35')
14
 
15
+ model_single = llava.load(MODEL_BASE_SINGLE, model_base=None, devices=[0])
16
  generation_config_single = model_single.default_generation_config
17
 
18
  model_think = PeftModel.from_pretrained(
19
  model_single,
20
  MODEL_BASE_THINK,
21
+ device_map="auto",
22
  torch_dtype=torch.float16,
23
  )
24
 
 
26
  # MULTI-TURN MODEL SETUP
27
  # ---------------------------------
28
  MODEL_BASE_MULTI = snapshot_download(repo_id="nvidia/audio-flamingo-3-chat")
29
+ model_multi = llava.load(MODEL_BASE_MULTI, model_base=None, devices=[0])
30
  generation_config_multi = model_multi.default_generation_config
31
 
32