cocktailpeanut commited on
Commit
6f7e0d5
·
1 Parent(s): 0b0eadb
Files changed (1) hide show
  1. OmniGen/pipeline.py +2 -0
OmniGen/pipeline.py CHANGED
@@ -60,6 +60,8 @@ class OmniGenPipeline:
60
 
61
  @classmethod
62
  def from_pretrained(cls, model_name, vae_path: str=None):
 
 
63
  if not os.path.exists(model_name):
64
  logger.info("Model not found, downloading...")
65
  cache_folder = os.getenv('HF_HUB_CACHE')
 
60
 
61
  @classmethod
62
  def from_pretrained(cls, model_name, vae_path: str=None):
63
+ device = devicetorch.get(torch)
64
+ print(f">Device={device}")
65
  if not os.path.exists(model_name):
66
  logger.info("Model not found, downloading...")
67
  cache_folder = os.getenv('HF_HUB_CACHE')