Upload folder using huggingface_hub
Browse files- fabric_state/checkpoint.pt +1 -1
- generation_config.json +4 -0
- model.safetensors +1 -1
- pico_decoder.py +268 -5
fabric_state/checkpoint.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 135543171
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f78990840c8b2a26e89eea5f5414a84a9e8a0c76b9637d3cac17ec22e5486678
|
| 3 |
size 135543171
|
generation_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"transformers_version": "4.48.3",
|
| 3 |
+
"vocab_size": 50304
|
| 4 |
+
}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 45143592
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3084d44929c019203e308a3f500b8792ca69ff273c69edb7eb6a433268e540f9
|
| 3 |
size 45143592
|
pico_decoder.py
CHANGED
|
@@ -31,7 +31,8 @@ import torch
|
|
| 31 |
import torch.nn as nn
|
| 32 |
import torch.nn.functional as F
|
| 33 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 34 |
-
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
|
| 35 |
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
|
| 36 |
|
| 37 |
try:
|
|
@@ -134,7 +135,7 @@ class RoPE(nn.Module):
|
|
| 134 |
Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
|
| 135 |
|
| 136 |
Note other implementations will use cos and sin directly, but using the complex
|
| 137 |
-
number representation is (probably
|
| 138 |
|
| 139 |
e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
|
| 140 |
"""
|
|
@@ -314,7 +315,7 @@ class Attention(nn.Module):
|
|
| 314 |
queries.contiguous(),
|
| 315 |
keys.contiguous(),
|
| 316 |
values.contiguous(),
|
| 317 |
-
attn_mask=mask.to(queries.dtype),
|
| 318 |
enable_gqa=apply_gqa,
|
| 319 |
)
|
| 320 |
|
|
@@ -556,9 +557,9 @@ class PicoDecoderHFConfig(PretrainedConfig):
|
|
| 556 |
return cls.from_dict(asdict(model_config))
|
| 557 |
|
| 558 |
|
| 559 |
-
class PicoDecoderHF(PreTrainedModel):
|
| 560 |
"""
|
| 561 |
-
HuggingFace wrapper for the Pico model.
|
| 562 |
|
| 563 |
Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
|
| 564 |
wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
|
|
@@ -571,10 +572,18 @@ class PicoDecoderHF(PreTrainedModel):
|
|
| 571 |
|
| 572 |
config_class = PicoDecoderHFConfig
|
| 573 |
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
|
|
|
| 574 |
|
| 575 |
def __init__(self, config: PicoDecoderHFConfig):
|
| 576 |
super().__init__(config)
|
| 577 |
self.pico_decoder = PicoDecoder(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
|
| 579 |
def forward(
|
| 580 |
self,
|
|
@@ -601,8 +610,262 @@ class PicoDecoderHF(PreTrainedModel):
|
|
| 601 |
logits=logits,
|
| 602 |
)
|
| 603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
# Register for auto classes
|
| 606 |
PicoDecoderHFConfig.register_for_auto_class()
|
| 607 |
PicoDecoderHF.register_for_auto_class("AutoModel")
|
| 608 |
PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
import torch.nn as nn
|
| 32 |
import torch.nn.functional as F
|
| 33 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 34 |
+
from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
|
| 35 |
+
from transformers.generation import GenerationConfig
|
| 36 |
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
|
| 37 |
|
| 38 |
try:
|
|
|
|
| 135 |
Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
|
| 136 |
|
| 137 |
Note other implementations will use cos and sin directly, but using the complex
|
| 138 |
+
number representation is (probably) more efficient:
|
| 139 |
|
| 140 |
e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
|
| 141 |
"""
|
|
|
|
| 315 |
queries.contiguous(),
|
| 316 |
keys.contiguous(),
|
| 317 |
values.contiguous(),
|
| 318 |
+
attn_mask=mask.to(queries.dtype) if mask is not None else None,
|
| 319 |
enable_gqa=apply_gqa,
|
| 320 |
)
|
| 321 |
|
|
|
|
| 557 |
return cls.from_dict(asdict(model_config))
|
| 558 |
|
| 559 |
|
| 560 |
+
class PicoDecoderHF(PreTrainedModel, GenerationMixin):
|
| 561 |
"""
|
| 562 |
+
HuggingFace wrapper for the Pico model with generation support.
|
| 563 |
|
| 564 |
Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
|
| 565 |
wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
|
|
|
|
| 572 |
|
| 573 |
config_class = PicoDecoderHFConfig
|
| 574 |
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
| 575 |
+
main_input_name = "input_ids"
|
| 576 |
|
| 577 |
def __init__(self, config: PicoDecoderHFConfig):
|
| 578 |
super().__init__(config)
|
| 579 |
self.pico_decoder = PicoDecoder(config)
|
| 580 |
+
# Initialize generation config with defaults
|
| 581 |
+
self.generation_config = GenerationConfig()
|
| 582 |
+
# Set some reasonable defaults for the model
|
| 583 |
+
if hasattr(config, "max_position_embeddings"):
|
| 584 |
+
self.generation_config.max_length = config.max_position_embeddings
|
| 585 |
+
if hasattr(config, "vocab_size"):
|
| 586 |
+
self.generation_config.vocab_size = config.vocab_size
|
| 587 |
|
| 588 |
def forward(
|
| 589 |
self,
|
|
|
|
| 610 |
logits=logits,
|
| 611 |
)
|
| 612 |
|
| 613 |
+
def prepare_inputs_for_generation(
|
| 614 |
+
self,
|
| 615 |
+
input_ids: torch.LongTensor,
|
| 616 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 617 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 618 |
+
**kwargs,
|
| 619 |
+
) -> Dict[str, Any]:
|
| 620 |
+
"""
|
| 621 |
+
Prepare inputs for generation.
|
| 622 |
+
|
| 623 |
+
Args:
|
| 624 |
+
input_ids: Input token IDs
|
| 625 |
+
past_key_values: Cached key-value pairs from previous forward passes
|
| 626 |
+
attention_mask: Attention mask for the input
|
| 627 |
+
**kwargs: Additional arguments
|
| 628 |
+
|
| 629 |
+
Returns:
|
| 630 |
+
Dictionary containing prepared inputs
|
| 631 |
+
"""
|
| 632 |
+
# If we have past_key_values, we only need the last token
|
| 633 |
+
if past_key_values is not None:
|
| 634 |
+
input_ids = input_ids[:, -1:]
|
| 635 |
+
|
| 636 |
+
return {
|
| 637 |
+
"input_ids": input_ids,
|
| 638 |
+
"past_key_values": past_key_values,
|
| 639 |
+
"use_cache": True,
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
def get_input_embeddings(self):
|
| 643 |
+
"""Get the input embeddings layer."""
|
| 644 |
+
return self.pico_decoder.embedding_proj
|
| 645 |
+
|
| 646 |
+
def set_input_embeddings(self, value):
|
| 647 |
+
"""Set the input embeddings layer."""
|
| 648 |
+
self.pico_decoder.embedding_proj = value
|
| 649 |
+
|
| 650 |
+
def get_output_embeddings(self):
|
| 651 |
+
"""Get the output embeddings layer."""
|
| 652 |
+
return self.pico_decoder.de_embedding_proj
|
| 653 |
+
|
| 654 |
+
def set_output_embeddings(self, value):
|
| 655 |
+
"""Set the output embeddings layer."""
|
| 656 |
+
self.pico_decoder.de_embedding_proj = value
|
| 657 |
+
|
| 658 |
+
def get_lm_head(self):
|
| 659 |
+
"""Get the language model head."""
|
| 660 |
+
return self.pico_decoder.de_embedding_proj
|
| 661 |
+
|
| 662 |
+
def can_generate(self) -> bool:
|
| 663 |
+
"""Check if the model can generate text."""
|
| 664 |
+
return True
|
| 665 |
+
|
| 666 |
+
@property
|
| 667 |
+
def is_encoder_decoder(self) -> bool:
|
| 668 |
+
"""Check if the model is an encoder-decoder model."""
|
| 669 |
+
return False
|
| 670 |
+
|
| 671 |
+
@property
|
| 672 |
+
def can_use_cache(self) -> bool:
|
| 673 |
+
"""Check if the model can use KV cache."""
|
| 674 |
+
return True
|
| 675 |
+
|
| 676 |
+
def resize_token_embeddings(
|
| 677 |
+
self, new_num_tokens: Optional[int] = None
|
| 678 |
+
) -> torch.nn.Embedding:
|
| 679 |
+
"""Resize token embeddings."""
|
| 680 |
+
old_embeddings = self.get_input_embeddings()
|
| 681 |
+
if new_num_tokens is None:
|
| 682 |
+
new_num_tokens = old_embeddings.num_embeddings
|
| 683 |
+
|
| 684 |
+
new_embeddings = torch.nn.Embedding(
|
| 685 |
+
new_num_tokens, old_embeddings.embedding_dim
|
| 686 |
+
)
|
| 687 |
+
new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
|
| 688 |
+
old_embeddings.weight.data
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
self.pico_decoder.embedding_proj = new_embeddings
|
| 692 |
+
self.pico_decoder.de_embedding_proj = torch.nn.Linear(
|
| 693 |
+
old_embeddings.embedding_dim, new_num_tokens, bias=False
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
return new_embeddings
|
| 697 |
+
|
| 698 |
|
| 699 |
# Register for auto classes
|
| 700 |
PicoDecoderHFConfig.register_for_auto_class()
|
| 701 |
PicoDecoderHF.register_for_auto_class("AutoModel")
|
| 702 |
PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
########################################################
|
| 706 |
+
#
|
| 707 |
+
# New PicoDecoderForCausalLM class for generation support
|
| 708 |
+
#
|
| 709 |
+
########################################################
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
|
| 713 |
+
"""
|
| 714 |
+
PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
|
| 715 |
+
|
| 716 |
+
This class is designed to work with existing checkpoints and provides full generation support.
|
| 717 |
+
It inherits from the right base classes that HuggingFace expects for text generation.
|
| 718 |
+
"""
|
| 719 |
+
|
| 720 |
+
config_class = PicoDecoderHFConfig
|
| 721 |
+
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
| 722 |
+
main_input_name = "input_ids"
|
| 723 |
+
|
| 724 |
+
def __init__(self, config: PicoDecoderHFConfig):
|
| 725 |
+
super().__init__(config)
|
| 726 |
+
self.pico_decoder = PicoDecoder(config)
|
| 727 |
+
# Initialize generation config with defaults
|
| 728 |
+
self.generation_config = GenerationConfig()
|
| 729 |
+
# Set some reasonable defaults for the model
|
| 730 |
+
if hasattr(config, "max_position_embeddings"):
|
| 731 |
+
self.generation_config.max_length = config.max_position_embeddings
|
| 732 |
+
if hasattr(config, "vocab_size"):
|
| 733 |
+
self.generation_config.vocab_size = config.vocab_size
|
| 734 |
+
|
| 735 |
+
def forward(
|
| 736 |
+
self,
|
| 737 |
+
input_ids: torch.Tensor,
|
| 738 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 739 |
+
use_cache: bool = False,
|
| 740 |
+
**kwargs,
|
| 741 |
+
) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
|
| 742 |
+
"""Forward pass for text generation."""
|
| 743 |
+
logits, past_key_values = self.pico_decoder(
|
| 744 |
+
input_ids, past_key_values, use_cache
|
| 745 |
+
)
|
| 746 |
+
if use_cache:
|
| 747 |
+
return CausalLMOutputWithPast(
|
| 748 |
+
logits=logits,
|
| 749 |
+
past_key_values=past_key_values,
|
| 750 |
+
)
|
| 751 |
+
else:
|
| 752 |
+
return CausalLMOutput(
|
| 753 |
+
logits=logits,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
def prepare_inputs_for_generation(
|
| 757 |
+
self,
|
| 758 |
+
input_ids: torch.LongTensor,
|
| 759 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 760 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 761 |
+
**kwargs,
|
| 762 |
+
) -> Dict[str, Any]:
|
| 763 |
+
"""Prepare inputs for generation."""
|
| 764 |
+
# If we have past_key_values, we only need the last token
|
| 765 |
+
if past_key_values is not None:
|
| 766 |
+
input_ids = input_ids[:, -1:]
|
| 767 |
+
|
| 768 |
+
return {
|
| 769 |
+
"input_ids": input_ids,
|
| 770 |
+
"past_key_values": past_key_values,
|
| 771 |
+
"use_cache": True,
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
def get_input_embeddings(self):
|
| 775 |
+
"""Get the input embeddings layer."""
|
| 776 |
+
return self.pico_decoder.embedding_proj
|
| 777 |
+
|
| 778 |
+
def set_input_embeddings(self, value):
|
| 779 |
+
"""Set the input embeddings layer."""
|
| 780 |
+
self.pico_decoder.embedding_proj = value
|
| 781 |
+
|
| 782 |
+
def get_output_embeddings(self):
|
| 783 |
+
"""Get the output embeddings layer."""
|
| 784 |
+
return self.pico_decoder.de_embedding_proj
|
| 785 |
+
|
| 786 |
+
def set_output_embeddings(self, value):
|
| 787 |
+
"""Set the output embeddings layer."""
|
| 788 |
+
self.pico_decoder.de_embedding_proj = value
|
| 789 |
+
|
| 790 |
+
def get_lm_head(self):
|
| 791 |
+
"""Get the language model head."""
|
| 792 |
+
return self.pico_decoder.de_embedding_proj
|
| 793 |
+
|
| 794 |
+
def can_generate(self) -> bool:
|
| 795 |
+
"""Check if the model can generate text."""
|
| 796 |
+
return True
|
| 797 |
+
|
| 798 |
+
@property
|
| 799 |
+
def is_encoder_decoder(self) -> bool:
|
| 800 |
+
"""Check if the model is an encoder-decoder model."""
|
| 801 |
+
return False
|
| 802 |
+
|
| 803 |
+
@property
|
| 804 |
+
def can_use_cache(self) -> bool:
|
| 805 |
+
"""Check if the model can use KV cache."""
|
| 806 |
+
return True
|
| 807 |
+
|
| 808 |
+
def resize_token_embeddings(
|
| 809 |
+
self, new_num_tokens: Optional[int] = None
|
| 810 |
+
) -> torch.nn.Embedding:
|
| 811 |
+
"""Resize token embeddings."""
|
| 812 |
+
old_embeddings = self.get_input_embeddings()
|
| 813 |
+
if new_num_tokens is None:
|
| 814 |
+
new_num_tokens = old_embeddings.num_embeddings
|
| 815 |
+
|
| 816 |
+
new_embeddings = torch.nn.Embedding(
|
| 817 |
+
new_num_tokens, old_embeddings.embedding_dim
|
| 818 |
+
)
|
| 819 |
+
new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
|
| 820 |
+
old_embeddings.weight.data
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
self.pico_decoder.embedding_proj = new_embeddings
|
| 824 |
+
self.pico_decoder.de_embedding_proj = torch.nn.Linear(
|
| 825 |
+
old_embeddings.embedding_dim, new_num_tokens, bias=False
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
return new_embeddings
|
| 829 |
+
|
| 830 |
+
@classmethod
|
| 831 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 832 |
+
"""
|
| 833 |
+
Load a pretrained model from a checkpoint.
|
| 834 |
+
|
| 835 |
+
This method handles loading from both the old PicoDecoderHF format and the new format.
|
| 836 |
+
"""
|
| 837 |
+
# First try to load with the new class
|
| 838 |
+
try:
|
| 839 |
+
return super().from_pretrained(
|
| 840 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
| 841 |
+
)
|
| 842 |
+
except Exception as e:
|
| 843 |
+
print(f"Failed to load with new class: {e}")
|
| 844 |
+
print("Attempting to load with legacy class and convert...")
|
| 845 |
+
|
| 846 |
+
# Try to load with the old class and convert
|
| 847 |
+
try:
|
| 848 |
+
from transformers import AutoModel
|
| 849 |
+
|
| 850 |
+
old_model = AutoModel.from_pretrained(
|
| 851 |
+
pretrained_model_name_or_path,
|
| 852 |
+
trust_remote_code=True,
|
| 853 |
+
*model_args,
|
| 854 |
+
**kwargs,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# Create new model instance
|
| 858 |
+
new_model = cls(old_model.config)
|
| 859 |
+
|
| 860 |
+
# Copy state dict
|
| 861 |
+
new_model.load_state_dict(old_model.state_dict(), strict=False)
|
| 862 |
+
|
| 863 |
+
return new_model
|
| 864 |
+
|
| 865 |
+
except Exception as e2:
|
| 866 |
+
print(f"Failed to convert from legacy format: {e2}")
|
| 867 |
+
raise e
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
# Register the new class
|
| 871 |
+
PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|