Stable Diffusionの学習コードを作る:5.ControlNet編

 今回はControlNetの学習についてやっていきます。以下のような設定を増やすことで学習できるようにします。

controlnet:
  train: true
  resume: null # model file path
  transformer_layers_per_block: false # default = false
  global_average_pooling: false # default = false

DiffusionModel

create_controlnet

 diffusersのControlNetModelを使います。最初から作る場合はfrom_unetで作れます。
 真ん中らへんはTransformerを省略するためのものです。SDXLでモデルサイズを削減するためにhuggingfaceが提案したものです。空のTransformer省略ControlNetを作った後、ロード済みのControlNetのweightを適用します。
 global_average_poolingはshuffle用ですが、学習時には使わないと思います。

    def create_controlnet(self, config):
        if config.resume is not None:
            pre_controlnet = ControlNetModel.from_pretrained(config.resume)
        else:
            pre_controlnet = ControlNetModel.from_unet(self.unet)  

        if config.transformer_layers_per_block is not None:
            down_block_types = tuple(["DownBlock2D" if l == 0 else "CrossAttnDownBlock2D" for l in config.transformer_layers_per_block])
            transformer_layers_per_block = tuple([int(x) for x in config.transformer_layers_per_block])
            self.controlnet = ControlNetModel.from_config(
                pre_controlnet.config,
                down_block_types=down_block_types,
                transformer_layers_per_block=transformer_layers_per_block,
            )
            self.controlnet.load_state_dict(pre_controlnet.state_dict(), strict=False)
            del pre_controlnet
        else:
            self.controlnet = pre_controlnet
        
        self.controlnet.config.global_pool_conditions = config.global_average_pooling

forward

 UNetの推論と統合します。diffusersの機能をそのまま使うだけなので、別に難しいことはないですね。新たにcontrolnet_hintという入力が増えています。これは値が[0, 1]でサイズが画像と同じ[b, 3, h, w]のてんさーです。

    def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size_condition=None, controlnet_hint=None):
        if self.sdxl:
            if size_condition is None:
                h, w = latents.shape[2] * 8, latents.shape[3] * 8
                size_condition = torch.tensor([h, w, 0, 0, h, w]) # original_h/w. crop_top/left, target_h/w
                size_condition = size_condition.repeat(latents.shape[0], 1).to(latents)
            added_cond_kwargs = {"text_embeds": pooled_output, "time_ids": size_condition}
        else:
            added_cond_kwargs = None

        if self.controlnet is not None:
            assert controlnet_hint is not None, "controlnet_hint is required when controlnet is enabled"
            down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
                latents,
                timesteps,
                encoder_hidden_states=encoder_hidden_states,
                controlnet_cond=controlnet_hint,
                added_cond_kwargs=added_cond_kwargs,
                return_dict=False,
            )
        else:
            down_block_additional_residuals = None
            mid_block_additional_residual = None

        model_output = self.unet(
            latents,
            timesteps,
            encoder_hidden_states,
            added_cond_kwargs=added_cond_kwargs,
            down_block_additional_residuals=down_block_additional_residuals,
            mid_block_additional_residual=mid_block_additional_residual,
        ).sample

データセット

 ControlNetのヒント画像を取り込むメソッドを増やします。上記にある通り[0, 1]のてんさーであり、ToTensor()をするだけです。引数も増やしてますが省略。

    def get_control(self, samples, dir="control"):
        images = []
        transform = transforms.ToTensor()
        for sample in samples:
            image = Image.open(os.path.join(self.path, dir, sample + f".png")).convert("RGB")
            images.append(transform(image))
        images_tensor = torch.stack(images).to(memory_format=torch.contiguous_format).float()
        return images_tensor

 canny edgeの場合、計算が軽いのでわざわざ前処理済み画像を保存せずとも、学習中に計算して取得することもできます。そういった場合はカスタムデータセットを作りましょう。式の中身は全然わかっていません。

from modules.dataset import BaseDataset
import cv2
import os
import torch
import numpy as np
from torchvision import transforms

class CannyDataset(BaseDataset):
    def get_control(self, samples, dir="control"):
        images = []
        transform = transforms.ToTensor()
        for sample in samples:
            # ref https://qiita.com/kotai2003/items/662c33c15915f2a8517e
            image = cv2.imread(os.path.join(self.path, dir, sample + f".png"))
            med_val = np.median(image)
            sigma = 0.33  # 0.33
            min_val = int(max(0, (1.0 - sigma) * med_val))
            max_val = int(max(255, (1.0 + sigma) * med_val))
            image = cv2.Canny(image, threshold1 = min_val, threshold2 = max_val)
            image = image[:, :, None] # add channel
            image = np.concatenate([image]*3, axis=2) # grayscale to rgb
            images.append(transform(image))
        images_tensor = torch.stack(images).to(memory_format=torch.contiguous_format).float()
        return images_tensor

とれーなー

 controlnetの準備はこんな感じ

    def prepare_controlnet(self, config):
        if config is None:
            self.controlnet = None
            self.controlnet_train = False
            logger.info("コントロールネットはないみたい。")
            return
        
        self.diffusion.create_controlnet(config)
        self.controlnet_train = config.train

        self.diffusion.controlnet.to(self.device, self.train_dtype if self.controlnet_train else self.weight_dtype)
        self.diffusion.controlnet.train(self.controlnet_train)
        self.diffusion.controlnet.requires_grad_(self.controlnet_train)

        logger.info("コントロールネットを作ったよ!")

  他にgradient checkpointingをcontrolnetにも適用できるようにするとかそういうどうでもいい変更があります。

 損失の計算ではヒント画像の読み込みを追加します。またdiffusionでcontrolnet_hintも追加で入力します。

        if "controlnet_hint" in batch:
            controlnet_hint = batch["controlnet_hint"].to(self.device)
        else:
            controlnet_hint = None

サンプル生成

 sampleにもcontrolnet_hint関連を追加します。ファイルパスを直接指定することもできるようにします。損失とどうようdiffusionへの引数にも追加します。

        if controlnet_hint is not None:
            if isinstance(controlnet_hint, str):
                controlnet_hint = Image.open(controlnet_hint).convert("RGB")
                controlnet_hint = transforms.ToTensor()(controlnet_hint).unsqueeze(0)
            controlnet_hint = controlnet_hint.to(self.device)
            if guidance_scale != 1.0:
                controlnet_hint = torch.cat([controlnet_hint] *2)

 学習設定でヒント画像のファイルパスを指定すれば任意の画像でテストできるようになります。

  validation_args:
    prompt: "1girl, solo, sitting, blonde hair, red eyes , sailor collar, blue skirt, black thighhighs, room"
    negative_prompt: "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
    width: 832
    height: 1216
    controlnet_hint: "data/pose.png"

モデルファイルについて

 diffusers形式になっています。sgm形式にするためには変換コードがありますがSD1/SD2用になります。SDXLはComfyUIだと中身のsafetensorsをそのまま使えるので変換コードはありません(SD1/SD2でも使えるのかな?)。

Control-LoRA

 SDXLのControlNetはそのままだと2.5GBとかいうくそでかファイルになってしまいますが、学習差分を特異値分解によってLoRAに変換することができます。ControlNet固有のモジュール(input hintやzero conv)はLoRAにせず、またbiasやLayerNormなどの重みがベクトルのものもLoRAにはしません(できません)。

変換コード用意しておきました(どこか別のところにもありそうだけど。。。)

https://github.com/laksjdjf/sd-trainer/blob/main/tools/create_control_lora.py