Upload modeling_deepseek.py
Browse files- modeling_deepseek.py +36 -85
modeling_deepseek.py
CHANGED
@@ -27,6 +27,7 @@ import torch.nn.functional as F
|
|
27 |
import torch.utils.checkpoint
|
28 |
from torch import nn
|
29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
30 |
|
31 |
from transformers.activations import ACT2FN
|
32 |
from transformers.cache_utils import Cache, DynamicCache
|
@@ -469,11 +470,17 @@ class MoEGate(nn.Module):
|
|
469 |
|
470 |
return topk_idx, topk_weight
|
471 |
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
def __init__(self, config):
|
478 |
super().__init__()
|
479 |
self.config = config
|
@@ -487,9 +494,7 @@ class DeepseekV3MoE(nn.Module):
|
|
487 |
self.experts = nn.ModuleList(
|
488 |
[
|
489 |
(
|
490 |
-
DeepseekV3MLP(
|
491 |
-
config, intermediate_size=config.moe_intermediate_size
|
492 |
-
)
|
493 |
if i >= self.ep_rank * self.experts_per_rank
|
494 |
and i < (self.ep_rank + 1) * self.experts_per_rank
|
495 |
else None
|
@@ -503,18 +508,14 @@ class DeepseekV3MoE(nn.Module):
|
|
503 |
self.ep_rank = 0
|
504 |
self.experts = nn.ModuleList(
|
505 |
[
|
506 |
-
DeepseekV3MLP(
|
507 |
-
config, intermediate_size=config.moe_intermediate_size
|
508 |
-
)
|
509 |
for i in range(config.n_routed_experts)
|
510 |
]
|
511 |
)
|
512 |
self.gate = MoEGate(config)
|
513 |
if config.n_shared_experts is not None:
|
514 |
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
515 |
-
self.shared_experts = DeepseekV3MLP(
|
516 |
-
config=config, intermediate_size=intermediate_size
|
517 |
-
)
|
518 |
|
519 |
def forward(self, hidden_states):
|
520 |
identity = hidden_states
|
@@ -530,79 +531,29 @@ class DeepseekV3MoE(nn.Module):
|
|
530 |
|
531 |
@torch.no_grad()
|
532 |
def moe_infer(self, x, topk_ids, topk_weight):
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
|
556 |
-
dist.all_to_all(
|
557 |
-
list(gathered_tokens.split(output_splits)),
|
558 |
-
list(sorted_tokens.split(input_split_sizes)),
|
559 |
-
)
|
560 |
-
tokens_per_expert_post_gather = tokens_per_expert_group.view(
|
561 |
-
self.ep_size, self.experts_per_rank
|
562 |
-
).sum(dim=0)
|
563 |
-
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
|
564 |
-
s = 0
|
565 |
-
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
566 |
-
gatherd_idxs[s : s + k] = i % self.experts_per_rank
|
567 |
-
s += k
|
568 |
-
gatherd_idxs = gatherd_idxs.argsort()
|
569 |
-
sorted_tokens = gathered_tokens[gatherd_idxs]
|
570 |
-
tokens_per_expert = tokens_per_expert_post_gather
|
571 |
-
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
572 |
-
|
573 |
-
outputs = []
|
574 |
-
start_idx = 0
|
575 |
-
for i, num_tokens in enumerate(tokens_per_expert):
|
576 |
-
end_idx = start_idx + num_tokens
|
577 |
-
if num_tokens == 0:
|
578 |
-
continue
|
579 |
-
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
580 |
-
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
581 |
-
expert_out = expert(tokens_for_this_expert)
|
582 |
-
outputs.append(expert_out)
|
583 |
-
start_idx = end_idx
|
584 |
-
|
585 |
-
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
586 |
-
if self.ep_size > 1:
|
587 |
-
new_x = torch.empty_like(outs)
|
588 |
-
new_x[gatherd_idxs] = outs
|
589 |
-
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
|
590 |
-
dist.all_to_all(
|
591 |
-
list(gathered_tokens.split(input_split_sizes)),
|
592 |
-
list(new_x.split(output_splits)),
|
593 |
-
)
|
594 |
-
outs = gathered_tokens
|
595 |
-
|
596 |
-
new_x = torch.empty_like(outs)
|
597 |
-
new_x[idxs] = outs
|
598 |
-
final_out = (
|
599 |
-
new_x.view(*topk_ids.shape, -1)
|
600 |
-
.type(topk_weight.dtype)
|
601 |
-
.mul_(topk_weight.unsqueeze(dim=-1))
|
602 |
-
.sum(dim=1)
|
603 |
-
.type(new_x.dtype)
|
604 |
)
|
605 |
-
return final_out
|
606 |
|
607 |
|
608 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
|
|
27 |
import torch.utils.checkpoint
|
28 |
from torch import nn
|
29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
30 |
+
from torch.library import custom_op
|
31 |
|
32 |
from transformers.activations import ACT2FN
|
33 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
470 |
|
471 |
return topk_idx, topk_weight
|
472 |
|
473 |
+
@torch.library.custom_op("deepseek::moe_infer_op", mutates_args=())
|
474 |
+
def moe_infer_fake(x: torch.Tensor, gate_proj_weight: torch.Tensor, up_proj_weight: torch.Tensor, down_proj_weight: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
475 |
+
final_out = torch.empty_like(x)
|
476 |
+
return final_out
|
477 |
|
478 |
+
# FakeTensor 커널 등록
|
479 |
+
@moe_infer_fake.register_fake
|
480 |
+
def _(x, gate_proj_weight, up_proj_weight, down_proj_weight, topk_ids, topk_weight):
|
481 |
+
return torch.empty_like(x)
|
482 |
+
|
483 |
+
class DeepseekV3MoE(nn.Module):
|
484 |
def __init__(self, config):
|
485 |
super().__init__()
|
486 |
self.config = config
|
|
|
494 |
self.experts = nn.ModuleList(
|
495 |
[
|
496 |
(
|
497 |
+
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
|
|
|
|
498 |
if i >= self.ep_rank * self.experts_per_rank
|
499 |
and i < (self.ep_rank + 1) * self.experts_per_rank
|
500 |
else None
|
|
|
508 |
self.ep_rank = 0
|
509 |
self.experts = nn.ModuleList(
|
510 |
[
|
511 |
+
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
|
|
|
|
512 |
for i in range(config.n_routed_experts)
|
513 |
]
|
514 |
)
|
515 |
self.gate = MoEGate(config)
|
516 |
if config.n_shared_experts is not None:
|
517 |
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
518 |
+
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size)
|
|
|
|
|
519 |
|
520 |
def forward(self, hidden_states):
|
521 |
identity = hidden_states
|
|
|
531 |
|
532 |
@torch.no_grad()
|
533 |
def moe_infer(self, x, topk_ids, topk_weight):
|
534 |
+
# self.experts MLP모듈별 weight 추출
|
535 |
+
gate_proj_weight = []
|
536 |
+
up_proj_weight = []
|
537 |
+
down_proj_weight = []
|
538 |
+
for i in range(len(self.experts)):
|
539 |
+
expert = self.experts[i]
|
540 |
+
if expert is not None:
|
541 |
+
gate_proj_weight.append(expert.gate_proj.weight.unsqueeze(0))
|
542 |
+
up_proj_weight.append(expert.up_proj.weight.unsqueeze(0))
|
543 |
+
down_proj_weight.append(expert.down_proj.weight.unsqueeze(0))
|
544 |
+
|
545 |
+
gate_proj_weight = torch.cat(gate_proj_weight, dim=0) # [num_experts, hidden_size, intermediate_size]
|
546 |
+
up_proj_weight = torch.cat(up_proj_weight, dim=0) # [num_experts, hidden_size, intermediate_size]
|
547 |
+
down_proj_weight = torch.cat(down_proj_weight, dim=0) # [num_experts, intermediate_size, hidden_size]
|
548 |
+
|
549 |
+
return moe_infer_fake(
|
550 |
+
x=x,
|
551 |
+
gate_proj_weight=gate_proj_weight,
|
552 |
+
up_proj_weight=up_proj_weight,
|
553 |
+
down_proj_weight=down_proj_weight,
|
554 |
+
topk_ids=topk_ids,
|
555 |
+
topk_weight=topk_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
)
|
|
|
557 |
|
558 |
|
559 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|