跳转至

AdaLoRA

mindnlp.peft.tuners.adalora.config.AdaLoraConfig dataclass

Bases: LoraConfig

This is the configuration class to store the configuration of a [~peft.AdaLora].

PARAMETER DESCRIPTION
target_r

The target average rank of incremental matrix.

TYPE: `int` DEFAULT: 8

init_r

The initial rank for each incremental matrix.

TYPE: `int` DEFAULT: 12

tinit

The steps of initial fine-tuning warmup.

TYPE: `int` DEFAULT: 0

tfinal

The step of final fine-tuning.

TYPE: `int` DEFAULT: 0

deltaT

The time internval between two budget allocations.

TYPE: `int` DEFAULT: 1

beta1

The hyperparameter of EMA for sensitivity smoothing.

TYPE: `float` DEFAULT: 0.85

beta2

The hyperparameter of EMA for undertainty quantification.

TYPE: `float` DEFAULT: 0.85

orth_reg_weight

The coefficient of orthogonal regularization.

TYPE: `float` DEFAULT: 0.5

total_step

The total training steps that should be specified before training.

TYPE: `int` DEFAULT: None

rank_pattern

The allocated rank for each weight matrix by RankAllocator.

TYPE: `list` DEFAULT: None

Source code in mindnlp/peft/tuners/adalora/config.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@dataclass
class AdaLoraConfig(LoraConfig):
    """
    This is the configuration class to store the configuration of a [`~peft.AdaLora`].

    Args:
        target_r (`int`): The target average rank of incremental matrix.
        init_r (`int`): The initial rank for each incremental matrix.
        tinit (`int`): The steps of initial fine-tuning warmup.
        tfinal (`int`): The step of final fine-tuning.
        deltaT (`int`): The time internval between two budget allocations.
        beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing.
        beta2 (`float`): The hyperparameter of EMA for undertainty quantification.
        orth_reg_weight (`float`): The coefficient of orthogonal regularization.
        total_step (`int`): The total training steps that should be specified before training.
        rank_pattern (`list`): The allocated rank for each weight matrix by RankAllocator.
    """

    target_r: int = field(default=8, metadata={"help": "Target Lora matrix dimension."})
    init_r: int = field(default=12, metadata={"help": "Initial Lora matrix dimension."})
    tinit: int = field(default=0, metadata={"help": "The steps of initial warmup."})
    tfinal: int = field(default=0, metadata={"help": "The steps of final warmup."})
    deltaT: int = field(default=1, metadata={"help": "Step interval of rank allocation."})
    beta1: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."})
    beta2: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."})
    orth_reg_weight: float = field(default=0.5, metadata={"help": "The orthogonal regularization coefficient."})
    total_step: Optional[int] = field(default=None, metadata={"help": "The total training steps."})
    rank_pattern: Optional[dict] = field(default=None, metadata={"help": "The saved rank pattern."})

    def __post_init__(self):
        self.peft_type = PeftType.ADALORA

mindnlp.peft.tuners.adalora.model.AdaLoraModel

Bases: LoraModel

Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper: https://openreview.net/forum?id=lq62uWRJjiY

PARAMETER DESCRIPTION
model

The model to be adapted.

TYPE: [`mindspore.nn.Cell`]

config

The configuration of the AdaLora model.

TYPE: [`AdaLoraConfig`]

adapter_name

The name of the adapter, defaults to "default".

TYPE: `str`

RETURNS DESCRIPTION
AdaLoraModel

The AdaLora model.

TYPE: [`mindspore.nn.Cell`]

>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import AdaLoraModel, AdaLoraConfig
>>> config = AdaLoraConfig(
        peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
        lora_dropout=0.01,
    )
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(model, config, "default")

Attributes:

  • model ([transformers.PreTrainedModel])— The model to be adapted.

  • peft_config ([AdaLoraConfig]): The configuration of the AdaLora model.

Source code in mindnlp/peft/tuners/adalora/model.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
class AdaLoraModel(LoraModel):
    """
    Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper:
    https://openreview.net/forum?id=lq62uWRJjiY

    Args:
        model ([`mindspore.nn.Cell`]): The model to be adapted.
        config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
        adapter_name (`str`): The name of the adapter, defaults to `"default"`.

    Returns:
        AdaLoraModel ([`mindspore.nn.Cell`]): The AdaLora model.

    Example::

        >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import AdaLoraModel, AdaLoraConfig
        >>> config = AdaLoraConfig(
                peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
                lora_dropout=0.01,
            )
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(model, config, "default")

    > **Attributes**:  

    >   - **model** ([`transformers.PreTrainedModel`])— The model to be adapted. 

    >   - **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model. 
    """

    # Note: don't redefine prefix here, it should be inherited from LoraModel

    def __init__(self, model, config, adapter_name):
        super().__init__(model, config, adapter_name)

        traininable_mode_counter = 0
        for peft_config in self.peft_config.values():
            if not peft_config.inference_mode:
                traininable_mode_counter += 1

        if traininable_mode_counter > 1:
            raise ValueError(
                "AdaLoraModel supports only 1 trainable adapter. "
                "When using multiple adapters, set inference_mode to True for all adapters except the one you want to train."
            )

        if self.peft_config[adapter_name].inference_mode:
            _freeze_adapter(self.model, adapter_name)
        else:
            self.trainable_adapter_name = adapter_name
            self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name)

    def _check_new_adapter_config(self, config: LoraConfig) -> None:
        """
        A helper method to check the config when a new adapter is being added.

        Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.

        """
        super()._check_new_adapter_config(config)

        traininable_mode_counter = 0
        for config_ in self.peft_config.values():
            if not config_.inference_mode:
                traininable_mode_counter += 1

        if traininable_mode_counter > 1:
            raise ValueError(
                f"{self.__class__.__name__} supports only 1 trainable adapter. "
                "When using multiple adapters, set inference_mode to True for all adapters except the one "
                "you want to train."
            )
    def _mark_only_adapters_as_trainable(self, model: nn.Cell) -> None:
        for n, p in model.parameters_and_names():
            if "lora_" not in n:
                p.requires_grad = False

        for active_adapter in self.active_adapters:
            bias = self.peft_config[active_adapter].bias
            if bias == "none":
                continue

            if bias == "all":
                for n, p in model.parameters_and_names():
                    if "bias" in n:
                        p.requires_grad = True
            elif bias == "lora_only":
                for m in model.cells():
                    if isinstance(m, AdaLoraLayer) and hasattr(m, "bias") and m.bias is not None:
                        m.bias.requires_grad = True
            else:
                raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")
    def _create_and_replace(
        self,
        lora_config,
        adapter_name,
        target,
        target_name,
        parent,
        current_key,
        **optionnal_kwargs,
    ):
        kwargs = {
            "r": lora_config.init_r,
            "lora_alpha": lora_config.lora_alpha,
            "lora_dropout": lora_config.lora_dropout,
            "fan_in_fan_out": lora_config.fan_in_fan_out,
            "init_lora_weights": lora_config.init_lora_weights,
        }
        kwargs["loaded_in_8bit"] = optionnal_kwargs.pop("loaded_in_8bit", False)
        kwargs["loaded_in_4bit"] = optionnal_kwargs.pop("loaded_in_4bit", False)
        # if (kwargs["loaded_in_8bit"] or kwargs["loaded_in_4bit"]) and not is_bnb_available():
        #     raise ImportError(
        #         "To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. "
        #         "You can install it with `pip install bitsandbytes`."
        #     )
        # quantization_config = get_quantization_config(self.model, method="gptq")
        # if quantization_config is not None:
        #     kwargs["gptq_quantization_config"] = quantization_config

        # If it is not an AdaLoraLayer, create a new module, else update it with new adapters
        if not isinstance(target, AdaLoraLayer):
            new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
            self._replace_module(parent, target_name, new_module, target)
        else:
            target.update_layer(
                adapter_name,
                lora_config.init_r,
                lora_config.lora_alpha,
                lora_config.lora_dropout,
                lora_config.init_lora_weights,
            )

    @staticmethod
    def _create_new_module(lora_config, adapter_name, target, **kwargs):
        # avoid eager bnb import
        # if is_bnb_available():
        #     import bitsandbytes as bnb

        #     from .bnb import SVDLinear8bitLt
        # if is_bnb_4bit_available():
        #     from .bnb import SVDLinear4bit

        # gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
        # AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

        # loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
        # loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)

        if isinstance(target, BaseTunerLayer):
            target_base_layer = target.get_base_layer()
        else:
            target_base_layer = target

        # if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
        #     kwargs.update(
        #         {
        #             "has_fp16_weights": target_base_layer.state.has_fp16_weights,
        #             "memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
        #             "threshold": target_base_layer.state.threshold,
        #             "index": target_base_layer.index,
        #         }
        #     )
        #     new_module = SVDLinear8bitLt(target, adapter_name, **kwargs)
        # elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
        #     fourbit_kwargs = kwargs.copy()
        #     fourbit_kwargs.update(
        #         {
        #             "compute_dtype": target_base_layer.compute_dtype,
        #             "compress_statistics": target_base_layer.weight.compress_statistics,
        #             "quant_type": target_base_layer.weight.quant_type,
        #         }
        #     )
        #     new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
        # elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
        #     new_module = SVDQuantLinear(target, adapter_name, **kwargs)
        if isinstance(target_base_layer, nn.Dense):
            if kwargs["fan_in_fan_out"]:
                warnings.warn(
                    "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                    "Setting fan_in_fan_out to False."
                )
                kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
        elif isinstance(target_base_layer, Conv1D):
            if not kwargs["fan_in_fan_out"]:
                warnings.warn(
                    "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                    "Setting fan_in_fan_out to True."
                )
                kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
        else:
            raise ValueError(
                f"Target module {target} is not supported. "
                f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
            )
        new_module = SVDLinear(target, adapter_name, **kwargs)

        return new_module
    def _replace_module(self, parent, child_name, new_module, child):
        setattr(parent, child_name, new_module)

        # child layer wraps the original module, unpack it
        if hasattr(child, "base_layer"):
            child = child.base_layer

        # layers with base_layer don't need the weight to be copied, as they have a reference already
        if not hasattr(new_module, "base_layer"):
            new_module.weight = child.weight
            if hasattr(child, "bias"):
                new_module.bias = child.bias

        if getattr(child, "state", None) is not None:
            if hasattr(new_module, "base_layer"):
                new_module.base_layer.state = child.state
            else:
                new_module.state = child.state
    @staticmethod
    def _prepare_adapter_config(peft_config, model_config):
        if peft_config.target_modules is None:
            if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING:
                raise ValueError("Please specify `target_modules` in `peft_config`")
            peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[
                model_config["model_type"]
            ]
        return peft_config

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.model, name)

    def construct(self, *args, **kwargs):
        """The construct method of the model"""
        outputs = self.model(*args, **kwargs)

        if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, Tensor):
            # Calculate the orthogonal regularization
            orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight

            if orth_reg_weight <= 0:
                raise ValueError("orth_reg_weight should be greater than 0. ")

            regu_loss = 0
            num_param = 0
            for n, p in self.model.parameters_and_names():
                if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n:
                    para_cov = p @ p.T if "lora_A" in n else p.T @ p
                    I = ops.eye(*para_cov.shape)  # noqa: E741
                    I = ops.stop_gradient(I)
                    num_param += 1
                    regu_loss += ops.norm(para_cov - I, ord="fro")
            if num_param > 0:
                regu_loss = regu_loss / num_param
            else:
                regu_loss = 0
            outputs.loss += orth_reg_weight * regu_loss
        return outputs

    def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name):
        "resize the modules by rank pattern"
        lora_config = self.peft_config[adapter_name]
        for name, rank_idx in rank_pattern.items():
            if isinstance(rank_idx, list):
                rank = sum(rank_idx)
                rank_idx = Tensor(rank_idx).view(-1)
            elif isinstance(rank_idx, Tensor):
                rank_idx = rank_idx.view(-1)
                rank = rank_idx.sum().item()
            else:
                raise ValueError("Unexpected type of rank_idx")
            key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
            _, target, _ = _get_submodules(self.model, key)
            lora_E_weights = target.lora_E[adapter_name][rank_idx]
            lora_A_weights = target.lora_A[adapter_name][rank_idx]
            lora_B_weights = target.lora_B[adapter_name][:, rank_idx]
            ranknum = target.ranknum[adapter_name]
            target.update_layer(
                adapter_name,
                rank,
                lora_config.lora_alpha,
                lora_config.lora_dropout,
                lora_config.init_lora_weights,
            )
            if rank > 0:
                target.lora_E.update({adapter_name: Parameter(lora_E_weights)})
                target.lora_A.update({adapter_name: Parameter(lora_A_weights)})
                target.lora_B.update({adapter_name: Parameter(lora_B_weights)})
                # The scaling is exactly as the previous
                target.ranknum.update({adapter_name: Parameter(ranknum)})


    def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name):
        "resize the state_dict by rank pattern"
        for name, rank_idx in rank_pattern.items():
            rank = sum(rank_idx)
            prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
            for layer in ["lora_E", "lora_A", "lora_B"]:
                key = f"base_model.model.{prefix}.{layer}.{adapter_name}"
                if layer != "lora_B":
                    if rank != state_dict[key][2].reshape(state_dict[key][0]).shape[0]:
                        dims = []
                        data = state_dict[key][2].reshape(state_dict[key][0])
                        data = data[rank_idx]
                        state_dict[key][2] = data.reshape(-1)
                        for dim in data.shape:
                            dims.append(dim)
                        state_dict[key][0] = dims
                else:
                    if rank != state_dict[key][2].reshape(state_dict[key][0]).shape[1]:
                        dims = []
                        data = state_dict[key][2].reshape(state_dict[key][0])
                        data = data[:, rank_idx]
                        state_dict[key][2] = data.reshape(-1)
                        for dim in data.shape:
                            dims.append(dim)
                        state_dict[key][0] = dims
        return state_dict

    def update_and_allocate(self, global_step, gradient):
        """
        This method updates Adalora budget and mask.

        This should be called in every training step after `loss.backward()` and before `zero_grad()`.

        `tinit`, `tfinal` and `deltaT` are handled with in the method.

        Args:
            global_step (`int`): The current training step, it is used to calculate adalora budget.

        Example:

        ```python
        >>> loss = model(**input).loss
        >>> loss.backward()
        >>> optimizer.step()
        >>> model.base_model.update_and_allocate(i_step)
        >>> optimizer.zero_grad()
        ```
        """
        lora_config = self.peft_config[self.trainable_adapter_name]
        # Update the importance score and allocate the budget
        if global_step < lora_config.total_step - lora_config.tfinal:
            _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, gradient)
            if rank_pattern:
                lora_config.rank_pattern = rank_pattern
        # Finalize the budget allocation
        elif global_step == lora_config.total_step - lora_config.tfinal:
            _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, gradient,force_mask=True)
            # for some reason, this freezes the trainable parameters and nothing gets updates
            # self.resize_modules_by_rank_pattern(rank_pattern, self.trainable_adapter_name)
            lora_config.rank_pattern = rank_pattern
            self.rankallocator.reset_ipt()
        # Currently using inefficient way to mask the unimportant weights using the rank pattern
        #  due to problem mentioned above
        elif global_step > lora_config.total_step - lora_config.tfinal:
            self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern)

mindnlp.peft.tuners.adalora.model.AdaLoraModel.__getattr__(name)

Forward missing attributes to the wrapped module.

Source code in mindnlp/peft/tuners/adalora/model.py
263
264
265
266
267
268
def __getattr__(self, name: str):
    """Forward missing attributes to the wrapped module."""
    try:
        return super().__getattr__(name)  # defer to nn.Module's logic
    except AttributeError:
        return getattr(self.model, name)

mindnlp.peft.tuners.adalora.model.AdaLoraModel.construct(*args, **kwargs)

The construct method of the model

Source code in mindnlp/peft/tuners/adalora/model.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def construct(self, *args, **kwargs):
    """The construct method of the model"""
    outputs = self.model(*args, **kwargs)

    if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, Tensor):
        # Calculate the orthogonal regularization
        orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight

        if orth_reg_weight <= 0:
            raise ValueError("orth_reg_weight should be greater than 0. ")

        regu_loss = 0
        num_param = 0
        for n, p in self.model.parameters_and_names():
            if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n:
                para_cov = p @ p.T if "lora_A" in n else p.T @ p
                I = ops.eye(*para_cov.shape)  # noqa: E741
                I = ops.stop_gradient(I)
                num_param += 1
                regu_loss += ops.norm(para_cov - I, ord="fro")
        if num_param > 0:
            regu_loss = regu_loss / num_param
        else:
            regu_loss = 0
        outputs.loss += orth_reg_weight * regu_loss
    return outputs

mindnlp.peft.tuners.adalora.model.AdaLoraModel.resize_modules_by_rank_pattern(rank_pattern, adapter_name)

resize the modules by rank pattern

Source code in mindnlp/peft/tuners/adalora/model.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name):
    "resize the modules by rank pattern"
    lora_config = self.peft_config[adapter_name]
    for name, rank_idx in rank_pattern.items():
        if isinstance(rank_idx, list):
            rank = sum(rank_idx)
            rank_idx = Tensor(rank_idx).view(-1)
        elif isinstance(rank_idx, Tensor):
            rank_idx = rank_idx.view(-1)
            rank = rank_idx.sum().item()
        else:
            raise ValueError("Unexpected type of rank_idx")
        key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
        _, target, _ = _get_submodules(self.model, key)
        lora_E_weights = target.lora_E[adapter_name][rank_idx]
        lora_A_weights = target.lora_A[adapter_name][rank_idx]
        lora_B_weights = target.lora_B[adapter_name][:, rank_idx]
        ranknum = target.ranknum[adapter_name]
        target.update_layer(
            adapter_name,
            rank,
            lora_config.lora_alpha,
            lora_config.lora_dropout,
            lora_config.init_lora_weights,
        )
        if rank > 0:
            target.lora_E.update({adapter_name: Parameter(lora_E_weights)})
            target.lora_A.update({adapter_name: Parameter(lora_A_weights)})
            target.lora_B.update({adapter_name: Parameter(lora_B_weights)})
            # The scaling is exactly as the previous
            target.ranknum.update({adapter_name: Parameter(ranknum)})

mindnlp.peft.tuners.adalora.model.AdaLoraModel.resize_state_dict_by_rank_pattern(rank_pattern, state_dict, adapter_name)

resize the state_dict by rank pattern

Source code in mindnlp/peft/tuners/adalora/model.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name):
    "resize the state_dict by rank pattern"
    for name, rank_idx in rank_pattern.items():
        rank = sum(rank_idx)
        prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
        for layer in ["lora_E", "lora_A", "lora_B"]:
            key = f"base_model.model.{prefix}.{layer}.{adapter_name}"
            if layer != "lora_B":
                if rank != state_dict[key][2].reshape(state_dict[key][0]).shape[0]:
                    dims = []
                    data = state_dict[key][2].reshape(state_dict[key][0])
                    data = data[rank_idx]
                    state_dict[key][2] = data.reshape(-1)
                    for dim in data.shape:
                        dims.append(dim)
                    state_dict[key][0] = dims
            else:
                if rank != state_dict[key][2].reshape(state_dict[key][0]).shape[1]:
                    dims = []
                    data = state_dict[key][2].reshape(state_dict[key][0])
                    data = data[:, rank_idx]
                    state_dict[key][2] = data.reshape(-1)
                    for dim in data.shape:
                        dims.append(dim)
                    state_dict[key][0] = dims
    return state_dict

mindnlp.peft.tuners.adalora.model.AdaLoraModel.update_and_allocate(global_step, gradient)

This method updates Adalora budget and mask.

This should be called in every training step after loss.backward() and before zero_grad().

tinit, tfinal and deltaT are handled with in the method.

PARAMETER DESCRIPTION
global_step

The current training step, it is used to calculate adalora budget.

TYPE: `int`

>>> loss = model(**input).loss
>>> loss.backward()
>>> optimizer.step()
>>> model.base_model.update_and_allocate(i_step)
>>> optimizer.zero_grad()
Source code in mindnlp/peft/tuners/adalora/model.py
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def update_and_allocate(self, global_step, gradient):
    """
    This method updates Adalora budget and mask.

    This should be called in every training step after `loss.backward()` and before `zero_grad()`.

    `tinit`, `tfinal` and `deltaT` are handled with in the method.

    Args:
        global_step (`int`): The current training step, it is used to calculate adalora budget.

    Example:

    ```python
    >>> loss = model(**input).loss
    >>> loss.backward()
    >>> optimizer.step()
    >>> model.base_model.update_and_allocate(i_step)
    >>> optimizer.zero_grad()
    ```
    """
    lora_config = self.peft_config[self.trainable_adapter_name]
    # Update the importance score and allocate the budget
    if global_step < lora_config.total_step - lora_config.tfinal:
        _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, gradient)
        if rank_pattern:
            lora_config.rank_pattern = rank_pattern
    # Finalize the budget allocation
    elif global_step == lora_config.total_step - lora_config.tfinal:
        _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, gradient,force_mask=True)
        # for some reason, this freezes the trainable parameters and nothing gets updates
        # self.resize_modules_by_rank_pattern(rank_pattern, self.trainable_adapter_name)
        lora_config.rank_pattern = rank_pattern
        self.rankallocator.reset_ipt()
    # Currently using inefficient way to mask the unimportant weights using the rank pattern
    #  due to problem mentioned above
    elif global_step > lora_config.total_step - lora_config.tfinal:
        self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern)