Ojttt commited on
Commit
1de2f69
·
verified ·
1 Parent(s): 05fea73

Upload modeling_deepseek.py

Browse files
Files changed (1) hide show
  1. modeling_deepseek.py +36 -85
modeling_deepseek.py CHANGED
@@ -27,6 +27,7 @@ import torch.nn.functional as F
27
  import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
30
 
31
  from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache
@@ -469,11 +470,17 @@ class MoEGate(nn.Module):
469
 
470
  return topk_idx, topk_weight
471
 
472
- class DeepseekV3MoE(nn.Module):
473
- """
474
- A mixed expert module containing shared experts.
475
- """
476
 
 
 
 
 
 
 
477
  def __init__(self, config):
478
  super().__init__()
479
  self.config = config
@@ -487,9 +494,7 @@ class DeepseekV3MoE(nn.Module):
487
  self.experts = nn.ModuleList(
488
  [
489
  (
490
- DeepseekV3MLP(
491
- config, intermediate_size=config.moe_intermediate_size
492
- )
493
  if i >= self.ep_rank * self.experts_per_rank
494
  and i < (self.ep_rank + 1) * self.experts_per_rank
495
  else None
@@ -503,18 +508,14 @@ class DeepseekV3MoE(nn.Module):
503
  self.ep_rank = 0
504
  self.experts = nn.ModuleList(
505
  [
506
- DeepseekV3MLP(
507
- config, intermediate_size=config.moe_intermediate_size
508
- )
509
  for i in range(config.n_routed_experts)
510
  ]
511
  )
512
  self.gate = MoEGate(config)
513
  if config.n_shared_experts is not None:
514
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
515
- self.shared_experts = DeepseekV3MLP(
516
- config=config, intermediate_size=intermediate_size
517
- )
518
 
519
  def forward(self, hidden_states):
520
  identity = hidden_states
@@ -530,79 +531,29 @@ class DeepseekV3MoE(nn.Module):
530
 
531
  @torch.no_grad()
532
  def moe_infer(self, x, topk_ids, topk_weight):
533
- cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
534
- cnts.scatter_(1, topk_ids, 1)
535
- tokens_per_expert = cnts.sum(dim=0)
536
- idxs = topk_ids.view(-1).argsort()
537
- sorted_tokens = x[idxs // topk_ids.shape[1]]
538
- sorted_tokens_shape = sorted_tokens.shape
539
- if self.ep_size > 1:
540
- tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
541
- tokens_per_expert_group = tokens_per_expert.new_empty(
542
- tokens_per_expert.shape[0]
543
- )
544
- dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
545
- output_splits = (
546
- tokens_per_expert_group.view(self.ep_size, -1)
547
- .sum(1)
548
- .cpu()
549
- .numpy()
550
- .tolist()
551
- )
552
- gathered_tokens = sorted_tokens.new_empty(
553
- tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
554
- )
555
- input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
556
- dist.all_to_all(
557
- list(gathered_tokens.split(output_splits)),
558
- list(sorted_tokens.split(input_split_sizes)),
559
- )
560
- tokens_per_expert_post_gather = tokens_per_expert_group.view(
561
- self.ep_size, self.experts_per_rank
562
- ).sum(dim=0)
563
- gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
564
- s = 0
565
- for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
566
- gatherd_idxs[s : s + k] = i % self.experts_per_rank
567
- s += k
568
- gatherd_idxs = gatherd_idxs.argsort()
569
- sorted_tokens = gathered_tokens[gatherd_idxs]
570
- tokens_per_expert = tokens_per_expert_post_gather
571
- tokens_per_expert = tokens_per_expert.cpu().numpy()
572
-
573
- outputs = []
574
- start_idx = 0
575
- for i, num_tokens in enumerate(tokens_per_expert):
576
- end_idx = start_idx + num_tokens
577
- if num_tokens == 0:
578
- continue
579
- expert = self.experts[i + self.ep_rank * self.experts_per_rank]
580
- tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
581
- expert_out = expert(tokens_for_this_expert)
582
- outputs.append(expert_out)
583
- start_idx = end_idx
584
-
585
- outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
586
- if self.ep_size > 1:
587
- new_x = torch.empty_like(outs)
588
- new_x[gatherd_idxs] = outs
589
- gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
590
- dist.all_to_all(
591
- list(gathered_tokens.split(input_split_sizes)),
592
- list(new_x.split(output_splits)),
593
- )
594
- outs = gathered_tokens
595
-
596
- new_x = torch.empty_like(outs)
597
- new_x[idxs] = outs
598
- final_out = (
599
- new_x.view(*topk_ids.shape, -1)
600
- .type(topk_weight.dtype)
601
- .mul_(topk_weight.unsqueeze(dim=-1))
602
- .sum(dim=1)
603
- .type(new_x.dtype)
604
  )
605
- return final_out
606
 
607
 
608
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
 
27
  import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+ from torch.library import custom_op
31
 
32
  from transformers.activations import ACT2FN
33
  from transformers.cache_utils import Cache, DynamicCache
 
470
 
471
  return topk_idx, topk_weight
472
 
473
+ @torch.library.custom_op("deepseek::moe_infer_op", mutates_args=())
474
+ def moe_infer_fake(x: torch.Tensor, gate_proj_weight: torch.Tensor, up_proj_weight: torch.Tensor, down_proj_weight: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
475
+ final_out = torch.empty_like(x)
476
+ return final_out
477
 
478
+ # FakeTensor 커널 등록
479
+ @moe_infer_fake.register_fake
480
+ def _(x, gate_proj_weight, up_proj_weight, down_proj_weight, topk_ids, topk_weight):
481
+ return torch.empty_like(x)
482
+
483
+ class DeepseekV3MoE(nn.Module):
484
  def __init__(self, config):
485
  super().__init__()
486
  self.config = config
 
494
  self.experts = nn.ModuleList(
495
  [
496
  (
497
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
 
 
498
  if i >= self.ep_rank * self.experts_per_rank
499
  and i < (self.ep_rank + 1) * self.experts_per_rank
500
  else None
 
508
  self.ep_rank = 0
509
  self.experts = nn.ModuleList(
510
  [
511
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
 
 
512
  for i in range(config.n_routed_experts)
513
  ]
514
  )
515
  self.gate = MoEGate(config)
516
  if config.n_shared_experts is not None:
517
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
518
+ self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size)
 
 
519
 
520
  def forward(self, hidden_states):
521
  identity = hidden_states
 
531
 
532
  @torch.no_grad()
533
  def moe_infer(self, x, topk_ids, topk_weight):
534
+ # self.experts MLP모듈별 weight 추출
535
+ gate_proj_weight = []
536
+ up_proj_weight = []
537
+ down_proj_weight = []
538
+ for i in range(len(self.experts)):
539
+ expert = self.experts[i]
540
+ if expert is not None:
541
+ gate_proj_weight.append(expert.gate_proj.weight.unsqueeze(0))
542
+ up_proj_weight.append(expert.up_proj.weight.unsqueeze(0))
543
+ down_proj_weight.append(expert.down_proj.weight.unsqueeze(0))
544
+
545
+ gate_proj_weight = torch.cat(gate_proj_weight, dim=0) # [num_experts, hidden_size, intermediate_size]
546
+ up_proj_weight = torch.cat(up_proj_weight, dim=0) # [num_experts, hidden_size, intermediate_size]
547
+ down_proj_weight = torch.cat(down_proj_weight, dim=0) # [num_experts, intermediate_size, hidden_size]
548
+
549
+ return moe_infer_fake(
550
+ x=x,
551
+ gate_proj_weight=gate_proj_weight,
552
+ up_proj_weight=up_proj_weight,
553
+ down_proj_weight=down_proj_weight,
554
+ topk_ids=topk_ids,
555
+ topk_weight=topk_weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  )
 
557
 
558
 
559
  # Copied from transformers.models.llama.modeling_llama.repeat_kv