Stable Diffusionの学習コードを作る:3.LoRA編
前回はフルファインチューニングをやりましたが、今回はLoRAの学習ができるようにします。
NetworkManager
モデルへ追加ネットワークを適用する処理を行うクラスです。今回はLoRAしかやりませんが、後々LoHAとかいろいろなものを実装する予定なので、NetworkManagerという名前になっています。そのくせ変数名がlora前提になっていたりとちょっと見直さなきゃいけない部分もありそう。
https://github.com/laksjdjf/sd-trainer/blob/main/networks/manager.py
コンストラクタ
LoRAはテキストエンコーダ/UNetのトランスフォーマー層/UNetのResNet層(Upsample, Downsample含む)の3つで処理を分けます。
moduleは文字列で指定します。"networks.lora.LoRAModule"とかね。module_argsはTransformer、conv_module_argsはResNet、text_module_argsはテキストエンコーダに適用するmoduleへの引き数で、LoRAの場合はrankやalphaを指定できます。なんかTransformerだけは必須になっていますが、変更するかもしれません。
modeはLoRAを元のモデルと分けて計算するか、元のモデルにマージしてしまうかの設定です。学習時は"apply"にしますが、生成時は"merge"の方が計算が早くなります。
unet_key_filtersは、文字列のリストで、LoRAの適用範囲をkey名で適用できます。["up_blocks"]にすればup_blockのみのLoRAができあがります。要素を増やすとORで計算します。ANDはないけどまあいいか。
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
LORA_PREFIX_TEXT_ENCODER_1 = 'lora_te1'
LORA_PREFIX_TEXT_ENCODER_2 = 'lora_te2'
class NetworkManager(nn.Module):
def __init__(
self,
text_model,
unet,
module,
module_args,
unet_key_filters=None,
conv_module_args=None,
text_module_args=None,
multiplier=1.0,
mode="apply", # select "apply" or "merge"
):
super().__init__()
self.multiplier = multiplier
self.apply_te = text_module_args is not None
self.module = get_attr_from_config(module)
# unetのloraを作る
self.unet_modules = []
self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_TRANSFORMER, module_args, unet_key_filters)
if conv_module_args is not None:
self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_CONV, conv_module_args, unet_key_filters)
if self.apply_te:
self.text_encoder_modules = []
if text_model.sdxl:
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_1, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, text_module_args)
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_2, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, text_module_args)
else:
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, text_module_args)
else:
self.text_encoder_modules = []
logger.info(f"UNetのモジュールは: {len(self.unet_modules)}個だよ。")
logger.info(f"TextEncoderのモジュールは: {len(self.text_encoder_modules)}個だよ。")
for lora in self.text_encoder_modules + self.unet_modules:
self.add_module(lora.lora_name, lora)
if mode == "apply":
self.apply_to()
elif mode == "merge":
self.merge_to()
else:
raise ValueError(f"mode {mode} is not supported.")
apply_toとか
いっぱい並んでいますが、各モジュールでapply_toなどを処理させるだけのメソッドです。
create_modules
モデルからLinear層やConv層を検索してLoRAを作成します。二重ループになってますが、一つ目のループは全モジュールを対象にします。その後target_replace_modulesで指定したモジュールを対象に二つ目のループに入ります。二つ目のループで適用対象のモジュールとその名前、適用強度や引数を渡してLoRAを作ってもらいます。
def create_modules(self, prefix, root_module, target_replace_modules, module_args, unet_limited_keys=None) -> list:
modules = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
if is_key_allowed(lora_name, unet_limited_keys):
lora = self.module(lora_name, child_module, self.multiplier, **module_args)
modules.append(lora)
return modules
prepare_optimizer_params
最適化関数に渡すパラメータを出力します。学習率はテキストエンコーダとUNetで分かれています。前回紹介したTrainerのprepare_optimizerで使われています。
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
self.requires_grad_(True)
all_params = []
if self.text_encoder_modules:
params = []
[params.extend(lora.parameters()) for lora in self.text_encoder_modules]
param_data = {'params': params}
if text_encoder_lr is not None:
param_data['lr'] = text_encoder_lr
all_params.append(param_data)
if self.unet_modules:
params = []
[params.extend(lora.parameters()) for lora in self.unet_modules]
param_data = {'params': params}
if unet_lr is not None:
param_data['lr'] = unet_lr
all_params.append(param_data)
return all_params
save_weights, load_weights
重みのセーブとかロードです。大したことないコードなので省略します。
set_temporary_multiplier
なんかcontextmanagerとかいうやつで、LoRAの強度を指定します。with構文で使えます。LCM-LoRAの学習で利用する予定で、今回は使いません。
@contextlib.contextmanager
def set_temporary_multiplier(self, multiplier):
for lora in self.text_encoder_modules + self.unet_modules:
lora.multiplier = multiplier
yield
for lora in self.text_encoder_modules + self.unet_modules:
lora.multiplier = 1.0
BaseModule
https://github.com/laksjdjf/sd-trainer/blob/main/networks/lora.py
LoRAなどを定義する際、apply_to等のコードはどのモジュールでも変わらないと思うので、その辺を定義する基底クラスを用意しておきます。
apply_toはforwardを置き換えてLoRAを適用します。一方merge_toは元のモデルの重みにLoRAをマージすることで適用します。
class BaseModule(torch.nn.Module):
def __init__(self):
super().__init__()
def apply_to(self, multiplier=None):
if multiplier is not None:
self.multiplier = multiplier
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def unapply_to(self):
self.org_module[0].forward = self.org_forward
def merge_to(self, multiplier=None, sign=1):
lora_weight = self.get_weight(multiplier) * sign
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
weight = org_weight + lora_weight.to(org_weight)
# set weight to org_module
org_sd["weight"] = weight
self.org_module[0].load_state_dict(org_sd)
def restore_from(self, multiplier=None):
self.merge_to(multiplier=multiplier, sign=-1)
def get_weight(self, multiplier=None):
raise NotImplementedError
LoRAModule
LoRAはrankやalphaを指定して作ります。おれはalpha=1.0に固定してますけど。
Linear層の場合、入力次元⇒rankのdown層とrank⇒出力次元のup層に分けます。Conv層の場合も基本は同じですが、down層は適用対象と同じカーネルサイズ、up層は1×1カーネルになります。
LoRAはalpha/rankでスケーリングします(元論文がそうしている)。alphaはregister_bufferという勾配計算の対象にならない定数を作成する機能を利用しています。
down層の初期重みはよくわかんない、up層はゼロになります。up層がゼロなので学習前のLoRAを適用しても元のモデルと出力は変わりません。
class LoRAModule(BaseModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, rank=4, alpha=1):
super().__init__()
self.lora_name = lora_name
self.rank = rank
if 'Linear' in org_module.__class__.__name__: # ["Linear", "LoRACompatibleLinear"]
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
elif 'Conv' in org_module.__class__.__name__: # ["Conv2d", "LoRACompatibleConv"]
in_dim = org_module.in_channels
out_dim = org_module.out_channels
self.rank = min(self.rank, in_dim, out_dim)
if self.rank != rank:
print(f"{lora_name} dim (rank) is changed to: {self.rank} because of in_dim or out_dim is smaller than rank")
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(
in_dim, self.rank, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(
self.rank, out_dim, (1, 1), (1, 1), bias=False)
self.shape = org_module.weight.shape
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = rank if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.rank
self.register_buffer('alpha', torch.tensor(alpha))
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.org_module = [org_module] # moduleにならないようにlistに入れる
forward
別にむずかしいことはないですね。わざわざlora_forwardとforwardを分けて実装しているのは、特に意味ないです(こうなった経緯は複雑だけど)。
forwardのscale引数は使ってないじゃんって感じですがこれはDiffusersのせいで必要になります。**kwargsとかでもいい気がしますがどっちの方が安全なんやろ。
def lora_forward(self, x):
return self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x, scale = None):
if self.multiplier == 0.0:
return self.org_forward(x)
else:
return self.org_forward(x) + self.lora_forward(x)
get_weight
学習には使いませんが、ΔWを計算するメソッドです。Linearの場合二つの行列の積なので分かりやすいと思いますが、Convの場合よく分からんですよね。1×1カーネルなんてサイズ(out, in, 1, 1)になっていて実質Linearだから要素1の次元を消せば同じような計算になります。3×3カーネルの場合は、入力次元が9倍になっただけと思えばいいです。(out, in, k, k)を(out, in * k * k)にreshapeしたら行列です。行列積を計算したら今度は元の重みと同じサイズにreshapeするだけでいいです。
# calculate lora weight (delta W)
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
up_weight = self.lora_up.weight.view(-1, self.rank) # out_dim, rank
down_weight = self.lora_down.weight.view(self.rank, -1) # rank, in_dim*kernel*kernel
lora_weight = up_weight @ down_weight # out_dim, in_dim*kernel*kernel
lora_weight = lora_weight.view(self.shape) # out_dim, in_dim, [kernel, kernel]
return lora_weight * multiplier * self.scale
Trainer
prepare_network
前回省略しましたが、main.pyで使っています。
def prepare_network(self, config):
if config is None:
self.network = None
self.network_train = False
logger.info("ネットワークはないみたい。")
return
self.network = NetworkManager(
text_model=self.text_model,
unet=self.diffusion.unet,
**config.args
)
self.network_train = config.train
self.network.to(self.device, self.train_dtype if self.network_train else self.weight_dtype)
self.network.train(self.network_train)
self.network.requires_grad_(self.network_train)
logger.info("ネットワークを作ったよ!")
Config
以下のようにnetworkに関する設定を追加すれば、LoRAが適用されます。train_unetやtrain_text_encoderをfalseにすればLoRAの学習になります。
network:
train: true
args:
module: networks.lora.LoRAModule
module_args:
rank: 16
conv_module_args: null
text_module_args: null
さあLoRA学習だ!
上の通り設定ファイルを変えるだけでできます。やったああ。
よくある質問(妄想)
Q. 層別学習率はないんですか。
A. ないです。
Q. 層別マーz
A.ないです。
Q. LoRAのメタデータは・・
A. ないです。
Q. なんでDiffusersのLoRAは使わないんですか。
A. わかんないから。
Q. gradient_checkpointingを適用するとき、入力をrequired_grad_(True)になるようにしなくていいんですか。
A. よく分かんないけど実験したらLoRAが学習されないバグ起きなかったんですよね。どこかで直ったのかなあ。