Create a PEFT Module Guideline

概览

想在自己的仓库里“魔改”PEFT,最常见的套路是:

  1. Clone 官方 peft 源码
  2. tuner/ 里新建一个模块目录
  3. 用可编辑安装(pip install -e)替换你项目中的 PEFT
  4. 按需写 config.pylayer.pymodel.py

下面我用 LoRA(低秩适配器)举例,带你一步步梳理如何从零创建一个可复用的 PEFT Module。读不懂的地方我会额外加个小标注👀,帮助你快速上手。


一、准备工作

  1. 建仓库、拉源码

    1
    2
    3
    4
    mkdir peft-mogai
    cd peft-mogai
    git clone https://github.com/huggingface/peft.git
    pip install -e ./peft --config-settings editable_mode=compat

    -e(editable)模式可以让你改源码后,项目自动生效。

  2. 搭建模块目录
    peft/peft/tuners/ 目录下,创建你自己的文件夹,比如 lora_mogai/

    1
    2
    3
    4
    5
    6
    peft/peft/tuners/
    └── lora_mogai/
    ├── __init__.py
    ├── config.py
    ├── layer.py
    └── model.py
    • __init__.py:导出你的 Config、Layer、Model。
    1
    2
    3
    4
    5
    6
    # lora_mogai/__init__.py
    from .config import MogaiConfig
    from .layer import MogaiLayer, Linear
    from .model import MogaiModel

    __all__ = ["MogaiConfig", "MogaiLayer", "Linear", "MogaiModel"]

二、写配置(config.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# lora_mogai/config.py
from dataclasses import dataclass, field
from typing import List, Optional, Union
from peft.config import PeftConfig
from peft.utils import PeftType

@dataclass
class MogaiConfig(PeftConfig):
r: int = field(default=256, metadata={"help": "低秩映射维度"})
target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={"help": "要替换的模块名或正则,例如 ['q','v'] 或 '.*SelfAttention.*'"}
)
mogai_alpha: int = field(default=8, metadata={"help": "LoRA 缩放系数 α"})
mogai_dropout: float = field(default=0.0, metadata={"help": "LoRA dropout prob"})
fan_in_fan_out: bool = field(default=False, metadata={"help": "是否开启 fan_in_fan_out 模式"})
bias: str = field(default="none", metadata={"help": "bias 类型:'none'|'all'|'mogai_only'"})
modules_to_save: Optional[List[str]] = field(
default=None,
metadata={"help": "额外需要保存的随机初始化层"}
)
layers_to_transform: Optional[Union[List[int], int]] = field(
default=None,
metadata={"help": "只转换指定的层索引"}
)
layers_pattern: Optional[str] = field(
default=None,
metadata={"help": "自定义层名称正则(配合 layers_to_transform)"}
)

def __post_init__(self):
self.peft_type = PeftType.MOGAI
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)

👀 为什么要继承 PeftConfig
它让你的 Config 能和 PEFT 主流程无缝对接,比如自动解析 CLI、序列化存盘等。


三、实现基础 Adapter Layer(layer.py)

核心思路:

  1. 保存原始层 base_layer
  2. 动态创建低秩矩阵 A 和 B
  3. forward 里叠加 delta = B(A(x))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# lora_mogai/layer.py
import math, warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft.tuners.tuners_utils import BaseTunerLayer

class MogaiLayer(BaseTunerLayer):
adapter_layer_names = ("mogai_A", "mogai_B")
other_param_names = ("r", "mogai_alpha", "scaling", "mogai_dropout")

def __init__(self, base_layer: nn.Module, **kwargs):
self.base_layer = base_layer
self.r = {}
self.mogai_alpha = {}
self.scaling = {}
self.mogai_dropout= nn.ModuleDict()
self.mogai_A = nn.ModuleDict()
self.mogai_B = nn.ModuleDict()
self._disable_adapters = False
self.merged_adapters = []

base = self.get_base_layer()
if isinstance(base, nn.Linear):
self.in_features, self.out_features = base.in_features, base.out_features
self.kwargs = kwargs

def update_layer(self, adapter_name, module_name, r, mogai_alpha, mogai_dropout):
if r <= 0:
raise ValueError(f"`r` must be positive, got {r}")
self.r[adapter_name] = r
self.mogai_alpha[adapter_name] = mogai_alpha
self.scaling[adapter_name] = mogai_alpha / r
drop = nn.Dropout(p=mogai_dropout) if mogai_dropout>0 else nn.Identity()
self.mogai_dropout[adapter_name] = drop

self.mogai_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.mogai_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
nn.init.kaiming_uniform_(self.mogai_A[adapter_name].weight, a=math.sqrt(5))
nn.init.zeros_(self.mogai_B[adapter_name].weight)

self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapter)

class Linear(nn.Linear, MogaiLayer):
def __init__(self, base_layer, adapter_name, module_name, **kwargs):
super(nn.Linear, self).__init__()
MogaiLayer.__init__(self, base_layer, **kwargs)
self.update_layer(adapter_name, module_name, **kwargs)
def merge(self, *args, **kwargs): raise NotImplementedError
def unmerge(self, *args, **kwargs): raise NotImplementedError
def get_delta_weight(self, *args, **kwargs): raise NotImplementedError
def forward(self, x: torch.Tensor, *args, **kwargs):
orig_type = x.dtype
if self.disable_adapters or self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for name in self.active_adapters:
if name not in self.mogai_A: continue
A = self.mogai_A[name]; B = self.mogai_B[name]
drop, scale = self.mogai_dropout[name], self.scaling[name]
y = drop(x.to(A.weight.dtype))
result = result + F.linear(F.linear(y, A.weight), B.weight) * scale
return result.to(orig_type)
def __repr__(self):
return "mogai." + super().__repr__()

👀 小贴士

  • 把所有 adapter 参数都放到 ModuleDict 里,方便按名字管理;
  • 初始把 B 置零,可以保证训练初期模型行为和原模型一致。

四、拼装模型(model.py)

model.py 中,我们用 MogaiModel 继承自 PEFT 的 BaseTuner,核心作用是:

  1. 扫描并替换:在 _create_and_replace 方法里,遍历用户指定的 target_modules,把原始的 nn.LinearConv1D 层替换成我们前面写好的 lora_mogai.layer.Linear
  2. 动态注册:根据 MogaiConfig 里配置的 rmogai_alphamogai_dropout 等参数,调用 Linear 中的 update_layer 方法,初始化或更新对应层的 LoRA 参数。
  3. 保存/恢复状态:在 _replace_module 里,还会保留被替换层可能携带的 .state(例如缓存、RNN 隐状态等),确保替换后模型行为一致。
  4. 开关与合并:实现了 set_adapterenable_adapter_layers / disable_adapter_layersmerge_and_unload 等方法,让你可在训练、推理、部署时灵活地打开、关闭或将 LoRA 增量固化到原模型。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# lora_mogai/model.py 中的核心片段
from peft.tuners.tuners_utils import BaseTuner

class MogaiModel(BaseTuner):
prefix = "mogai_"
def __init__(self, model, config, adapter_name):
super().__init__(model, config, adapter_name)

def _create_and_replace(
self,
mogai_config,
adapter_name,
target,
target_name,
parent,
current_key,
**optional_kwargs,
):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")

pattern = re.compile(r'layers\.(\d+)\.(.+)')
match = pattern.search(current_key)
if match:
layer_id = int(match.group(1))
module_name = match.group(2).replace('.', '__')
else:
raise ValueError("Invalid target module type")

r = mogai_config.r
kwargs = {
"r": r,
"mogai_alpha": mogai_config.mogai_alpha,
"mogai_dropout": mogai_config.mogai_dropout,
"fan_in_fan_out": mogai_config.fan_in_fan_out,
"bias": mogai_config.bias
}

if isinstance(target, Linear):
target.update_layer(
adapter_name,
module_name,
r,
mogai_config.mogai_alpha,
mogai_config.mogai_dropout,
)
else:
new_module = self._create_new_module(mogai_config, adapter_name, target, module_name, **kwargs)
if adapter_name not in self.active_adapter:
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)

@staticmethod
def _create_new_module(mogai_config, adapter_name, target, module_name, **kwargs):
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = mogai_config.fan_in_fan_out = False
elif isinstance(target_base_layer, Conv1D):
kwargs["is_target_conv_1d_layer"] = True
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = mogai_config.fan_in_fan_out = True
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `transformers.pytorch_utils.Conv1D`."
)
new_module = Linear(
target,
adapter_name,
module_name,
**kwargs,
)

return new_module

@staticmethod
def _replace_module(parent, child_name, new_module, child):
# 保留 child.state,拷贝到 new_module
# 将 new_module 移动到原来 device 上
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable

# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer

if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias

if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)

# dispatch to correct device
for name, module in new_module.named_modules():
if "mogai_" in name:
module.to(child.weight.device)

下面这个函数是用来标记哪些参数是可训练的,哪些是不可训练的。这个函数会在 PeftModel__init__ 函数中被调用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False

for active_adapter in self.active_adapters:
bias = self.peft_config[active_adapter].bias
if bias == "none":
continue

if bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "mogai_only":
for m in model.modules():
if isinstance(m, MogaiLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")

运行一个完整示例

下面演示如何直接使用 MogaiModel 进行微调:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import AutoModelForCausalLM
from lora_mogai.config import MogaiConfig
from lora_mogai.model import MogaiModel

# 1. 准备基础模型和配置
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
config = MogaiConfig(
r=16,
target_modules=["q_proj", "v_proj"],
mogai_alpha=32,
mogai_dropout=0.05,
)

# 2. 用 MogaiModel 包装(会自动替换指定层)
mogai = MogaiModel(base_model, config, adapter_name="default")

# 3. 继续使用 Trainer 或Torch原生训练循环
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
output_dir="./mogai_output",
num_train_epochs=3,
per_device_train_batch_size=4,
learning_rate=1e-4,
)
trainer = Trainer(
model=mogai.model, # 拿到适配后的模型
args=training_args,
train_dataset=your_dataset, # 用户自己准备的数据集
)
trainer.train()

Mix Mode

get_peft_model 会返回两种class PeftModel和 PeftMixedModel。需要在call get_peft_model时指定mixed=True. 需要在 /src/peft/tuners/mixed/model.py 中把新创建的module加入到Comptaible Tuner中。

1
2
3
4
5
6
7
8
9
10
11
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base")
# model.add_adapter("adapters1", mogai_config)
peft_model = get_peft_model(model, mogai_config, "adapters1", mixed=True)
peft_model.add_adapter("adapters2", mogai_config)
peft_model.add_adapter("adapters3", mogai_config)
peft_model.add_adapter("adapters4", mogai_config)

peft_model.set_adapter(["adapters1", "adapters2", "adapters3", "adapters4"])
# peft_model.set_adapter("adapters1")

print(peft_model.active_adapters)