bluenevus commited on
Commit
1374f14
·
verified ·
1 Parent(s): 0128620

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -25
app.py CHANGED
@@ -27,28 +27,40 @@ snac_model = None
27
  @spaces.GPU()
28
  def load_model():
29
  global model, tokenizer, snac_model
30
- try:
31
- print("Loading SNAC model...")
32
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
33
- snac_model = snac_model.to(device)
34
-
35
- print("Loading Orpheus model...")
36
- model_name = "canopylabs/orpheus-3b-0.1-ft"
37
-
38
- snapshot_download(
39
- repo_id=model_name,
40
- use_auth_token=os.environ.get("HUGGINGFACE_TOKEN"),
41
- allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json", "vocab.json", "merges.txt", "tokenizer.json"],
42
- ignore_patterns=["optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt"]
43
- )
44
-
45
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
46
- model.to(device)
47
- tokenizer = AutoTokenizer.from_pretrained(model_name)
48
- print(f"Orpheus model and tokenizer loaded to {device}")
49
- except Exception as e:
50
- print(f"Error loading model: {str(e)}")
51
- raise
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  @spaces.GPU()
54
  def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
@@ -245,7 +257,9 @@ with gr.Blocks() as demo:
245
 
246
  if __name__ == "__main__":
247
  try:
248
- load_model()
249
- demo.queue().launch()
 
 
250
  except Exception as e:
251
- logger.error(f"Error launching the application: {str(e)}")
 
27
  @spaces.GPU()
28
  def load_model():
29
  global model, tokenizer, snac_model
30
+
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ print("Loading SNAC model...")
34
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
35
+ snac_model = snac_model.to(device)
36
+
37
+ print("Loading Orpheus model...")
38
+ model_name = "canopylabs/orpheus-3b-0.1-ft"
39
+
40
+ snapshot_download(
41
+ repo_id=model_name,
42
+ allow_patterns=[
43
+ "config.json",
44
+ "*.safetensors",
45
+ "model.safetensors.index.json",
46
+ "tokenizer.json",
47
+ "tokenizer_config.json",
48
+ "special_tokens_map.json",
49
+ "vocab.json",
50
+ "merges.txt",
51
+ ],
52
+ ignore_patterns=[
53
+ "optimizer.pt",
54
+ "pytorch_model.bin",
55
+ "training_args.bin",
56
+ "scheduler.pt",
57
+ ]
58
+ )
59
+
60
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
61
+ model.to(device)
62
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ print(f"Orpheus model and tokenizer loaded to {device}")
64
 
65
  @spaces.GPU()
66
  def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
 
257
 
258
  if __name__ == "__main__":
259
  try:
260
+ print("Loading models...")
261
+ load_model() # This function should be defined to load all necessary models
262
+ print("Models loaded successfully. Launching the interface...")
263
+ demo.queue().launch(share=False, ssr_mode=False)
264
  except Exception as e:
265
+ print(f"Error during startup: {e}")