reachomk commited on
Commit
37798ad
·
verified ·
1 Parent(s): 353e8fc

Update gen2seg_sd_pipeline.py

Browse files
Files changed (1) hide show
  1. gen2seg_sd_pipeline.py +4 -4
gen2seg_sd_pipeline.py CHANGED
@@ -50,7 +50,7 @@ def zeros_tensor(
50
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
 
52
  @dataclass
53
- class Gen2SegSDSegOutput(BaseOutput):
54
  """
55
  Output class for gen2seg Instance Segmentation prediction pipeline.
56
 
@@ -67,7 +67,7 @@ class Gen2SegSDSegOutput(BaseOutput):
67
  latent: Union[None, torch.Tensor]
68
 
69
 
70
- class Gen2SegSDPipeline(DiffusionPipeline):
71
  """
72
  # add
73
  Pipeline for Instance Segmentation prediction using our Stable Diffusion model.
@@ -251,7 +251,7 @@ class Gen2SegSDPipeline(DiffusionPipeline):
251
  within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
252
  `latents` argument.
253
  return_dict (`bool`, *optional*, defaults to `True`):
254
- Whether or not to return a [`Gen2SegSDSegOutput`] instead of a plain tuple.
255
 
256
  # add
257
  E2E FT models are deterministic single step models involving no ensembling, i.e. E=1.
@@ -397,7 +397,7 @@ class Gen2SegSDPipeline(DiffusionPipeline):
397
  if not return_dict:
398
  return (prediction, pred_latent)
399
 
400
- return Gen2SegSDSegOutput(
401
  prediction=prediction,
402
  latent=pred_latent,
403
  )
 
50
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
 
52
  @dataclass
53
+ class gen2segSDSegOutput(BaseOutput):
54
  """
55
  Output class for gen2seg Instance Segmentation prediction pipeline.
56
 
 
67
  latent: Union[None, torch.Tensor]
68
 
69
 
70
+ class gen2segSDPipeline(DiffusionPipeline):
71
  """
72
  # add
73
  Pipeline for Instance Segmentation prediction using our Stable Diffusion model.
 
251
  within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
252
  `latents` argument.
253
  return_dict (`bool`, *optional*, defaults to `True`):
254
+ Whether or not to return a [`gen2segSDSegOutput`] instead of a plain tuple.
255
 
256
  # add
257
  E2E FT models are deterministic single step models involving no ensembling, i.e. E=1.
 
397
  if not return_dict:
398
  return (prediction, pred_latent)
399
 
400
+ return gen2segSDSegOutput(
401
  prediction=prediction,
402
  latent=pred_latent,
403
  )