Fix(model): Linear detected and added to target module with rope linear (#738)
Browse files* Fix(model): Linear detected and added to target module with rope linear
* fix: exclude layer instead
src/axolotl/utils/models.py
CHANGED
|
@@ -507,7 +507,11 @@ def find_all_linear_names(model):
|
|
| 507 |
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
| 508 |
lora_module_names = set()
|
| 509 |
for name, module in model.named_modules():
|
| 510 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
names = name.split(".")
|
| 512 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
| 513 |
|
|
|
|
| 507 |
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
| 508 |
lora_module_names = set()
|
| 509 |
for name, module in model.named_modules():
|
| 510 |
+
if (
|
| 511 |
+
isinstance(module, cls)
|
| 512 |
+
or "Linear" in module.__class__.__name__
|
| 513 |
+
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
| 514 |
+
):
|
| 515 |
names = name.split(".")
|
| 516 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
| 517 |
|