Update app.py
Browse files
app.py
CHANGED
@@ -27,28 +27,40 @@ snac_model = None
|
|
27 |
@spaces.GPU()
|
28 |
def load_model():
|
29 |
global model, tokenizer, snac_model
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
@spaces.GPU()
|
54 |
def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
|
@@ -245,7 +257,9 @@ with gr.Blocks() as demo:
|
|
245 |
|
246 |
if __name__ == "__main__":
|
247 |
try:
|
248 |
-
|
249 |
-
|
|
|
|
|
250 |
except Exception as e:
|
251 |
-
|
|
|
27 |
@spaces.GPU()
|
28 |
def load_model():
|
29 |
global model, tokenizer, snac_model
|
30 |
+
|
31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
+
|
33 |
+
print("Loading SNAC model...")
|
34 |
+
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
35 |
+
snac_model = snac_model.to(device)
|
36 |
+
|
37 |
+
print("Loading Orpheus model...")
|
38 |
+
model_name = "canopylabs/orpheus-3b-0.1-ft"
|
39 |
+
|
40 |
+
snapshot_download(
|
41 |
+
repo_id=model_name,
|
42 |
+
allow_patterns=[
|
43 |
+
"config.json",
|
44 |
+
"*.safetensors",
|
45 |
+
"model.safetensors.index.json",
|
46 |
+
"tokenizer.json",
|
47 |
+
"tokenizer_config.json",
|
48 |
+
"special_tokens_map.json",
|
49 |
+
"vocab.json",
|
50 |
+
"merges.txt",
|
51 |
+
],
|
52 |
+
ignore_patterns=[
|
53 |
+
"optimizer.pt",
|
54 |
+
"pytorch_model.bin",
|
55 |
+
"training_args.bin",
|
56 |
+
"scheduler.pt",
|
57 |
+
]
|
58 |
+
)
|
59 |
+
|
60 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
|
61 |
+
model.to(device)
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
63 |
+
print(f"Orpheus model and tokenizer loaded to {device}")
|
64 |
|
65 |
@spaces.GPU()
|
66 |
def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
|
|
|
257 |
|
258 |
if __name__ == "__main__":
|
259 |
try:
|
260 |
+
print("Loading models...")
|
261 |
+
load_model() # This function should be defined to load all necessary models
|
262 |
+
print("Models loaded successfully. Launching the interface...")
|
263 |
+
demo.queue().launch(share=False, ssr_mode=False)
|
264 |
except Exception as e:
|
265 |
+
print(f"Error during startup: {e}")
|