Stable Diffusionの学習コードを作る:1.生成編

 学習コードに変な機能がいっぱい増えてわけわからなくなっちゃったので作り直すことにしました。せっかくなので記事にしてみます。完全なオレオレ学習コードなので全くゆうざあふれんどりぃでないものになる予定です!
 方針として、モデルの定義以外は自前で実装します。モデルの定義はhuggingfaceとの連携辺りがめんどくさいのでやりたくないです。
 ※Noteに貼り付けているコードはGithubのものより古かったり簡略化している可能性があります。

https://github.com/laksjdjf/sd-trainer

記事予定
1. 生成編
2. 学習編
3. LoRA編
4. LCM-LoRA編
以下願望
?.学習速度・VRAM使用量確認編
?. ControlNet編
?. LECO編
?. Stable Cascaded編
?. LoHA, LoKR編
?. IP-Adapter編
?. Textual Inversion編

まずはStable DIffusionの生成ができるようにします。生成コードは学習時にはサンプル画像の生成でしか使いませんが、とりあえず生成できるようにしないとコードの正しさがテストできないのでそうしました。

以下のようなクラスを定義します。

TextModel(プロンプトを処理)
DiffusionModel(ノイズ予測モデル全般を処理)
BaseScheduler(サンプラー・スケジューラーを実装)
BaseDataset(データセットを実装)
BaseTrainer(学習・簡単な生成を実装)

言うまでもなくsd-scriptsを参考にしたりパクったりトレースしたりしています。


プロンプトの処理(TextModel)

CLIPの構造、あんまり詳しくないけどこんなイメージでしょ。

 プロンプトを入力し、テキスト埋め込みを得るTextModelクラスを作成します。SDではテキスト関連はプロンプトをトークンID列にするトークナイザーと、それを解析するCLIPテキストエンコーダに分かれます。SDXLではさらにそれぞれ二種類あるので、4つもあって面倒なので、統一できるようにしましょう。

https://github.com/laksjdjf/sd-trainer/blob/main/modules/text_model.py

コンストラクタとか

 コンストラクタやモデルのロードについてはこんな感じ。text_encoder_2だけCLIPTextModelWithProjectionとかいう変なクラスになっていますが、上の画像にあるpooled_outputを得るためのものですね(結局この機能は使いませんが)。クラスメソッドでhuggingfaceのモデルを直接ロードできるようにもしておきます。

class TextModel(nn.Module):
    def __init__(
        self, 
        tokenizer:CLIPTokenizer, 
        tokenizer_2:CLIPTokenizer, 
        text_encoder:CLIPTextModel, 
        text_encoder_2:CLIPTextModelWithProjection, 
        clip_skip:int=-1
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.tokenizer_2 = tokenizer_2

        self.text_encoder = text_encoder
        self.text_encoder_2 = text_encoder_2

        self.clip_skip = clip_skip
        self.sdxl = tokenizer_2 is not None

    @classmethod
    def from_pretrained(cls, path, sdxl=False, clip_skip=-1):
        tokenizer = CLIPTokenizer.from_pretrained(path, subfolder='tokenizer')
        text_encoder = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
        if sdxl:
            tokenizer_2 = CLIPTokenizer.from_pretrained(path, subfolder='tokenizer_2')
            text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(path, subfolder='text_encoder_2')
        else:
            tokenizer_2 = None
            text_encoder_2 = None
        return cls(tokenizer, tokenizer_2, text_encoder, text_encoder_2, clip_skip=clip_skip)

トークナイザー

 トークンはテキストエンコーダモデルの限度である77個(BOS, EOSトークン含む)までで切り捨てられます。長すぎるpromptは最後の方無視されちゃいますね。その辺を解決する方法もありますが、今回のコードではやりません。逆に足りない場合はEOSもしくはPADトークンでパディングされます。その辺の設定がpaddingやtruncationの意味ですね。

def tokenize(self, texts):
    tokens = self.tokenizer(
        texts, 
        max_length=self.tokenizer.model_max_length, # 77
        padding="max_length",
        truncation=True,
        return_tensors='pt'
    ).input_ids.to(self.text_encoder.device)

    if self.sdxl:
        tokens_2 = self.tokenizer_2(
            texts, 
            max_length=self.tokenizer_2.model_max_length, 
            padding="max_length",
            truncation=True, 
            return_tensors='pt'
        ).input_ids.to(self.text_encoder_2.device)
    else:
        tokens_2 = None

    return tokens, tokens_2

引数のtextsは文字列のリストです。

テキストエンコーダー

 SD1/SD2ではテキストエンコーダーのlast_hidden_stateにlayernormを通したものが使われます。clip_skipを設定する場合代わりにlast_hidden_statesより手前の層の出力を使います。clip_skipは-2のように設定すると最後から2番目になります。
SDXLでは二つのテキストエンコーダー出力を結合します。また二つ目のテキストエンコーダーのpooled_outputが必要ですが、自前で計算します。CLIPTextModelWithProjectionクラスには、pooled_outputを直接取り出す機能がありますが、Textual inversionでトークンの種類を増やすとEOSの位置がずれてしまうバグが起こる可能性があります。この辺の計算はsd-scriptsを参考にしています。ちなみに空文の場合はpooled_outputを0にする処理を行います。

def get_hidden_states(self, tokens, tokens_2=None):
    encoder_hidden_states = self.text_encoder(tokens, output_hidden_states=True).hidden_states[self.clip_skip]
    if self.sdxl:
        encoder_output_2 = self.text_encoder_2(tokens_2, output_hidden_states=True)
        last_hidden_state = encoder_output_2.last_hidden_state

        # calculate pooled_output
        eos_token_index = torch.where(tokens_2 == self.tokenizer_2.eos_token_id)[1].to(device=last_hidden_state.device)
        pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),eos_token_index]
        pooled_output = self.text_encoder_2.text_projection(pooled_output)

        encoder_hidden_states_2 = encoder_output_2.hidden_states[self.clip_skip]

        # (b, n, 768) + (b, n, 1280) -> (b, n, 2048)
        encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=2)

        # pooled_output is zero vector for empty text            
        for i, token in enumerate(tokens_2):
            if token[1].item() == self.tokenizer_2.eos_token_id: # 二番目がEOSなら空文
                pooled_output[i] = 0
    else:
        encoder_hidden_states = self.text_encoder.text_model.final_layer_norm(encoder_hidden_states)
        pooled_output = None

    return encoder_hidden_states, pooled_output

forward

二つのメソッドをまとめてプロンプトをそのまま処理できるようにします。

def forward(self, prompts):
    tokens, tokens_2 = self.tokenize(prompts)
    encoder_hidden_states, pooled_output = self.get_hidden_states(tokens, tokens_2)
    return encoder_hidden_states, pooled_output

ノイズ予測モデル(DiffusionModel)

 ノイズ予測をするモデルです。現状UNetをかぶせているだけであまり意味がないんですが、後々ControlNetやIP-Adapterを追加することを考えています。

https://github.com/laksjdjf/sd-trainer/blob/main/modules/diffusion_model.py

class DiffusionModel(nn.Module):
    def __init__(
        self, 
        unet:UNet2DConditionModel,
        sdxl:bool=False,
    ):
        super().__init__()
        self.unet = unet
        self.sdxl = sdxl
    
    def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size_condition=None):
        if self.sdxl:
            if size_condition is None:
                h, w = latents.shape[2] * 8, latents.shape[3] * 8
                size_condition = torch.tensor([h, w, 0, 0, h, w]) # original_h/w. crop_top/left, target_h/w
                size_condition = size_condition.repeat(latents.shape[0], 1).to(latents)
            added_cond_kwargs = {"text_embeds": pooled_output, "time_ids": size_condition}
        else:
            added_cond_kwargs = None

        model_output = self.unet(
            latents,
            timesteps,
            encoder_hidden_states,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        return model_output

 latentsはノイズ付き潜在変数、encoder_hidden_statesとpooled_outputはTextModelの出力です。size_conditionはSDXLで必要な画像の解像度情報です。ここの話は複雑なので学習編でやろうかな。簡単に説明すると学習データのリサイズ前解像度縦横、学習データの切り抜きサイズ縦横、学習データのリサイズ後解像度縦横の6要素のTensorになってます。生成時は入力データサイズから決めることができます。

スケジューラー(BaseScheduler)

 スケジューラーというよりサンプラーといった方が正しいような気がしますが、Diffusersではスケジューラーというのでなんかそうすることにしました。Diffusersのコードはあまりにもぐちゃぐちゃなので自前で用意します。ただし学習コードなので最低限の実装にします。

https://github.com/laksjdjf/sd-trainer/blob/main/modules/scheduler.py

コンストラクタ

 コンストラクタでは時刻ごとのノイズの大きさをスケジューリングします。SD1/SD2/SDXLで全部共通なので特に設定はありません。この辺はこの記事が詳しいです。v_predictionはUNetがノイズ予測ではなくvelocity予測モデルになっているSD2用の設定です。
最終的に時刻tでxにノイズを加えたxtは、
xt = sqrt_alpha_bar[t] * x + sqrt_beta_bar[t] * noise
になります。
※一般に拡散モデルではt=0がノイズのない状態で、t=[1,1000]でノイズが大きくなっていくよう定義されますが、ここではUNetの仕様上tが1ずれており、t=0からすでにノイズが加えられていることに注意してください。

class BaseScheduler:
    def __init__(self, v_prediction=False):
        self.v_prediction = v_prediction  # velocity予測かどうか
        self.make_alpha_beta()

    def make_alpha_beta(self, beta_start=0.00085, beta_end=0.012, num_timesteps=1000):
        self.num_timesteps = num_timesteps

        # beta_1, ... , beta_T
        self.betas = (torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps, dtype=torch.float32) ** 2)

        # alpha_1, ... , alpha_T
        self.alphas = 1 - self.betas

        # with bar
        self.alphas_bar = torch.cumprod(self.alphas, dim=0)
        self.betas_bar = 1 - self.alphas_bar

        self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
        self.sqrt_betas_bar = torch.sqrt(self.betas_bar)

 さてここで定義されたself.alphasとかはtorchのTensorになっています。これに時刻を入れて潜在変数やノイズと積をとったりしたいんですが、deviceがcpuかcudaかとかブロードキャストで問題が起きます。生成時はバッチの各要素でtが同じであり、スカラーになるんで問題ないですが、学習時は各要素でtをばらばらにするのでその辺の処理が必要です。そのため以下のような便利関数を作っておきます。timestepsがスカラーのときは要素数分拡張して、constantsをtimestepsと同じデバイスに移動したのち指定したtimestepの値を得ます。最後の[:, None, None, None]は潜在変数との計算でブロードキャストするためのものですね。

def substitution_t(constants, timesteps, batch_size): 
    if timesteps.dim() == 0: # 全要素同じtの場合(生成用)
        timesteps = timesteps.repeat(batch_size)
    device = timesteps.device
    constants = constants.to(device)[timesteps][:,None, None, None] # 4dims for latents
    return constants

時刻スケジュール

 生成用の時刻は単純なものにしています。999から0までステップ数に1を加えた数だけ線形に区分けします。最後に0を切り捨てます。0を切り捨てるのは、t=0はノイズがほとんどない状態でそこからさらにノイズ除去したところであんまり意味がないからです。ノイズ時刻スケジュールにはいろいろな実装によって細かい違いがあるんですが、どうせ学習時のサンプル用に使うだけなんでこれが一番楽でいいと思います。

def set_timesteps(self, num_inference_steps, device="cuda"):
    self.num_inference_steps = num_inference_steps
    timesteps = torch.linspace(0, self.num_timesteps-1, num_inference_steps+1, dtype=float).round()
    return timesteps.flip(0)[:-1].clone().long().to(device) # [999, ... , n]

ノイズ付与

 式通りにノイズを付与して時刻tの潜在変数xtを得ます。

# x0 -> xt    
def add_noise(self, sample, noise, t):
    sqrt_alphas_bar = substitution_t(self.sqrt_alphas_bar, t, sample.shape[0])
    sqrt_betas_bar = substitution_t(self.sqrt_betas_bar, t, sample.shape[0])
    
    return sqrt_alphas_bar * sample + sqrt_betas_bar * noise

元の潜在予測とノイズ予測

 ノイズ付き潜在変数xtとモデルの出力から、元の潜在変数及びノイズの予測結果を計算するメソッドです。計算式は、ノイズ予測の場合は割と単純ですね。元の潜在変数は上のadd_noiseの関係式を変形すればいいし、ノイズはモデルの出力そのものです。v_predictionモデルの場合は…論文を参考にしてください。

    def pred_original_sample(self, sample, model_output, t):
        sqrt_alphas_bar = substitution_t(self.sqrt_alphas_bar, t, sample.shape[0])
        sqrt_betas_bar = substitution_t(self.sqrt_betas_bar, t, sample.shape[0])
        
        if self.v_prediction:
            return sqrt_alphas_bar * sample - sqrt_betas_bar * model_output
        else: # noise_prediction
            return (sample - sqrt_betas_bar * model_output) / sqrt_alphas_bar
        
    def pred_noise(self, sample, model_output, t):
        sqrt_alphas_bar = substitution_t(self.sqrt_alphas_bar, t, model_output.shape[0])
        sqrt_betas_bar = substitution_t(self.sqrt_betas_bar, t, model_output.shape[0])
        
        if self.v_prediction:
            return sqrt_alphas_bar * model_output + sqrt_betas_bar * sample
        else: # noise_prediction
            return model_output

生成1ステップ分の計算

 DDIMの場合、上の二つのメソッドから簡単に計算できます。ちなみにEulerとはアルゴリズムが違いますが、数学上は同じです。

    # x_t -> x_prev_t
    def step(self, sample, model_output, t, prev_t):
        original_sample = self.pred_original_sample(sample, model_output, t)
        noise_pred = self.pred_noise(sample, model_output, t)

        return self.add_noise(original_sample, noise_pred, prev_t)

Trainer(生成部分のみ)

 Trainerという名前だけど生成部分のみやっていきます。あくまで学習コードなので、生成はそのことを前提とした実装になります。

https://github.com/laksjdjf/sd-trainer/blob/main/modules/trainer.py

モデルのロード

 その前にモデルのロード方法についてやります。StabilityAI式のckpt(safetensors)ファイルにもDiffusers式のフォルダにも対応できるようにするっていうかこの辺はDiffusersに頼っているだけです。めんどくさいことにSD1/SD2とSDXLでぱいぷらいんが違うので場合分けします(DiffusionPipelineクラスを使えばいいだけかもしれない)。ついでにTextModelのclip_skipについても設定しておきます。

def load_model(path, sdxl=False, clip_skip=-1):
    if sdxl:
        if os.path.isfile(path):
            pipe = StableDiffusionXLPipeline.from_single_file(path, scheduler_type="ddim")
            tokenizer = pipe.tokenizer
            tokenizer_2 = pipe.tokenizer_2
            text_encoder = pipe.text_encoder
            text_encoder_2 = pipe.text_encoder_2
            unet = pipe.unet
            vae = pipe.vae
            scheduler = pipe.scheduler
            text_model = TextModel(tokenizer, tokenizer_2, text_encoder, text_encoder_2)
            del pipe
        else:
            text_model = TextModel.from_pretrained(path, sdxl=True)
            unet = UNet2DConditionModel.from_pretrained(path, subfolder='unet')
            vae = AutoencoderKL.from_pretrained(path, subfolder='vae')
            scheduler = DDPMScheduler.from_pretrained(path, subfolder='scheduler')
    else:
        if os.path.isfile(path):
            pipe = StableDiffusionPipeline.from_single_file(path, scheduler_type="ddim")
            tokenizer = pipe.tokenizer
            text_encoder = pipe.text_encoder
            unet = pipe.unet
            vae = pipe.vae
            scheduler = pipe.scheduler
            text_model = TextModel(tokenizer, None, text_encoder, None)
            del pipe
        else:
            text_model = TextModel.from_pretrained(path)
            unet = UNet2DConditionModel.from_pretrained(path, subfolder='unet')
            vae = AutoencoderKL.from_pretrained(path, subfolder='vae')
            scheduler = DDPMScheduler.from_pretrained(path, subfolder='scheduler')
            
    text_model.clip_skip = clip_skip
    return text_model, vae, unet, scheduler

コンストラクタ・ロードメソッド

 まあ大したことはしていません。SDXLのclip_skipはデフォルトで-2なのでそうなるようにしています。

class BaseTrainer:
    def __init__(self, config, diffusion:DiffusionModel, text_model:TextModel, vae:AutoencoderKL, scheduler, network:NetworkManager):
        self.config = config
        self.diffusion = diffusion
        self.text_model = text_model
        self.vae = vae
        self.network = network
        self.diffusers_scheduler = scheduler # モデルのセーブ次にのみ利用
        self.scheduler = BaseScheduler(scheduler.config.prediction_type == "v_prediction")
        self.sdxl = text_model.sdxl

    @classmethod
    def from_pretrained(cls, path, sdxl, clip_skip=None, config=None, network=None):
        if clip_skip is None:
            clip_skip = -2 if sdxl else -1
        text_model, vae, unet, scheduler = load_model(path, sdxl, clip_skip)
        diffusion = DiffusionModel(unet, sdxl)
        return cls(config, diffusion, text_model, vae, scheduler, network)

deviceやdtypeの移動

 これは生成用のメソッドです。vaeだけ別の型を指定できるのは、SDXLのVAEがtorch.float16で利用すると(NaN;)になるからです。生成用なのでeval()にしちゃいます。

def to(self, device="cuda", dtype=torch.float16, vae_dtype=None):
    self.device = device
    self.te_device = device
    self.vae_device = device
    
    self.autocast_dtype = dtype
    self.vae_dtype = vae_dtype or dtype

    self.diffusion.unet.to(device, dtype=dtype).eval()
    self.text_model.to(device, dtype=dtype).eval()
    self.vae.to(device, dtype=self.vae_dtype).eval()

VAEの計算(潜在変数のエンコード、デコード)

 エンコードでは画像を[-1,1]のTensorにして、VAEのエンコーダに渡します。デコードでは潜在変数をVAEデコーダに渡して、画像を[0, 255]のPIL.Imageリストにします。デコーダー側は学習時のサンプル生成に使うので、VRAM使用量のピークにならないよう各要素はまとめずに1枚ずつやります。エンコーダは潜在変数のキャッシュをすれば学習時に使わず済むのでまとめてやってます。
 学習時にVAEをCPUに置いておくモードを実装するため、最初と最後にいちいちデバイスを移すコードがあります。

@torch.no_grad()
def encode_latents(self, images):
    self.vae.to("cuda")
    to_tensor_norm = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )
    images = torch.stack([to_tensor_norm(image) for image in images]).to(self.vae.device)
    latents = self.vae.encode(images).latent_dist.sample()
    self.vae.to(self.vae.device)
    return latents 

@torch.no_grad()
def decode_latents(self, latents):
    self.vae.to("cuda")
    images = []

    for i in range(latents.shape[0]):
        image = self.vae.decode(latents[i].unsqueeze(0)).sample
        images.append(image)
    images = torch.cat(images, dim=0)
    images = (images / 2 + 0.5).clamp(0, 1)
    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    self.vae.to(self.vae.device)
    return pil_images

生成メソッド

 おなじみの設定を入力して、生成画像をPILのImageリストとして返します。どうせ使わないんですが(だからテストもしてない)img2imgもできるようにしています。
 生成はまず解像度に合わせてゼロ潜在変数を作ります。次に同じサイズでノイズを作成して、時刻t=999でそのノイズを加えます。img2imgの場合はゼロ潜在変数ではなく、入力画像をエンコードした潜在変数を用意して、時刻t<999でノイズを加えるわけですね。あとは時刻スケジューリングに基づいてどんどんノイズを除去していきます。
※txt2imgはimg2imgの特殊ケースに過ぎないという発想ですね。この方式はComfyUIを参考にしています。t=999だとわずかにノイズが弱まるのでちょっとずれていますけどね(sqrt_beta_bar[999]倍される)。ComfyUIの場合Eulerのような分散発散型xt=x+σεを採用しているのでずれは起きません。
 guidance_scaleが1.0のときはネガティブプロンプト側の計算が必要ないので省略するようにしています。

@torch.no_grad()
def sample(self, prompt="", negative_prompt="", batch_size=1, height=768, width=768, num_inference_steps=30, guidance_scale=7.0, denoise=1.0, seed=4545, images=None):
    rng_state = torch.get_rng_state()
    cuda_rng_state = torch.cuda.get_rng_state()

    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    if guidance_scale != 1.0:
        prompt = [negative_prompt] * batch_size + [prompt] * batch_size
    else:
        prompt = [prompt] * batch_size

    timesteps = self.scheduler.set_timesteps(num_inference_steps, self.device)
    timesteps = timesteps[int(num_inference_steps*(1-denoise)):]

    if images is None:
        latents = torch.zeros(batch_size, 4, height // 8, width // 8, device=self.device, dtype=self.autocast_dtype)
    else:
        with torch.autocast("cuda", dtype=self.vae_dtype):
            latents = self.encode_latents(images) * self.vae.scaling_factor
        latents.to(dtype=self.autocast_dtype)

    noise = torch.randn_like(latents)
    latents = self.scheduler.add_noise(latents, noise, timesteps[0])

    self.text_model.to("cuda")
    with torch.autocast("cuda", dtype=self.autocast_dtype):
        encoder_hidden_states, pooled_output = self.text_model(prompt)
    self.text_model.to(self.te_device)

    progress_bar = tqdm(timesteps, desc="Sampling", leave=False, total=len(timesteps))

    for i, t in enumerate(timesteps):
        with torch.autocast("cuda", dtype=self.autocast_dtype):
            latents_input = torch.cat([latents] * (2 if guidance_scale != 1.0 else 1), dim=0)
            model_output = self.diffusion(latents_input, t, encoder_hidden_states, pooled_output)

        if guidance_scale != 1.0:
            uncond, cond = model_output.chunk(2)
            model_output = uncond + guidance_scale * (cond - uncond)

        if i+1 < len(timesteps):
            latents = self.scheduler.step(latents, model_output, t, timesteps[i+1])
        else:
            latents = self.scheduler.pred_original_sample(latents, model_output, t)
        progress_bar.update(1)

    with torch.autocast("cuda", dtype=self.vae_dtype):
        images = self.decode_latents(latents / self.vae.scaling_factor)

    torch.set_rng_state(rng_state)
    torch.cuda.set_rng_state(cuda_rng_state)

    return images

さあ生成だ!

 ここまで用意すれば生成は難しくありません。

import torch
from modules.trainer import BaseTrainer

model_path = <model_path>
sdxl = True # or False
prompt = "1girl,solo"
negative_prompt = "lowres, bad anatomy, bad hands, text, error"
batch_size = 4
height, width = 896, 640
guidance_scale = 7.0
step = 30
seed = 2424

trainer = BaseTrainer.from_pretrained("model_path", sdxl)
trainer.to(device="cuda", dtype=torch.bfloat16)

images = trainer.sample(prompt, negative_prompt, batch_size, height, width, step, guidance_scale, seed=seed, denoise=1)

たぶんできる。

よくある質問(妄想)

Q. xformersは使わないんですかあ。
A.Pytorch2.2からFlash Attention2が適用されるようになったため、いらないかもしれない。

Q. トークン長の拡張はしないんですかあ。
A. よくわからないのでやりません。