import torch import torch.nn as nn from functools import lru_cache import copy @lru_cache(maxsize=1) def get_cached_zeros(numel, device="cpu", dtype=torch.float32): return torch.zeros(numel, device=device, dtype=dtype) class StreamingODEWrapperForPrefix(nn.Module): def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule="linear", cfg_token_id=0): super(StreamingODEWrapperForPrefix, self).__init__() self.net = net self.x_mask = x_mask self.x_cond = x_cond assert use_cfg == False, "cfg is not supported in streaming detokenizer" self.use_cfg = use_cfg self.use_cfg_rescale = use_cfg_rescale self.cfg_init = cfg_init self.cfg_scale = cfg_scale self.cfg_token_id = cfg_token_id self.cfg_schedule = cfg_schedule self.position_ids = None self.seq_len = None self.incremental_state = {} self.kv_cache_tokens = 0 self.cu_seqlens = None self.cu_maxlen = None self.cu_seqlens_k = None self.cu_maxlen_k = None self.previous_seqlen = None def clear_all_states(self): self.incremental_state = {} self.kv_cache_tokens = 0 self.cu_seqlens = None self.cu_maxlen = None self.cu_seqlens_k = None self.cu_maxlen_k = None self.previous_seqlen = None def state_dict(self): return { "incremental_state": copy.deepcopy(self.incremental_state), "kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens), "cu_seqlens": copy.deepcopy(self.cu_seqlens), "cu_maxlen": copy.deepcopy(self.cu_maxlen), "cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k), "cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k), "previous_seqlen": copy.deepcopy(self.previous_seqlen) } def load_state_dict(self, state_dict): self.incremental_state = state_dict["incremental_state"] self.kv_cache_tokens = state_dict["kv_cache_tokens"] self.cu_seqlens = state_dict["cu_seqlens"] self.cu_maxlen = state_dict["cu_maxlen"] self.cu_seqlens_k = state_dict["cu_seqlens_k"] self.cu_maxlen_k = state_dict["cu_maxlen_k"] self.previous_seqlen = state_dict["previous_seqlen"] def set_conditions(self, x_mask, x_cond, start_position_id, cache={}): if not self.use_cfg: self.x_mask = x_mask self.x_cond = x_cond else: self.x_cond = torch.cat((x_cond, x_cond), dim=0) self.x_mask = torch.cat((x_mask, x_mask), dim=0) position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)] position_ids = torch.tensor([position_ids_cur]) if not self.use_cfg: self.position_ids = position_ids.to(self.x_cond.device).long() self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long() else: self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long() self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long() cu_seqlens = torch.cumsum(self.seq_len, dim=0) self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int() self.cu_maxlen = self.seq_len.cpu().max() if self.cu_seqlens_k is None: self.cu_seqlens_k = self.cu_seqlens self.cu_maxlen_k = self.cu_maxlen previous_seqlen = self.seq_len else: previous_seqlen_old = cache["previous_seqlen"] previous_seqlen = previous_seqlen_old + self.seq_len # calculate cu_seqlens_k cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0) self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int() self.cu_maxlen_k = previous_seqlen.cpu().max() self.previous_seqlen = previous_seqlen ret_cache = { "previous_seqlen": previous_seqlen } return ret_cache def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={"previous_seqlen"}): assert reserve_kv_cache_tokens <= max_kv_cache_tokens, "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens" for layer_idx, layer_cache in self.incremental_state.items(): # update attention kv cache layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"] layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"] self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1] if self.kv_cache_tokens > max_kv_cache_tokens: # drop old tokens from reserve kv cache tokens to max_kv_cache_tokens reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens if reserve_kv_cache_tokens == 0: layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:] layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:] elif reserve_tokens_excludeprompt == 0: layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens] layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens] else: layer_cache["attn_kvcache"]["prev_k"] = torch.cat([ layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens], layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:] ], dim=1) layer_cache["attn_kvcache"]["prev_v"] = torch.cat([ layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens], layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:] ], dim=1) bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0] self.previous_seqlen = torch.Tensor([layer_cache["attn_kvcache"]["prev_k"].shape[1] for i in range(bsz)]).to(layer_cache["attn_kvcache"]["prev_k"].device).long() condition_cache["previous_seqlen"] = self.previous_seqlen self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1] # clear current cache layer_cache["attn_kvcache"].pop("cur_k") layer_cache["attn_kvcache"].pop("cur_v") def forward(self, t, x, args=None): # t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long() t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long() if self.use_cfg: raise NotImplementedError("cfg is not supported in streaming detokenizer.") else: pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids, cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen, cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k, incremental_state=self.incremental_state, nopadding=True, mask=None, seq_len=None ) return pred_noise