ComfyUIにおけるUNet改造ノードの作り方

誰が得するんだろうこんな記事


ComfyUIのカスタムノードについて

新しいノードを作るには、以下の記事が詳しいです。

 UNetを改造するとき、たとえばMODELを受け取って好き勝手いじって、MODELを出力するという話なら簡単です。しかし入力時にコピーされるわけではないので、適当に改変してしまうと入力側も変わってしまいます。ComfyUIはノードベースの生成UIであり、複数のノードにMODELを分岐させることがあります。たとえば新しく作ったカスタムノードを通ったMODELと通らなかったMODELの二つでそれぞれ生成して、効果を比較するみたいなことを自然にやりたくなるので、ちゃんと入力側は影響を受けないように改変することになります。ComfyUIにはそのための仕掛けが色々されています。

MODELについて

ComfyUIのMODELの正体は、ModelPatcherです。このクラスにはUNetの様々な部分にパッチをあてる機能があります。LoraLoaderノード等では、このパッチ部分だけをコピーすることで、巨大なUNetをコピーせずに様々なLoRA設定で並列に生成できるというわけです。ModelPatcherにBaseModelがあり、その中にUnetがあります。入力されたMODELからUNetにアクセスするためには、以下のような方法になります。

# 例として、SDXLかどうか判定するコード(他にもいい方法ありそうだけど・・・)
is_sdxl = hasattr(new_model.model.diffusion_model, "label_emb")

いっぱい改変するぞ

例として、DeepShrinkの実装ノードを見てみましょう。

m = model.clone()
if downscale_after_skip:
    m.set_model_input_block_patch_after_skip(input_block_patch)
else:
    m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)

model.clone()するといい感じにパッチだけコピーしたmodelが返ってくるっぽいです。これをせずにmodelに対して直接改変してしまうと、入力側の
modelも改変されてしまいます。
.set_model_*_patchみたいな関数がいっぱい用意されてるので、目的に応じてパッチを当てるという感じです。それでは各パッチについて説明していきます。
ちなみに渡すパッチは関数っぽく使えればいいので、__call__メソッドがあるクラスのインスタンス変数でもいいです。

set_model_attn1_patch, set_model_attn2_patch

 attn1はself attentionの直前で適用されるパッチです。self attntionのq,k,vそれぞれへの入力とextra_options(後述)を受け取って、改変したq,k,vへの入力を返します。self attentionですから、基本的に三つの入力は同じものになっているはずですね。
 attn2もほとんど同じです。k,vへの入力は基本的にテキストエンコーダの出力になっています。
 全てのattentionに適用されるため、限定するには後述のextra_optionsの情報を使います。

利用例:hypertile

set_model_attn(1or2)_output_patch

 これは出力に対するパッチです。さっきと同じ感じです。

set_model_attn(1or2)_replace

 こちらはパッチではなくAttentionを置き換えます。例えば意味ないけどそのままAttentionを実行するには以下のような関数を渡せばいいです。

from comfy.ldm.modules.attention import optimized_attention
def replace_function(q, k, v, extra_options):
    return optimized_attention(q, k, v, extra_options["n_heads"])

optimized_attentionはxformersなど、設定に応じて適切なattentionが呼び出されます。入力は既にto_q,to_k,to_vを通ったもので、出力はto_outの前です。先ほどまでのパッチと違い、複数適用はできません。そのためできればこちらは使いたくないですね(使いまくってるけど)。
 また適用には一工夫必要です。replaceはpatchと違い、各モジュールごとに適用します。block_nameとnumberが必要です。block_nameは"input", "middle", "output"のどれかで、numberは何番目のブロックかを表します。これはSDXLに対応していません・・・。SDXLには各ブロックごとにAttentionが複数個ありますからね。SDXLに対応するためには以下のような関数が必要です。

# attn2の場合
def set_model_patch_replace(model, patch, key):
    to = model.model_options["transformer_options"]
    if "patches_replace" not in to:
        to["patches_replace"] = {}
    if "attn2" not in to["patches_replace"]:
        to["patches_replace"]["attn2"] = {}
    to["patches_replace"]["attn2"][key] = patch

ここでkeyはSD1系の場合、(block_name, number)の二要素タプル、SDXLの場合そこからAttention層の何個目かを表す整数を加え、(block_name, number, , transformer_index)の形にする必要があります。

ここでkeyについての情報を書いておきます。
SD1系(SD2系も同様)のkeyは
("input", [1,2,4,5,7,8]), ("middle", 0), ("output", [3,4,5,6,7,8,9,10,11])の三種類計16個あります。

SDXL系のkeyは
("input", [4,5], [0,1]), ("input", [7,8], [0,…,9])
("middle", [0,…,9])
("output", [0,1,2], [0,…,9]), ("output", [3,4,5], [0,1])
となります。

Attentionのモジュールを直接参照したいときは、以下のようなコードになります。

# ("input", block_id, transformer_index)
new_model.model.diffusion_model.input_blocks[block_id][1].transformer_blocks[transformer_index].attn2

例としてattention_coupleではCross Attention層の全モジュールを置き換える必要があるので、以下のような実装を行っています。

self.sdxl = hasattr(new_model.model.diffusion_model, "label_emb")
if not self.sdxl:
    for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention
        set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.input_blocks[id][1].transformer_blocks[0].attn2), ("input", id))
    set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn2), ("middle", 0))
    for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention
        set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.output_blocks[id][1].transformer_blocks[0].attn2), ("output", id))
else:
    for id in [4,5,7,8]: # id of input_blocks that have cross attention
        block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
        for index in block_indices:
            set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.input_blocks[id][1].transformer_blocks[index].attn2), ("input", id, index))
    for index in range(10):
        set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.middle_block[1].transformer_blocks[index].attn2), ("middle", id, index))
    for id in range(6): # id of output_blocks that have cross attention
        block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth
        for index in block_indices:
            set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.output_blocks[id][1].transformer_blocks[index].attn2), ("output", id, index))

いやいくらなんでももっといい方法はないのかよ。

set_model_input_block_patch

input_blockの各出力に対してパッチをあてます。KohyaさんのDeep Shrink実装のために用意されたものです。skip connectionにも適用されます。
使用例:Deep Shrink HiresFix

set_model_input_block_patch_after_skip

こちらはskip connecition側に適用されません。

set_model_output_block_patch

output_blockの各出力に対してパッチをあてます。前層からの入力とskip connection側の入力を受け取って、何らかの改変をして出力する感じです。skip connectionはcontrolnetがすでに適用されています。

extra_options, transformer_options

説明中に、リンクをいっぱい貼り付けましたが、引数としてこれらが使われていると思います。これはパッチに渡す情報を示す辞書です。内容を見てみましょう。二つはだいたい内容が同じですが、extra_optionsについてはAttention層特有の情報が追加されています(headの次元とか)

"block":(block_name, block_id)のタプルです。UNetの位置を確認するために使えます。
"block_index": Attentionの何番目かで、SDXLでは使うかもしれません。
"original_shape": UNetへの入力サイズです。基本的には(バッチサイズ, 4, latent_height, latent_width)になります。(バッチサイズはcfgの場合設定の2倍になります。)
"cond_or_uncond": バッチの要素がcond(=0)なのか、uncond(=1)なのかを示す列です。通常の生成では(1, 0)になっているはずです。ConditionalCombine等の特殊なワークフローの場合要素が3つになったりします。
"sigmas": ノイズの強さです。バッチサイズで拡張されたテンソルであるため、実数値を見たい場合はtransformer_options["sigmas"][0].item()のようにします。model.model.model_sampling.percent_to_sigmaという[0,1]の時刻からsigmaへ置き換える関数と組み合わせると、時刻を制限したパッチがつくれます。

他にもなんかありそうですが、使ったことないです。

set_model_unet_function_wrapper

最終手段として、UNetの関数をらっぷしちゃいます。適用位置はここapply_modelとUNetへの入力情報を受け取って、ノイズが除去された潜在変数を返します。ノイズ予測ではないことに注意。
apply_modelとUNet.forward()を自前で作ってしまえば、好き放題できることになります。ただComfyUIのアップデートに対応しづらい実装なので、なるべくやりたくないですけどね。
適用例:DeepCache

またこのラッパーを使ってモデルを改変⇒元に戻すことで安全に何らかの処理が適用できたりもします。cd-tunerではそれをやっています。