Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| import loralib as lora | |
| class LoraConfig: | |
| r: int | |
| lora_alpha: float | |
| lora_dropout: float = 0.0 | |
| def setup_lora(model, lora_config): | |
| # Replace the embedding layer with a LoRA layer | |
| model.embeddings = lora.Embedding( | |
| num_embeddings=model.embeddings.num_embeddings, | |
| embedding_dim=model.embeddings.embedding_dim, | |
| padding_idx=model.embeddings.padding_idx, | |
| r=lora_config.r, | |
| lora_alpha=lora_config.lora_alpha, | |
| ) | |
| model.codebook_embeddings = lora.Embedding( | |
| num_embeddings=model.codebook_embeddings.num_embeddings, | |
| embedding_dim=model.codebook_embeddings.embedding_dim, | |
| padding_idx=model.codebook_embeddings.padding_idx, | |
| r=lora_config.r, | |
| lora_alpha=lora_config.lora_alpha, | |
| ) | |
| # Replace output layer with a LoRA layer | |
| linears = [(model, "output")] | |
| # Replace all linear layers with LoRA layers | |
| for layer in model.layers: | |
| linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) | |
| linears.extend( | |
| [ | |
| (layer.feed_forward, "w1"), | |
| (layer.feed_forward, "w2"), | |
| (layer.feed_forward, "w3"), | |
| ] | |
| ) | |
| if hasattr(model, "fast_layers"): | |
| model.fast_embeddings = lora.Embedding( | |
| num_embeddings=model.fast_embeddings.num_embeddings, | |
| embedding_dim=model.fast_embeddings.embedding_dim, | |
| padding_idx=model.fast_embeddings.padding_idx, | |
| r=lora_config.r, | |
| lora_alpha=lora_config.lora_alpha, | |
| ) | |
| # Dual-AR model | |
| linears.append((model, "fast_output")) | |
| for layer in model.fast_layers: | |
| linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) | |
| linears.extend( | |
| [ | |
| (layer.feed_forward, "w1"), | |
| (layer.feed_forward, "w2"), | |
| (layer.feed_forward, "w3"), | |
| ] | |
| ) | |
| for module, layer in linears: | |
| updated_linear = lora.Linear( | |
| in_features=getattr(module, layer).in_features, | |
| out_features=getattr(module, layer).out_features, | |
| bias=getattr(module, layer).bias, | |
| r=lora_config.r, | |
| lora_alpha=lora_config.lora_alpha, | |
| lora_dropout=lora_config.lora_dropout, | |
| ) | |
| setattr(module, layer, updated_linear) | |
| # Mark only the LoRA layers as trainable | |
| lora.mark_only_lora_as_trainable(model, bias="none") | |
| def get_merged_state_dict(model): | |
| # This line will merge the state dict of the model and the LoRA parameters | |
| model.eval() | |
| # Then we need to remove the LoRA parameters from the state dict | |
| state_dict = model.state_dict() | |
| for name in list(state_dict.keys()): | |
| if "lora" in name: | |
| state_dict.pop(name) | |
| return state_dict | |