jixin0101 commited on
Commit
327c80d
·
1 Parent(s): 459ff92

Pipeline add torch dtype code

Browse files
Files changed (1) hide show
  1. pipeline_objectclear.py +3 -1
pipeline_objectclear.py CHANGED
@@ -463,7 +463,6 @@ class ObjectClearPipeline(
463
  )
464
 
465
 
466
- @classmethod
467
  @classmethod
468
  def from_pretrained_with_custom_modules(
469
  cls,
@@ -500,6 +499,9 @@ class ObjectClearPipeline(
500
  cache_dir=cache_dir,
501
  **kwargs,
502
  )
 
 
 
503
 
504
  return pipe
505
 
 
463
  )
464
 
465
 
 
466
  @classmethod
467
  def from_pretrained_with_custom_modules(
468
  cls,
 
499
  cache_dir=cache_dir,
500
  **kwargs,
501
  )
502
+
503
+ if torch_dtype is not None:
504
+ pipe.to(dtype=torch_dtype)
505
 
506
  return pipe
507