見出し画像

CLIPモデルで画像特徴点の抽出とElasticsearchで類似画像検索

類似画像検索システムを検討するにあたってCLIP(2021年2月にOpenAIによって公開された,言語と画像のマルチモーダルモデル)を試してみました。

1.Elasticsearchのマッピング定義

import json
from elasticsearch import Elasticsearch
es = Elasticsearch("http://0.0.0.0:9200")
# インデックス名
index_name = "test_index"
# インデックスを削除
#response = es.indices.delete(index=index_name)
mapping = {
    "mappings": {
      "properties": {
        "metadata": {
          "properties": {
            "image_code": {
              "type": "text",
              "fields": {
                "keyword": {
                  "type": "keyword",
                  "ignore_above": 256
                }
              }
            },  
              
          }
        },
        "vector": {
          "type": "dense_vector",
          "dims": 512,
          "index": True,
          "similarity": "cosine"
        }
      }
    }
  }
# マッピングを作成
response = es.indices.create(index=index_name, body=mapping)

2.CLIPで画像のベクトル化とElasticsearchへのデータ投入

%%time
import os
import numpy as np
from PIL import Image
import torch
from clip import clip
import json
from elasticsearch import Elasticsearch

# CLIPモデルのロード
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 画像のベクトル化
def extract_features(image_path):
    # 画像の前処理
    image = Image.open(image_path)
    image = preprocess(image).unsqueeze(0).to(device)
    # 特徴量の抽出
    with torch.no_grad():
        image_features = model.encode_image(image).cpu().numpy()
    return image_features / np.linalg.norm(image_features)

es = Elasticsearch("http://0.0.0.0:9200")

# rowsは画像パスのコードが格納された配列
for row in rows:
    image_path = "../image/"+row["image_code"]+".jpg"
    if os.path.isfile(image_path):        
        features = extract_features(image_path)

        # 登録するデータ
        data = {
            "metadata":{
                "image_code": str(row["image_code"]),
            },
            "vector":features[0]
        }
        response = es.index(index="test_index",  body=data)

3.類似画像検索

%%time
import pprint
import sys
import os
import numpy as np
from PIL import Image
import torch
from clip import clip
import json
from elasticsearch import Elasticsearch

# CLIPモデルのロード
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 画像のベクトル化
def extract_features(image_path):
    # 画像の前処理
    image = Image.open(image_path)
    image = preprocess(image).unsqueeze(0).to(device)
    # 特徴量の抽出
    with torch.no_grad():
        image_features = model.encode_image(image).cpu().numpy()
    return image_features / np.linalg.norm(image_features)

image_path = "./test.jpg"
query_features = extract_features(image_path)

#==========================================================
es = Elasticsearch("http://0.0.0.0:9200")

# ドキュメントを検索するためのクエリ
# Elasticsearchが負のスコアを許容しないため+1
q = {
    "size": 20,
    "query": {
        "script_score": {
          "query": {"match_all": {}},
          "script": {
            "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0 ",
              "params": {"query_vector": query_features[0]}
          }
        }
    }
}


# ドキュメントを検索
result = es.search(index="test_index", body=q)

# 検索結果からドキュメントの内容のみ表示
docs=[]
content = []
for document in result["hits"]["hits"]:
    content.append([document["_score"], document["_source"]["metadata"]["image_code"]])
        
es.close()
print(content)