Stable Diffusionの学習コードを作る:3.LoRA編

 前回はフルファインチューニングをやりましたが、今回はLoRAの学習ができるようにします。


NetworkManager

 モデルへ追加ネットワークを適用する処理を行うクラスです。今回はLoRAしかやりませんが、後々LoHAとかいろいろなものを実装する予定なので、NetworkManagerという名前になっています。そのくせ変数名がlora前提になっていたりとちょっと見直さなきゃいけない部分もありそう。
https://github.com/laksjdjf/sd-trainer/blob/main/networks/manager.py

コンストラクタ

 LoRAはテキストエンコーダ/UNetのトランスフォーマー層/UNetのResNet層(Upsample, Downsample含む)の3つで処理を分けます。
 moduleは文字列で指定します。"networks.lora.LoRAModule"とかね。module_argsはTransformer、conv_module_argsはResNet、text_module_argsはテキストエンコーダに適用するmoduleへの引き数で、LoRAの場合はrankやalphaを指定できます。なんかTransformerだけは必須になっていますが、変更するかもしれません。
 modeはLoRAを元のモデルと分けて計算するか、元のモデルにマージしてしまうかの設定です。学習時は"apply"にしますが、生成時は"merge"の方が計算が早くなります。
 unet_key_filtersは、文字列のリストで、LoRAの適用範囲をkey名で適用できます。["up_blocks"]にすればup_blockのみのLoRAができあがります。要素を増やすとORで計算します。ANDはないけどまあいいか。

UNET_TARGET_REPLACE_MODULE_TRANSFORMER = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
LORA_PREFIX_TEXT_ENCODER_1 = 'lora_te1'
LORA_PREFIX_TEXT_ENCODER_2 = 'lora_te2'

class NetworkManager(nn.Module):
    def __init__(
        self, 
        text_model,
        unet, 
        module,
        module_args,
        unet_key_filters=None,
        conv_module_args=None,
        text_module_args=None,
        multiplier=1.0,
        mode="apply", # select "apply" or "merge"
    ):
        super().__init__()
        self.multiplier = multiplier
        self.apply_te = text_module_args is not None

        self.module = get_attr_from_config(module)
       
        # unetのloraを作る
        self.unet_modules = []
        self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_TRANSFORMER, module_args, unet_key_filters)
        if conv_module_args is not None:
            self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_CONV, conv_module_args, unet_key_filters)
        if self.apply_te:
            self.text_encoder_modules = []
            if text_model.sdxl:
                self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_1, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, text_module_args)
                self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_2, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, text_module_args)
            else:
                self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, text_module_args)
        else:
            self.text_encoder_modules = []
        
        logger.info(f"UNetのモジュールは: {len(self.unet_modules)}個だよ。")
        logger.info(f"TextEncoderのモジュールは: {len(self.text_encoder_modules)}個だよ。")

        for lora in self.text_encoder_modules + self.unet_modules:
            self.add_module(lora.lora_name, lora)

        if mode == "apply":
            self.apply_to()
        elif mode == "merge":
            self.merge_to()
        else:
            raise ValueError(f"mode {mode} is not supported.")

apply_toとか

 いっぱい並んでいますが、各モジュールでapply_toなどを処理させるだけのメソッドです。

create_modules

 モデルからLinear層やConv層を検索してLoRAを作成します。二重ループになってますが、一つ目のループは全モジュールを対象にします。その後target_replace_modulesで指定したモジュールを対象に二つ目のループに入ります。二つ目のループで適用対象のモジュールとその名前、適用強度や引数を渡してLoRAを作ってもらいます。

def create_modules(self, prefix, root_module, target_replace_modules, module_args, unet_limited_keys=None) -> list:
    modules = []
    for name, module in root_module.named_modules():
        if module.__class__.__name__ in target_replace_modules:
            for child_name, child_module in module.named_modules():
                if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
                    lora_name = prefix + '.' + name + '.' + child_name
                    lora_name = lora_name.replace('.', '_')
                    if is_key_allowed(lora_name, unet_limited_keys):
                        lora = self.module(lora_name, child_module, self.multiplier, **module_args)
                        modules.append(lora)
    return modules

prepare_optimizer_params

 最適化関数に渡すパラメータを出力します。学習率はテキストエンコーダとUNetで分かれています。前回紹介したTrainerのprepare_optimizerで使われています。

def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
    self.requires_grad_(True)
    all_params = []

    if self.text_encoder_modules:
        params = []
        [params.extend(lora.parameters()) for lora in self.text_encoder_modules]
        param_data = {'params': params}
        if text_encoder_lr is not None:
            param_data['lr'] = text_encoder_lr
        all_params.append(param_data)

    if self.unet_modules:
        params = []
        [params.extend(lora.parameters()) for lora in self.unet_modules]
        param_data = {'params': params}
        if unet_lr is not None:
            param_data['lr'] = unet_lr
        all_params.append(param_data)

    return all_params

save_weights, load_weights

 重みのセーブとかロードです。大したことないコードなので省略します。

set_temporary_multiplier

 なんかcontextmanagerとかいうやつで、LoRAの強度を指定します。with構文で使えます。LCM-LoRAの学習で利用する予定で、今回は使いません。

@contextlib.contextmanager
def set_temporary_multiplier(self, multiplier):
    for lora in self.text_encoder_modules + self.unet_modules:
        lora.multiplier = multiplier
    yield
    for lora in self.text_encoder_modules + self.unet_modules:
        lora.multiplier = 1.0

BaseModule

https://github.com/laksjdjf/sd-trainer/blob/main/networks/lora.py

 LoRAなどを定義する際、apply_to等のコードはどのモジュールでも変わらないと思うので、その辺を定義する基底クラスを用意しておきます。
 apply_toはforwardを置き換えてLoRAを適用します。一方merge_toは元のモデルの重みにLoRAをマージすることで適用します。

class BaseModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def apply_to(self, multiplier=None):
        if multiplier is not None:
            self.multiplier = multiplier
        self.org_forward = self.org_module[0].forward
        self.org_module[0].forward = self.forward

    def unapply_to(self):
        self.org_module[0].forward = self.org_forward

    def merge_to(self, multiplier=None, sign=1):
        lora_weight = self.get_weight(multiplier) * sign

        # get org weight
        org_sd = self.org_module[0].state_dict()
        org_weight = org_sd["weight"]
        weight = org_weight + lora_weight.to(org_weight)

        # set weight to org_module
        org_sd["weight"] = weight
        self.org_module[0].load_state_dict(org_sd)

    def restore_from(self, multiplier=None):
        self.merge_to(multiplier=multiplier, sign=-1)

    def get_weight(self, multiplier=None):
        raise NotImplementedError

LoRAModule

 LoRAはrankやalphaを指定して作ります。おれはalpha=1.0に固定してますけど。
 Linear層の場合、入力次元⇒rankのdown層とrank⇒出力次元のup層に分けます。Conv層の場合も基本は同じですが、down層は適用対象と同じカーネルサイズ、up層は1×1カーネルになります。
 LoRAはalpha/rankでスケーリングします(元論文がそうしている)。alphaはregister_bufferという勾配計算の対象にならない定数を作成する機能を利用しています。
 down層の初期重みはよくわかんない、up層はゼロになります。up層がゼロなので学習前のLoRAを適用しても元のモデルと出力は変わりません。

class LoRAModule(BaseModule):

    def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, rank=4, alpha=1):
        super().__init__()
        self.lora_name = lora_name
        self.rank = rank

        if 'Linear' in org_module.__class__.__name__: # ["Linear", "LoRACompatibleLinear"]
            in_dim = org_module.in_features
            out_dim = org_module.out_features

            self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
            self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)

        elif 'Conv' in org_module.__class__.__name__: # ["Conv2d", "LoRACompatibleConv"]
            in_dim = org_module.in_channels
            out_dim = org_module.out_channels

            self.rank = min(self.rank, in_dim, out_dim)
            if self.rank != rank:
                print(f"{lora_name} dim (rank) is changed to: {self.rank} because of in_dim or out_dim is smaller than rank")

            kernel_size = org_module.kernel_size
            stride = org_module.stride
            padding = org_module.padding
            self.lora_down = torch.nn.Conv2d(
                in_dim, self.rank, kernel_size, stride, padding, bias=False)
            self.lora_up = torch.nn.Conv2d(
                self.rank, out_dim, (1, 1), (1, 1), bias=False)

        self.shape = org_module.weight.shape

        if type(alpha) == torch.Tensor:
            alpha = alpha.detach().numpy()
        alpha = rank if alpha is None or alpha == 0 else alpha
        self.scale = alpha / self.rank
        self.register_buffer('alpha', torch.tensor(alpha))

        # same as microsoft's
        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
        torch.nn.init.zeros_(self.lora_up.weight)

        self.multiplier = multiplier
        self.org_module = [org_module] # moduleにならないようにlistに入れる

forward

 別にむずかしいことはないですね。わざわざlora_forwardとforwardを分けて実装しているのは、特に意味ないです(こうなった経緯は複雑だけど)。
 forwardのscale引数は使ってないじゃんって感じですがこれはDiffusersのせいで必要になります。**kwargsとかでもいい気がしますがどっちの方が安全なんやろ。

def lora_forward(self, x):
    return self.lora_up(self.lora_down(x)) * self.multiplier * self.scale

def forward(self, x, scale = None):
    if self.multiplier == 0.0:
        return self.org_forward(x)
    else:
        return self.org_forward(x) + self.lora_forward(x)

get_weight

 学習には使いませんが、ΔWを計算するメソッドです。Linearの場合二つの行列の積なので分かりやすいと思いますが、Convの場合よく分からんですよね。1×1カーネルなんてサイズ(out, in, 1, 1)になっていて実質Linearだから要素1の次元を消せば同じような計算になります。3×3カーネルの場合は、入力次元が9倍になっただけと思えばいいです。(out, in, k, k)を(out, in * k * k)にreshapeしたら行列です。行列積を計算したら今度は元の重みと同じサイズにreshapeするだけでいいです。

# calculate lora weight (delta W)
def get_weight(self, multiplier=None):
    if multiplier is None:
        multiplier = self.multiplier

    up_weight = self.lora_up.weight.view(-1, self.rank) # out_dim, rank
    down_weight = self.lora_down.weight.view(self.rank, -1) # rank, in_dim*kernel*kernel
        
    lora_weight = up_weight @ down_weight  # out_dim, in_dim*kernel*kernel
    lora_weight = lora_weight.view(self.shape)  # out_dim, in_dim, [kernel, kernel]

    return lora_weight * multiplier * self.scale

Trainer

prepare_network

前回省略しましたが、main.pyで使っています。

    def prepare_network(self, config):
        if config is None:
            self.network = None
            self.network_train = False
            logger.info("ネットワークはないみたい。")
            return 
        self.network = NetworkManager(
            text_model=self.text_model,
            unet=self.diffusion.unet,
            **config.args
        )
        self.network_train = config.train

        self.network.to(self.device, self.train_dtype if self.network_train else self.weight_dtype)
        self.network.train(self.network_train)
        self.network.requires_grad_(self.network_train)

        logger.info("ネットワークを作ったよ!")

Config

 以下のようにnetworkに関する設定を追加すれば、LoRAが適用されます。train_unetやtrain_text_encoderをfalseにすればLoRAの学習になります。

network:
  train: true
  args:
    module: networks.lora.LoRAModule
    module_args:
      rank: 16
    conv_module_args: null
    text_module_args: null

さあLoRA学習だ!

 上の通り設定ファイルを変えるだけでできます。やったああ。

よくある質問(妄想)

Q. 層別学習率はないんですか。
A. ないです。

Q. 層別マーz
A.ないです。

Q. LoRAのメタデータは・・
A. ないです。

Q. なんでDiffusersのLoRAは使わないんですか。
A. わかんないから。

Q. gradient_checkpointingを適用するとき、入力をrequired_grad_(True)になるようにしなくていいんですか。
A. よく分かんないけど実験したらLoRAが学習されないバグ起きなかったんですよね。どこかで直ったのかなあ。