| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						Streaming module API that should be implemented by all Streaming components, | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from contextlib import contextmanager | 
					
					
						
						| 
							 | 
						import typing as tp | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						State = tp.Dict[str, torch.Tensor] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class StreamingModule(nn.Module): | 
					
					
						
						| 
							 | 
						    """Common API for streaming components. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Each streaming component has a streaming state, which is just a dict[str, Tensor]. | 
					
					
						
						| 
							 | 
						    By convention, the first dim of each tensor must be the batch size. | 
					
					
						
						| 
							 | 
						    Don't use dots in the key names, as this would clash with submodules | 
					
					
						
						| 
							 | 
						    (like in state_dict). | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    If `self._is_streaming` is True, the component should use and remember | 
					
					
						
						| 
							 | 
						    the proper state inside `self._streaming_state`. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    To set a streaming component in streaming state, use | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        with module.streaming(): | 
					
					
						
						| 
							 | 
						            ... | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This will automatically reset the streaming state when exiting the context manager. | 
					
					
						
						| 
							 | 
						    This also automatically propagates to all streaming children module. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Some module might also implement the `StreamingModule.flush` method, although | 
					
					
						
						| 
							 | 
						    this one is trickier, as all parents module must be StreamingModule and implement | 
					
					
						
						| 
							 | 
						    it as well for it to work properly. See `StreamingSequential` after. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    def __init__(self) -> None: | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self._streaming_state: State = {} | 
					
					
						
						| 
							 | 
						        self._is_streaming = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _apply_named_streaming(self, fn: tp.Any): | 
					
					
						
						| 
							 | 
						        for name, module in self.named_modules(): | 
					
					
						
						| 
							 | 
						            if isinstance(module, StreamingModule): | 
					
					
						
						| 
							 | 
						                fn(name, module) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _set_streaming(self, streaming: bool): | 
					
					
						
						| 
							 | 
						        def _set_streaming(name, module): | 
					
					
						
						| 
							 | 
						            module._is_streaming = streaming | 
					
					
						
						| 
							 | 
						        self._apply_named_streaming(_set_streaming) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @contextmanager | 
					
					
						
						| 
							 | 
						    def streaming(self): | 
					
					
						
						| 
							 | 
						        """Context manager to enter streaming mode. Reset streaming state on exit. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self._set_streaming(True) | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            yield | 
					
					
						
						| 
							 | 
						        finally: | 
					
					
						
						| 
							 | 
						            self._set_streaming(False) | 
					
					
						
						| 
							 | 
						            self.reset_streaming() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def reset_streaming(self): | 
					
					
						
						| 
							 | 
						        """Reset the streaming state. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        def _reset(name: str, module: StreamingModule): | 
					
					
						
						| 
							 | 
						            module._streaming_state.clear() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self._apply_named_streaming(_reset) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_streaming_state(self) -> State: | 
					
					
						
						| 
							 | 
						        """Return the streaming state, including that of sub-modules. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        state: State = {} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def _add(name: str, module: StreamingModule): | 
					
					
						
						| 
							 | 
						            if name: | 
					
					
						
						| 
							 | 
						                name += "." | 
					
					
						
						| 
							 | 
						            for key, value in module._streaming_state.items(): | 
					
					
						
						| 
							 | 
						                state[name + key] = value | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self._apply_named_streaming(_add) | 
					
					
						
						| 
							 | 
						        return state | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def set_streaming_state(self, state: State): | 
					
					
						
						| 
							 | 
						        """Set the streaming state, including that of sub-modules. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        state = dict(state) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def _set(name: str, module: StreamingModule): | 
					
					
						
						| 
							 | 
						            if name: | 
					
					
						
						| 
							 | 
						                name += "." | 
					
					
						
						| 
							 | 
						            module._streaming_state.clear() | 
					
					
						
						| 
							 | 
						            for key, value in list(state.items()): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if key.startswith(name): | 
					
					
						
						| 
							 | 
						                    local_key = key[len(name):] | 
					
					
						
						| 
							 | 
						                    if '.' not in local_key: | 
					
					
						
						| 
							 | 
						                        module._streaming_state[local_key] = value | 
					
					
						
						| 
							 | 
						                        del state[key] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self._apply_named_streaming(_set) | 
					
					
						
						| 
							 | 
						        assert len(state) == 0, list(state.keys()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def flush(self, x: tp.Optional[torch.Tensor] = None): | 
					
					
						
						| 
							 | 
						        """Flush any remaining outputs that were waiting for completion. | 
					
					
						
						| 
							 | 
						        Typically, for convolutions, this will add the final padding | 
					
					
						
						| 
							 | 
						        and process the last buffer. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        This should take an optional argument `x`, which will be provided | 
					
					
						
						| 
							 | 
						        if a module before this one in the streaming pipeline has already | 
					
					
						
						| 
							 | 
						        spitted out a flushed out buffer. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if x is None: | 
					
					
						
						| 
							 | 
						            return None | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            return self(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class StreamingSequential(StreamingModule, nn.Sequential): | 
					
					
						
						| 
							 | 
						    """A streaming compatible alternative of `nn.Sequential`. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    def flush(self, x: tp.Optional[torch.Tensor] = None): | 
					
					
						
						| 
							 | 
						        for module in self: | 
					
					
						
						| 
							 | 
						            if isinstance(module, StreamingModule): | 
					
					
						
						| 
							 | 
						                x = module.flush(x) | 
					
					
						
						| 
							 | 
						            elif x is not None: | 
					
					
						
						| 
							 | 
						                x = module(x) | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 |