Consistency Models を Diffusers から試してみる

Consistency Models とは


OpenAI の発表した、1ステップや数ステップでいい感じに生成できるモデルのことです。OpenAI が珍しくちゃんとモデルをオープンにしたモデルです。

diffusers v0.18.0 で対応されたので試していきます。

本家 GitHub:

Diffusers の解説:

生成してみる


まあほぼ解説ページに載っているものと同じです。

無料版の Colab で動かしてみます。

%pip install diffusers
import torch
from diffusers import ConsistencyModelPipeline
device = "cuda"
model_id_or_path = "openai/diffusers-cd_cat256_lpips"

# Load the cd_cat256_lpips checkpoint.
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)

# Onestep Sampling
image = pipe(num_inference_steps=1).images[0]

image.show()

# Multistep sampling
# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83
image = pipe(num_inference_steps=None, timesteps=[17, 0]).images[0]

image.show()

ここで使っているモデルは、`openai/diffusers-cd_cat256_lpips` です。猫の256x256解像度画像が生成できます。Stable Diffusion のように、テキストでの指示はできません。

生成にかかった時間を見てもらうとわかると思いますが、マジで爆速で生成されています。

消費 VRAM は 3 GB 前後でした。

従来みたいな拡散モデルであれば、事前に決めたタイムステップから順にデノイズ(1000ステップなら、999, 998, 997… 2, 1, 0 という感じ)を行いますが、先程の生成では一気に 0 の時まで飛ばして画像を生成しています。

また、diffusers 版でも Multistep sampling が可能です。

Multistep sampling というのは、指定ステップから指定ステップまで飛ばして生成する機能です。

`num_inference_steps` を None に、`timesteps` に行きたいタイムステップを降順で指定すると、そのステップ順にデノイズしていくことができます。

たとえば、 39→38 にデノイズした結果を予測すると、

このようにノイズまみれのものが返ってきますが、次のようにちゃんと 0 までデノイズすればしっかりと猫の画像が返ってきます。

すべて同じシード

また、[39, 5, 0] のように指定することもできます。

PyTorch 2.0 の compile という機能を使うと、さらに高速に生成できるようです。

pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

初回だけ (おそらく実際にコンパイルが実行されているため) 数分待たされましたが、二回目以降は高速に生成されました。

しかし、もとから割りと速いのでどれくらい速くなったのかいまいちわかりません。

そこで、1000 回の 1 ステップ生成を行ってかかった時間を比較してみました。

import time

start_time = time.time()

for _ in range(1000):
    image = pipe(num_inference_steps=1).images[0]

end_time = time.time()
elapsed_time = end_time - start_time  
print(f'Elapsed time: {elapsed_time} seconds')  

compile() 前では約 211.63 秒でしたが、compile() したものは、 約 183.88 秒になりました。

元が高速すぎてあまり体感できるレベルではないですが、生成回数が増えれば大きな差になりそうですね。



この記事が気に入ったらサポートをしてみませんか?