Tomtom84 commited on
Commit
28190c5
·
verified ·
1 Parent(s): 549abbb

Update engines/orpheus_engine.py

Browse files
Files changed (1) hide show
  1. engines/orpheus_engine.py +23 -22
engines/orpheus_engine.py CHANGED
@@ -29,6 +29,7 @@ import logging
29
  import struct
30
  import time
31
  import os
 
32
  from queue import Queue
33
  from typing import Generator, Iterable, List, Optional
34
 
@@ -97,13 +98,13 @@ def _fade_in_out(audio: np.ndarray, fade_ms: int = 50) -> np.ndarray:
97
  ###############################################################################
98
  # SNAC – lightweight wrapper #
99
  ###############################################################################
100
- try:
101
- from snac import SNAC
102
- _snac_model: Optional[SNAC] = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
103
- _snac_model = _snac_model.to("cuda" if _snac_model and _snac_model.torch.cuda.is_available() else "cpu")
104
- except Exception as exc: # pragma: no cover
105
- logging.warning("SNAC model could not be loaded – %s", exc)
106
- _snac_model = None
107
 
108
 
109
  def _codes_to_audio(codes: List[int]) -> bytes:
@@ -171,21 +172,21 @@ class OrpheusEngine(BaseEngine):
171
  OrpheusVoice("Lea", "f"),
172
  ]
173
  def _load_snac(self, model_name: str = SNAC_MODEL):
174
- """
175
- Lädt den SNAC-Decoder auf CPU/GPU.
176
- Fällt bei jedem Fehler sauber auf CPU zurück.
177
- """
178
- device = "cuda" if torch.cuda.is_available() else "cpu"
179
- try:
180
- snac = SNAC.from_pretrained(model_name).to(device)
181
- if device == "cuda": # half() nur auf GPU – ältere SNAC-Versionen haben keine .half()
182
- snac = snac.half()
183
- snac.eval()
184
- logging.info(f"SNAC {snac_version} loaded on {device}")
185
- return snac
186
- except Exception as e:
187
- logging.exception("SNAC load failed – running with silent fallback")
188
- return None
189
  # ---------------------------------------------------------------------
190
  def __init__(
191
  self,
 
29
  import struct
30
  import time
31
  import os
32
+ import torch
33
  from queue import Queue
34
  from typing import Generator, Iterable, List, Optional
35
 
 
98
  ###############################################################################
99
  # SNAC – lightweight wrapper #
100
  ###############################################################################
101
+ try:
102
+ from snac import SNAC
103
+ _snac_model: Optional[SNAC] = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
104
+ _snac_model = _snac_model.to("cuda" if _snac_model and _snac_model.torch.cuda.is_available() else "cpu")
105
+ except Exception as exc: # pragma: no cover
106
+ logging.warning("SNAC model could not be loaded – %s", exc)
107
+ _snac_model = None
108
 
109
 
110
  def _codes_to_audio(codes: List[int]) -> bytes:
 
172
  OrpheusVoice("Lea", "f"),
173
  ]
174
  def _load_snac(self, model_name: str = SNAC_MODEL):
175
+ """
176
+ Lädt den SNAC-Decoder auf CPU/GPU.
177
+ Fällt bei jedem Fehler sauber auf CPU zurück.
178
+ """
179
+ device = "cuda" if torch.cuda.is_available() else "cpu"
180
+ try:
181
+ snac = SNAC.from_pretrained(model_name).to(device)
182
+ if device == "cuda": # half() nur auf GPU – ältere SNAC-Versionen haben keine .half()
183
+ snac = snac.half()
184
+ snac.eval()
185
+ logging.info(f"SNAC {snac_version} loaded on {device}")
186
+ return snac
187
+ except Exception as e:
188
+ logging.exception("SNAC load failed – running with silent fallback")
189
+ return None
190
  # ---------------------------------------------------------------------
191
  def __init__(
192
  self,