見出し画像

ComfyUIプラグインの書き方/DALL-EプラグインとMemeplexプラグイン

仕事で本格的にやろうって時にComfyUIから各種作画APIにアクセスできた方が便利なので以下のページを参考にやってみた。意外なところでハマったのでそこもメモ

ComfyUIプラグインのディレクトリ

ComfyUIのプラグインは、その名もcustom_nodeというディレクトリに置く。

ここに自分の作りたいプラグインの名前でフォルダを掘って__init.pyとnodes.pyを作る。

今回はMemeplexプラグインが仕事で必要になったのでmemeplexという名前にしたけど、なんでも良い。

ファイルの中身はこんな感じ

# __init__.py
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

これはお約束っぽい
本体はnodes.pyに書く。

ComfyUIノードの受け取りと受け渡し

ComfyUIはノードに入力を受けて他のノードに加工したデータを受け渡す。
いわゆるデータプローで作業を記述する。

先の例の一部抜粋で恐縮だが、テキストボックスの内容を出力するだけのコードはこうなる


class TextInput:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"text": ("STRING", {"multiline": True})}}
    RETURN_TYPES = ("STRING",)
    FUNCTION = "run"
    CATEGORY = "Memeplex"

    def run(self, text, seed = None):
        return (text,)

INPUT_TYPE(s)というメソッドの返り値として必須の引数(required)か、オプション引数(optional)を指定し、それぞれの型も指定できる。RETURN_TYPESはタプルで、"STRING"なら文字列を出力するし"IMAGE"なら画像を出力する。

他にもこんな感じで整数のパラメータを設定させることができる。


class MemeplexRender:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                             "prompt": ("STRING", {"forceInput": True}),
                             "width": ("INT", {"default": 512, "min": 512, "max": 1024, "step": 64}),
                             "height": ("INT", {"default": 512, "min": 512, "max": 1024, "step": 64}),
                             "qty": ("INT", {"default": 9, "min": 1, "max": 9, "step": 1}),
                             }
                             ,
                "optional":{
                            "negative": ("STRING", {"forceInput": True}),
                            }

                }
                

デフォルト値、最大値、最小値、ステップ数などが指定できる。

文字列を指定して選ばせたい場合は、タプルとして文字列リストを与える。

class MemeplexRender:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                             "prompt": ("STRING", {"forceInput": True}),
                             "width": ("INT", {"default": 512, "min": 512, "max": 1024, "step": 64}),
                             "height": ("INT", {"default": 512, "min": 512, "max": 1024, "step": 64}),
                             "qty": ("INT", {"default": 9, "min": 1, "max": 9, "step": 1}),
                             "model": (["trinart","StableDiffusion-v1-5","StableDiffusion-v2-0"],)
                             }
                             ,
                "optional":{
                            "negative": ("STRING", {"forceInput": True}),
                            }

                }

するとリストから選べるようになる。

あとは、実行されるときに呼び出されるメソッド名の指定や出力するかどうか、ノードはどのカテゴリーに入れるかを指定する。

    OUTPUT_NODE = True #出力ノードあり 
    RETURN_TYPES = ("IMAGE",) #画像を出力する

    FUNCTION  = "run" # runというメソッドを呼ぶ
    CATEGORY = "Memeplex" # カテゴリはMemeplexにする

ここで作ったクラスを最後に登録する。

NODE_CLASS_MAPPINGS = {
    "TextInput": TextInput,
    "MemeplexCustomSDXLRender": MemeplexCustomSDXLRender,
    "MemeplexRender": MemeplexRender,
    "DallERender":DallERender
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "TextInput": "TextInput",
    "MemeplexCustomSDXLRender": "MemeplexCustomSDXLRender",
    "MemeplexRender":"MemeplexRender",
    "DallERender":"DallERender"
}

クラス名とクラスのマッピングと表示されるノード名のマッピングを設定する。

ここまでは特にハマらない。ハマる要素がない。

画像を受け取り、画像を受け渡す

今回一番ハマったのは画像の受け取りと受け渡しだった。

画像の読み込みはいろんなやり方があるが、普通は[バッチ,高さ,幅,3]というバッチ形式で渡すのかと思っていたのだが、どうやら[1,バッチ,高さ,幅,3]という形式で渡さないといけないらしい。

以下はDALL-EのAPIを呼んで、受け取ったURLを何度か読みに行き、描画が完了したら次のノードに画像を渡すというシンプルなコード

def generate_image_DALLE(prompt):
    response = client.images.generate(
        model="dall-e-3",
        prompt=prompt,
        size="1024x1024",
        quality="standard",
        n=1,
    )
    return(response.data[0].url)

#中略

    def run(self, prompt):
        print(prompt)
        urls=[generate_image_DALLE(prompt)]
        print("request queing,wait 10sec")

        time.sleep(3)
        output_images=[]
        for url in urls:
            while True:
                try:
                    i=Image.open(io.BytesIO(requests.get(url).content))
                    if i.width>0:
                        break
                    print("retry")
                except Exception as e:
                    print(e)
                    print("retry in 30sec")
                time.sleep(30)
            i = ImageOps.exif_transpose(i)
            if i.mode == 'I':
                i = i.point(lambda i: i * (1 / 255))
            image = i.convert("RGB")
            image = np.array(image).astype(np.float32) / 255.0
            image = torch.from_numpy(image)[None,]
            print(image.shape)
            output_images.append(image)
        output_image = torch.cat(output_images, dim=0)
        return (output_image,)

output_imagesをcatしてoutput_imageになる(複数形が単数系になる)という奇妙な記法は、ComfyUIのお作法(?)に準じる。というかComfyUIのデフォルトのloadImageが(なぜか)そのようになってる。

これにハマって30分くらい唸ってしまった。バッチだけ渡すと白黒画像とみなされて変な絵が出てくる。

というわけで何かできた。
プラグインを作れるようになっておくと、プロンプトを自動生成するようなものを作ったり、色々使い道があると思うので作れるようになっといた方が得だし、書くのは癖はあるけどめちゃくちゃ簡単。ただし、ドキュメントがほぼないので手探りになる。ComfyUI/nodes.pyが一番参考になる。

ソースコードは以下