見出し画像

ControlNet-LLLite学習メモ①

2023/08/26追記
そもそも環境構築を間違っていたので、学習できてないのは当たり前でした。
環境構築に関しては
こちらの記事をご参照ください。

『ラフを線画にしてくれる夢のようなAIアプリが欲しい!』そう思ったことはありませんか?私はあります。
生成AIでどうこうせずとも別のメカニズムで似たようなものはあるのですがいまいち私の求めている精度のものは見つけられませんでした。
『ないならば、作って見せよう新モデル』という訳で無謀にも個人でそれにチャレンジした記録です。
上手くいく保証はないのでうまくいかなかったり途中で終わっていても生暖かい目で見守ってください

そもそもControlNet-LLLiteは何か

開発者のKohyaさん曰く、『画風や絵柄は覚えず制御のみを覚える軽量なControlNet』とのことです。
つまりどういうことかと言うと『モデルの中にある概念の制御のみ覚えられる』ということだそう、つまりどういうことだぜ?

以前私が作ったControlNetモデルは線画から陰影やノーマルマップを生成するものでした。

前回のControlNetは画風も覚えてくれるものでしたが今回のControlNet-LLLiteは『線画→陰影(ノーマルマップ)』に変換する制御部分だけ覚える、ということです。
私は今回、まず自分の描いた線画50枚用意し、有志のフォロワーを募って、フォロワーさんご自身で制作された388枚の線画データを用意しました。神フォロワーさんに感謝。
それでもControlNet-LLLiteの学習には足りなかったので(とりあえず1000枚くらいあるといいみたい)、生成AIから画像をランダムに出力しかさ増ししました。あと有料の背景の線画素材とかも購入してミックス。
そんな感じで線画1088枚を用意しました。
とは言え、もともと別作者が描いた線画、色んなブレがあるのでそれを極力同一規格にする必要があります。というわけで使わせてもらったのが以下のリポジトリ。

導入方法は割愛するとして以下のようなスクリプトで一括で線画を抽出します。

import os
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from PIL import Image
import fnmatch
import cv2

import sys

import numpy as np
import gc #torch .set_printoptions(precision=10)


class _bn_relu_conv(nn.Module):
    def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
        super(_bn_relu_conv, self).__init__()
        self.model = nn.Sequential(
            nn.BatchNorm2d(in_filters, eps=1e-3),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
        )

    def forward(self, x):
        return self.model(x)

        # the following are for debugs
        print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
        for i,layer in enumerate(self.model):
            if i != 2:
                x = layer(x)
            else:
                x = layer(x)
                #x  = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0)
            print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
            print(x[0])
        return x


class _u_bn_relu_conv(nn.Module):
    def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
        super(_u_bn_relu_conv, self).__init__()
        self.model = nn.Sequential(
            nn.BatchNorm2d(in_filters, eps=1e-3),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
            nn.Upsample(scale_factor=2, mode='nearest')
        )

    def forward(self, x):
        return self.model(x)



class _shortcut(nn.Module):
    def __init__(self, in_filters, nb_filters, subsample=1):
        super(_shortcut, self).__init__()
        self.process = False
        self.model = None
        if in_filters != nb_filters or subsample != 1:
            self.process = True
            self.model = nn.Sequential(
                    nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
                )

    def forward(self, x, y):
        #print (x.size(), y.size(), self.process)
        if self.process:
            y0 = self.model(x)
            #print ("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
            return y0 + y
        else:
            #print ("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
            return x + y

class _u_shortcut(nn.Module):
    def __init__(self, in_filters, nb_filters, subsample):
        super(_u_shortcut, self).__init__()
        self.process = False
        self.model = None
        if in_filters != nb_filters:
            self.process = True
            self.model = nn.Sequential(
                nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
                nn.Upsample(scale_factor=2, mode='nearest')
            )

    def forward(self, x, y):
        if self.process:
            return self.model(x) + y
        else:
            return x + y


class basic_block(nn.Module):
    def __init__(self, in_filters, nb_filters, init_subsample=1):
        super(basic_block, self).__init__()
        self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
        self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
        self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.residual(x1)
        return self.shortcut(x, x2)

class _u_basic_block(nn.Module):
    def __init__(self, in_filters, nb_filters, init_subsample=1):
        super(_u_basic_block, self).__init__()
        self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
        self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
        self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)

    def forward(self, x):
        y = self.residual(self.conv1(x))
        return self.shortcut(x, y)


class _residual_block(nn.Module):
    def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
        super(_residual_block, self).__init__()
        layers = []
        for i in range(repetitions):
            init_subsample = 1
            if i == repetitions - 1 and not is_first_layer:
                init_subsample = 2
            if i == 0:
                l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
            else:
                l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
            layers.append(l)

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class _upsampling_residual_block(nn.Module):
    def __init__(self, in_filters, nb_filters, repetitions):
        super(_upsampling_residual_block, self).__init__()
        layers = []
        for i in range(repetitions):
            l = None
            if i == 0: 
                l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
            else:
                l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
            layers.append(l)

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class res_skip(nn.Module):

    def __init__(self):
        super(res_skip, self).__init__()
        self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input)
        self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0)
        self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1)
        self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2)
        self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3)
        
        self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4)
        self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1))

        self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1)
        self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1))

        self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2)
        self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1))

        self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3)
        self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1))

        self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4)
        self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7)

    def forward(self, x):
        x0 = self.block0(x)
        x1 = self.block1(x0)
        x2 = self.block2(x1)
        x3 = self.block3(x2)
        x4 = self.block4(x3)

        x5 = self.block5(x4)
        res1 = self.res1(x3, x5)

        x6 = self.block6(res1)
        res2 = self.res2(x2, x6)

        x7 = self.block7(res2)
        res3 = self.res3(x1, x7)

        x8 = self.block8(res3)
        res4 = self.res4(x0, x8)

        x9 = self.block9(res4)
        y = self.conv15(x9)

        return y

class MyDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        
    def get_class_label(self, image_name):
        # your method here
        head, tail = os.path.split(image_name)
        #print (tail)
        return tail
        
    def __getitem__(self, index):
        image_path = self.image_paths[index]
        x = Image.open(image_path)
        y = self.get_class_label(image_path.split('/')[-1])
        if self.transform is not None:
            x = self.transform(x)
        return x, y
    
    def __len__(self):
        return len(self.image_paths)

def loadImages(folder):
    imgs = []
    matches = []
    for root, dirnames, filenames in os.walk(folder):
        for filename in fnmatch.filter(filenames, '*'):
            matches.append(os.path.join(root, filename))
   
    return matches

if __name__ == "__main__":
    model = res_skip()
    model.load_state_dict(torch.load("C:/Github/Rough2Line_SDXL/Models/manga_line/erika.pth"))
    is_cuda = torch.cuda.is_available()
    if is_cuda:
        model.cuda()
    else:
        model.cpu()
    model.eval()

    filelists = loadImages(sys.argv[1])
    batch_size = 5

    with torch.no_grad():
        for i, imname in enumerate(filelists):
            src = cv2.imread(imname, cv2.IMREAD_GRAYSCALE)
            
            # 短辺が1024になるようにリサイズ
            height, width = src.shape
            aspect_ratio = width / height
            if height < width:
                new_height = 1024
                new_width = int(new_height * aspect_ratio)
            else:
                new_width = 1024
                new_height = int(new_width / aspect_ratio)
            src = cv2.resize(src, (new_width, new_height))
                
            rows = int(np.ceil(src.shape[0]/16))*16
            cols = int(np.ceil(src.shape[1]/16))*16

            # manually construct a batch. You can change it based on your usecases.
            patch = np.ones((1, 1, rows, cols), dtype="float32")
            patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src

            if is_cuda:
                tensor = torch.from_numpy(patch).cuda()
            else:
                tensor = torch.from_numpy(patch).cpu()

            y = model(tensor)
            print(imname, torch.max(y), torch.min(y))

            yc = y.cpu().numpy()[0, 0, :, :]
            yc[yc > 255] = 255
            yc[yc < 0] = 0

            head, tail = os.path.split(imname)
            cv2.imwrite(sys.argv[2] + "/" + tail.replace(".jpg", ".png"), yc[0:src.shape[0], 0:src.shape[1]])

            if (i + 1) % batch_size == 0:
                torch.cuda.empty_cache()  # 5枚ごとにメモリを解放

画像の短辺の長さを1024にし、AIに線画を抽出させるだけのスクリプトです。

Line_resize.py "線画にしたい画像が入ったフォルダ" "線画出力フォルダ"

のように使います。こんな感じの絵が・・・

生成AIによって生成された画像
抽出された線画

こんな感じに画像のグレースケールな部分も白抜きにしてくれるた方が今回の用途だと道具としては使いやすい(あとから絵を描く人が影付をした方が良い為)ので、こんな感じで1088枚の画像の線画を抽出します。

次はこの画像をフラットなラインに変換します。何故そうしたかというといきなりラフスケッチをベースに推論するより簡単そう(主観)と感じたからです。というわけで以下のスクリプトで線画を変換。

import cv2
import numpy as np
from PIL import Image
import os
import glob


def denoise_image(image):
    # Convert image to RGB if it has an alpha channel
    if image.mode == 'RGBA':
        image = image.convert('RGB')

    image = np.array(image)

    # Ensure the data type is uint8
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)

    denoised_image = cv2.GaussianBlur(image, (7, 7), 0)
    return Image.fromarray(denoised_image)

def binarize_image(image):
    image = image.convert('L')
    image_np = np.array(image)
    _, binarized = cv2.threshold(image_np, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return Image.fromarray(binarized)

def skeletonize_and_dilate_image(image):
    # Convert to grayscale and then binarize the image
    image = image.convert('L')
    image_np = np.array(image)
    _, binarized = cv2.threshold(image_np, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Invert colors (thinning function expects white lines on black background)
    inverted = cv2.bitwise_not(binarized)

    # Skeletonize
    skeleton = cv2.ximgproc.thinning(inverted, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN)
    
    # Dilate to make the line 3 pixels wide
    kernel = np.ones((3, 3), np.uint8)
    dilated = cv2.dilate(skeleton, kernel, iterations=1)
    
    # Re-invert colors to black lines on white background
    dilated = cv2.bitwise_not(dilated)

    return Image.fromarray(dilated)

def process_images_in_directory(input_directory, output_directory):
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    for file_ext in ['png', 'jpg']:
        for input_path in glob.glob(f'{input_directory}/*.{file_ext}'):
            file_name = os.path.splitext(os.path.basename(input_path))[0]
            output_path = os.path.join(output_directory, file_name + '.png')

            image = Image.open(input_path)
            denoised_image = denoise_image(image)
                      
            binarized_image = binarize_image(denoised_image)

            skeletonized_image = skeletonize_and_dilate_image(binarized_image)

            skeletonized_image.save(output_path, format='PNG')

# Parameters
input_directory = 'D:/desktop/test3'   # Path of the input directory
output_directory = 'D:/desktop/test4'  # Path of the output directory

# Call the main function to execute the processing
process_images_in_directory(input_directory, output_directory)


うんうん、雑に単純化されています。
というわけで。二種類の線画(元線画とフラットな線画)を1088セット用意していざControlNet-LLLiteを学習させました。10epochで効果がでてくるとのことでまず10epochぶん回してみたのですが……。
結論から言うとモデル作成失敗でした。失敗した理由はおそらく『ControlNet-LLLiteはあくまで制御まで覚えるだけのもので絵柄概念自体は覚えない、そのモデルの中にない概念は覚えにくい』とのこと。
つまりどういうことか結論からいうと、
学習用の線画データセットはControlNet-LLLiteのベースにするモデルから生成したものが良い
とのことです。
つまり今回の用途だと自前で用意した線画が仇になったということです。そ、そんな~!!!!

つまりControlNet-LLLiteでラフ→線画ControlNet-LLLiteモデルを作るには以下のワークフローが必要になるということです。
①なんらかの手段で線画LoRAを学習(候補:LECOか今回不発に終わった線画画像ファイルでLoRA学習。あるいはその両方)
②↑のLoRAをSDモデルにマージ
③↑のマージ済みSDモデルから1000枚線画ベース画像出力
④↑の画像から線画を抽出
⑤↑の画像からラフ(あるいはフラットな線)を生成
⑥↑上の線画とラフからControlNet-LLLite学習

……まずは線画LoRAの作成が必要になりそうです。
LoRAの作成は別に必須ではないのですが、安定して線画にしやすい画像を出力するには作った方が良いと思われる為です。出力安定の為ですな。
(もしかしたら不要かもしれんが検証のためにやる!)
最近ファインチューニング関係のことを勉強していなかったので良い機会になりそうです。

以上、想像だらけの第一回報告でした!

この記事が気に入ったらサポートをしてみませんか?