jixin0101 commited on
Commit
e3ad07e
·
1 Parent(s): 327c80d

Add variant argument to support loading fp16 model

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. pipeline_objectclear.py +2 -0
app.py CHANGED
@@ -283,6 +283,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
283
  pipe = ObjectClearPipeline.from_pretrained_with_custom_modules(
284
  "jixin0101/ObjectClear",
285
  torch_dtype=torch.float16,
 
286
  save_cross_attn=True,
287
  cache_dir="/home/jovyan/shared/jixinzhao/models",
288
  )
 
283
  pipe = ObjectClearPipeline.from_pretrained_with_custom_modules(
284
  "jixin0101/ObjectClear",
285
  torch_dtype=torch.float16,
286
+ variant='fp16',
287
  save_cross_attn=True,
288
  cache_dir="/home/jovyan/shared/jixinzhao/models",
289
  )
pipeline_objectclear.py CHANGED
@@ -469,6 +469,7 @@ class ObjectClearPipeline(
469
  pretrained_model_name_or_path,
470
  torch_dtype=torch.float32,
471
  cache_dir=None,
 
472
  **kwargs,
473
  ):
474
  from safetensors.torch import load_file
@@ -497,6 +498,7 @@ class ObjectClearPipeline(
497
  image_prompt_encoder=image_prompt_encoder,
498
  postfuse_module=postfuse_module,
499
  cache_dir=cache_dir,
 
500
  **kwargs,
501
  )
502
 
 
469
  pretrained_model_name_or_path,
470
  torch_dtype=torch.float32,
471
  cache_dir=None,
472
+ variant=None,
473
  **kwargs,
474
  ):
475
  from safetensors.torch import load_file
 
498
  image_prompt_encoder=image_prompt_encoder,
499
  postfuse_module=postfuse_module,
500
  cache_dir=cache_dir,
501
+ variant=variant,
502
  **kwargs,
503
  )
504