# 画像検索

# はじめに

画像検索は、特徴や視覚的なコンテンツをマッチングして類似の画像を見つけることができるようになり、人気で強力なアプリケーションとなっています。コンピュータビジョンとディープラーニングの急速な発展により、この機能は大幅に向上しました。

このガイドは、最新の技術とツールを活用して画像検索を行うための手助けをすることを目的としています。このガイドでは、以下のことを学ぶことができます:

  • パブリックデータセットとモデルを使用してベクトル埋め込みを持つデータセットを作成する方法
  • MyScaleを使用して画像の類似性検索を行う方法。MyScaleは、検索プロセスを効率化し、高速かつ正確な結果を提供する強力なプラットフォームです。

もしMyScaleの機能を探索することに興味がある場合は、データセットの作成セクションをスキップして、MyScaleへのデータの追加セクションに直接進むこともできます。

このデータセットは、MyScaleコンソールでデータのインポートセクションで提供される手順に従ってインポートすることができます。インポートが完了したら、MyScaleのクエリセクションに直接進んで、このサンプルアプリケーションをお楽しみください。

# 前提条件

始める前に、clickhouse python client (opens new window)とHuggingFaceのdatasetsライブラリをインストールする必要があります。

pip install datasets clickhouse-connect

データセットの作成セクションで説明されている手順に従うためには、transformersおよびその他の必要な依存関係をインストールする必要があります。

pip install requests transformers torch tqdm

# データセットの作成

# データのダウンロードと処理

unsplash dataset (opens new window)からデータをダウンロードし、Liteデータセットを使用します。

wget https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip
# ダウンロードしたファイルを一時ディレクトリに展開する
unzip unsplash-research-dataset-lite-latest.zip -d tmp

ダウンロードしたデータを読み込み、Pandasのデータフレームに変換します。

import numpy as np
import pandas as pd
import glob
documents = ['photos', 'conversions']
datasets = {}
for doc in documents:
    files = glob.glob("tmp/" + doc + ".tsv*")
    subsets = []
    for filename in files:
        df = pd.read_csv(filename, sep='\t', header=0)
        subsets.append(df)
    datasets[doc] = pd.concat(subsets, axis=0, ignore_index=True)
df_photos = datasets['photos']
df_conversions = datasets['conversions']

# 画像の埋め込みの生成

画像から埋め込みを抽出するために、clip-vit-base-patch32 (opens new window)モデルを使用するextract_image_features関数を定義します。生成される埋め込みは512次元のベクトルです。

import torch
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def extract_image_features(image):
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model.get_image_features(**inputs)
        outputs = outputs / outputs.norm(dim=-1, keepdim=True)
    return outputs.squeeze(0).tolist()

その後、df_photosデータフレームから最初の1000個の写真IDを選択し、対応する画像をダウンロードし、extract_image_features関数を使用して画像の埋め込みを抽出します。

from PIL import Image
import requests
from tqdm.auto import tqdm
# 最初の1000個の写真IDを選択する
photo_ids = df_photos['photo_id'][:1000].tolist()
# 選択された写真IDのみを持つ新しいデータフレームを作成する
df_photos = df_photos[df_photos['photo_id'].isin(photo_ids)].reset_index(drop=True)
# データフレーム内の'photo_id'と'photo_image_url'の列のみを保持する
df_photos = df_photos[['photo_id', 'photo_image_url']]
# データフレームに新しい'photo_embed'列を追加する
df_photos['photo_embed'] = None
# 画像をダウンロードし、'extract_image_features'関数を使用して埋め込みを抽出する
for i, row in tqdm(df_photos.iterrows(), total=len(df_photos)):
    # 画像URLを変更して、サイズを小さくしてダウンロードするためのURLを構築する
    url = row['photo_image_url'] + "?q=75&fm=jpg&w=200&fit=max"
    try:
        res = requests.get(url, stream=True).raw
        image = Image.open(res)
    except:
        # 画像のダウンロードに失敗した場合は写真を削除する
        photo_ids.remove(row['photo_id'])
        continue
    # 埋め込みを抽出する
    df_photos.at[i, 'photo_embed'] = extract_image_features(image)

# データセットの作成

これで、写真情報と埋め込みを持つデータフレームと、変換情報を持つ別のデータフレームの2つのデータフレームが得られました。

df_photos = df_photos[df_photos['photo_id'].isin(photo_ids)].reset_index().rename(columns={'index': 'id'})
df_conversions = df_conversions[df_conversions['photo_id'].isin(photo_ids)].reset_index(drop=True)
df_conversions = df_conversions[['photo_id', 'keyword']].reset_index().rename(columns={'index': 'id'})

最後に、データフレームをParquetファイルに変換し、その後、myscale/unsplash-examples (opens new window)というHugging Faceのリポジトリにアップロードします。アップロードするための手順 (opens new window)に従うことで、データへの簡単なアクセスと共有が可能になります。

import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
# データとスキーマからTableオブジェクトを作成する
photos_table = pa.Table.from_pandas(df_photos)
conversion_table = pa.Table.from_pandas(df_conversions)
# テーブルをParquetファイルに書き込む
pq.write_table(photos_table, 'photos.parquet')
pq.write_table(conversion_table, 'conversions.parquet')

# MyScaleへのデータの追加

# データのロード

MyScaleにデータを追加するためには、まず前のセクションで作成したHuggingFaceのデータセットmyscale/unsplash-examples (opens new window)からデータをロードします。以下のコードスニペットは、データをロードしてPandasのデータフレームに変換する方法を示しています。

注意:photo_embedは、CLIP (opens new window)モデルを使用して画像から抽出された画像特徴を表す512次元の浮動小数点ベクトルです。

from datasets import load_dataset
photos = load_dataset("myscale/unsplash-examples", data_files="photos-all.parquet", split="train")
conversions = load_dataset("myscale/unsplash-examples", data_files="conversions-all.parquet", split="train")
# データセットをPandasのデータフレームに変換する
photo_df = photos.to_pandas()
conversion_df = conversions.to_pandas()
# photo_embedをnp配列からリストに変換する
photo_df['photo_embed'] = photo_df['photo_embed'].apply(lambda x: x.tolist())

# テーブルの作成

次に、MyScaleでテーブルを作成します。始める前に、MyScaleコンソールからクラスタのホスト、ユーザー名、パスワード情報を取得する必要があります。以下のコードスニペットは、写真情報用のテーブルと変換情報用のテーブルの2つのテーブルを作成する方法を示しています。

import clickhouse_connect
# クライアントの初期化
client = clickhouse_connect.get_client(host='YOUR_CLUSTER_HOST', port=443, username='YOUR_USERNAME', password='YOUR_CLUSTER_PASSWORD')
# 既に存在する場合はテーブルを削除する
client.command("DROP TABLE IF EXISTS default.myscale_photos")
client.command("DROP TABLE IF EXISTS default.myscale_conversions")
# 写真のテーブルを作成する
client.command("""
CREATE TABLE default.myscale_photos
(
    id UInt64,
    photo_id String,
    photo_image_url String,
    photo_embed Array(Float32),
    CONSTRAINT vector_len CHECK length(photo_embed) = 512
)
ORDER BY id
""")
# 変換のテーブルを作成する
client.command("""
CREATE TABLE default.myscale_conversions
(
    id UInt64,
    photo_id String,
    keyword String
)
ORDER BY id
""")

# データのアップロード

テーブルを作成した後、データセットからロードしたデータをテーブルに挿入し、後のベクトル検索クエリを高速化するためのベクトルインデックスを作成します。以下のコードスニペットは、データをテーブルに挿入し、コサイン距離メトリックを使用してベクトルインデックスを作成する方法を示しています。

# データセットからデータをアップロードする
client.insert("default.myscale_photos", photo_df.to_records(index=False).tolist(),
              column_names=photo_df.columns.tolist())
client.insert("default.myscale_conversions", conversion_df.to_records(index=False).tolist(),
              column_names=conversion_df.columns.tolist())
# 挿入されたデータの数を確認する
print(f"photos count: {client.command('SELECT count(*) FROM default.myscale_photos')}")
print(f"conversions count: {client.command('SELECT count(*) FROM default.myscale_conversions')}")
# コサイン距離メトリックを使用してベクトルインデックスを作成する
client.command("""
ALTER TABLE default.myscale_photos 
ADD VECTOR INDEX photo_embed_index photo_embed
TYPE MSTG
('metric_type=Cosine')
""")
# ベクトルインデックスのステータスを確認する。ベクトルインデックスが'Ready'のステータスで準備完了であることを確認する
get_index_status="SELECT status FROM system.vector_indices WHERE name='photo_embed_index'"
print(f"index build status: {client.command(get_index_status)}")

# MyScaleへのクエリ

# 上位K個の類似画像の検索

ベクトル検索を使用して、以下の手順に従って上位K個の類似画像を見つけることができます。

まず、ランダムに画像を選択し、show_image()関数を使用して表示します。

import requests
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
# URLから画像をダウンロードする
def download(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content))
# URLを使用してオンラインの画像を表示するメソッドを定義する
def show_image(url, title=None):
    img = download(url)
    fig = plt.figure(figsize=(4, 4))
    plt.imshow(img)
    plt.show()
# 各テーブルの行数を表示する
print(f"photos count: {client.command('SELECT count(*) FROM default.myscale_photos')}")
print(f"conversions count: {client.command('SELECT count(*) FROM default.myscale_conversions')}")
# ターゲットとしてランダムな画像を選択する
random_image = client.query("SELECT * FROM default.myscale_photos ORDER BY rand() LIMIT 1")
assert random_image.row_count == 1
target_image_id = random_image.first_item["photo_id"]
target_image_url = random_image.first_item["photo_image_url"]
target_image_embed = random_image.first_item["photo_embed"]
print("currently selected image id={}, url={}".format(target_image_id, target_image_url))
# ターゲット画像を表示する
print("Loading target image...")
show_image(target_image_url)

サンプル画像:

次に、ベクトル検索を使用して選択した画像に最も類似した上位K個の候補を特定し、これらの候補を表示します。

# データベースにクエリを実行して、与えられた画像に最も類似した上位K個の画像を見つける
top_k = 10
results = client.query(f"""
SELECT photo_id, photo_image_url, distance(photo_embed, {target_image_embed}) as dist
FROM default.myscale_photos
WHERE photo_id != '{target_image_id}'
ORDER BY dist
LIMIT {top_k}
""")
# 画像をダウンロードしてリストに追加する
images_url = []
for r in results.named_results():
    # 画像URLを変更して、サイズを小さくしてダウンロードするためのURLを構築する
    url = r['photo_image_url'] + "?q=75&fm=jpg&w=200&fit=max"
    images_url.append(download(url))
# 候補画像を表示する
print("Loading candidate images...")
for row in range(int(top_k / 5)):
    fig, axs = plt.subplots(1, 5, figsize=(20, 4))
    for i, img in enumerate(images_url[row * 5:row * 5 + 5]):
        axs[i % 5].imshow(img)
    plt.show()

類似の候補画像:

# 各候補画像の変換情報の分析

上位K個の類似画像を特定した後、構造化フィールドとベクトルフィールドを組み合わせたSQLクエリを使用して、各候補の変換情報を分析することができます。

各候補画像の総変換数を計算するには、以下のSQLクエリを使用して画像検索の結果とconversionsテーブルを結合します。

# 各候補画像の総ダウンロード数を表示する
results = client.query(f"""
SELECT photo_id, count(*) as count
FROM default.myscale_conversions
JOIN (
    SELECT photo_id, distance(photo_embed, {target_image_embed}) as dist
    FROM default.myscale_photos
    ORDER BY dist ASC
    LIMIT {top_k}
    ) AS target_photos
ON default.myscale_conversions.photo_id = target_photos.photo_id
GROUP BY photo_id
ORDER BY count DESC
""")
print("Total downloads for each candidate")
for r in results.named_results():
    print("- {}: {}".format(r['photo_id'], r['count']))

サンプル出力:

Total downloads for each candidate
- Qgb9urMZ8lw: 1729
- f0OL01IHbCM: 1444
- Bgae-sqbe_g: 313
- XYg2zLjxxa0: 207
- BkW8I1n354I: 184
- 5yFOvJZp7Rg: 63
- sKPPBn6OkJg: 48
- joL0nSbZ-lI: 20
- fzDtQWW8dV4: 8
- DCAERnaj31U: 3

各候補画像の総変換数を計算した後、最もダウンロード数の多い候補画像を特定し、その画像のダウンロードキーワードごとの詳細な変換情報を調べることができます。以下のSQLクエリを使用します。

# 最も人気のある候補画像と関連するトップ5のダウンロードキーワードを表示する
most_popular_candidate = results.first_item['photo_id']
# 最も人気のある画像を表示する
candidate_url = client.command(f"""
SELECT photo_image_url FROM default.myscale_photos WHERE photo_id = '{most_popular_candidate}'
""")
print("Loading the most popular candidate image...")
show_image(candidate_url)
# 上位5つのダウンロードキーワードを検索する
results = client.query(f"""
SELECT keyword, count(*) as count
FROM default.myscale_conversions
WHERE photo_id='{most_popular_candidate}'
GROUP BY keyword
ORDER BY count DESC
LIMIT 5
""")
print("Related keywords and download counts for most popular candidate")
for r in results.named_results():
    print(f"- {r['keyword']}: {r['count']}") 

上位10の中で最も人気のある候補画像:

サンプル出力:

Related keywords and download counts for most popular candidate
- bee: 1615
- bees: 21
- bumblebee: 13
- honey: 13
- honey bee: 12
Last Updated: Sat Apr 13 2024 10:45:55 GMT+0000