Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class LoRALayer(): | |
| def __init__( | |
| self, | |
| r: int, | |
| lora_alpha: int, | |
| lora_dropout: float, | |
| merge_weights: bool, | |
| ): | |
| self.r = r | |
| self.lora_alpha = lora_alpha | |
| # Optional dropout | |
| if lora_dropout > 0.: | |
| self.lora_dropout = nn.Dropout(p=lora_dropout) | |
| else: | |
| self.lora_dropout = lambda x: x | |
| # Mark the weight as unmerged | |
| self.merged = False | |
| self.merge_weights = merge_weights | |
| class Linear(nn.Linear, LoRALayer): | |
| # LoRA implemented in a dense layer | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| r: int = 0, | |
| lora_alpha: int = 1, | |
| lora_dropout: float = 0., | |
| fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | |
| merge_weights: bool = True, | |
| **kwargs | |
| ): | |
| nn.Linear.__init__(self, in_features, out_features, **kwargs) | |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, | |
| merge_weights=merge_weights) | |
| self.fan_in_fan_out = fan_in_fan_out | |
| # Actual trainable parameters | |
| if r > 0: | |
| self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) | |
| self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) | |
| self.scaling = self.lora_alpha / self.r | |
| # Freezing the pre-trained weight matrix | |
| self.weight.requires_grad = False | |
| self.reset_parameters() | |
| if fan_in_fan_out: | |
| self.weight.data = self.weight.data.transpose(0, 1) | |
| def reset_parameters(self): | |
| nn.Linear.reset_parameters(self) | |
| if hasattr(self, 'lora_A'): | |
| # initialize A the same way as the default for nn.Linear and B to zero | |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) | |
| nn.init.zeros_(self.lora_B) | |
| def train(self, mode: bool = True): | |
| def T(w): | |
| return w.transpose(0, 1) if self.fan_in_fan_out else w | |
| nn.Linear.train(self, mode) | |
| if mode: | |
| if self.merge_weights and self.merged: | |
| # Make sure that the weights are not merged | |
| if self.r > 0: | |
| self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling | |
| self.merged = False | |
| else: | |
| if self.merge_weights and not self.merged: | |
| # Merge the weights and mark it | |
| if self.r > 0: | |
| self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling | |
| self.merged = True | |
| def forward(self, x: torch.Tensor): | |
| def T(w): | |
| return w.transpose(0, 1) if self.fan_in_fan_out else w | |
| if self.r > 0 and not self.merged: | |
| result = F.linear(x, T(self.weight), bias=self.bias) | |
| if self.r > 0: | |
| result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling | |
| return result | |
| else: | |
| return F.linear(x, T(self.weight), bias=self.bias) | |
| class Conv2d(nn.Conv2d, LoRALayer): | |
| # LoRA implemented in a dense layer | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| r: int = 0, | |
| lora_alpha: int = 1, | |
| lora_dropout: float = 0., | |
| merge_weights: bool = True, | |
| **kwargs | |
| ): | |
| nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) | |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, | |
| merge_weights=merge_weights) | |
| # assert type(kernel_size) is int | |
| if type(kernel_size) is tuple: | |
| temp_ks = kernel_size[0] | |
| # Actual trainable parameters | |
| if r > 0: | |
| self.lora_A = nn.Parameter( | |
| self.weight.new_zeros((r*temp_ks, in_channels*temp_ks)) | |
| ) | |
| self.lora_B = nn.Parameter( | |
| self.weight.new_zeros((out_channels*temp_ks, r*temp_ks)) | |
| ) | |
| self.scaling = self.lora_alpha / self.r | |
| # Freezing the pre-trained weight matrix | |
| self.weight.requires_grad = False | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.Conv2d.reset_parameters(self) | |
| if hasattr(self, 'lora_A'): | |
| # initialize A the same way as the default for nn.Linear and B to zero | |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) | |
| nn.init.zeros_(self.lora_B) | |
| def train(self, mode: bool = True): | |
| nn.Conv2d.train(self, mode) | |
| if mode: | |
| if self.merge_weights and self.merged: | |
| # Make sure that the weights are not merged | |
| self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling | |
| self.merged = False | |
| else: | |
| if self.merge_weights and not self.merged: | |
| # Merge the weights and mark it | |
| self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling | |
| self.merged = True | |
| def forward(self, x: torch.Tensor): | |
| if self.r > 0 and not self.merged: | |
| return F.conv2d( | |
| x, | |
| self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, | |
| self.bias, self.stride, self.padding, self.dilation, self.groups | |
| ) | |
| return nn.Conv2d.forward(self, x) | |
| def wrap_model_with_lora(module, rank=4): | |
| for name, child in module.named_children(): | |
| if isinstance(child, (Linear, Conv2d)): | |
| continue | |
| if 'stitch' in name: | |
| pass | |
| if isinstance(child, nn.Linear): | |
| setattr(module, name, Linear(in_features=child.in_features, out_features=child.out_features, bias=child.bias is not None, r=rank)) | |
| elif isinstance(child, nn.Conv2d): | |
| setattr(module, name, Conv2d(in_channels=child.in_channels, out_channels=child.out_channels, kernel_size=child.kernel_size, stride=child.stride, padding=child.padding, dilation=child.dilation, groups=child.groups, bias=child.bias is not None, r=rank)) | |
| else: | |
| wrap_model_with_lora(child, rank) | |