# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput from typing import Union, Optional, Tuple import torch class AdditFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, step_index: Optional[int] = None, ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if step_index is not None: self._step_index = step_index if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] prev_sample = sample + (sigma_next - sigma) * model_output # Calculate X_0 x_0 = sample - sigma * model_output # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) x_0 = x_0.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample, x_0) return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)