แชร์ผ่าน


สํารวจศิลปะข้ามวัฒนธรรมและสื่อด้วยอัลกอริทึมย่าน k-nearest Neighbors ที่รวดเร็วมีเงื่อนไข

บทความนี้อธิบายการหาค่าที่ตรงกันผ่านอัลกอริทึม k-nearest-Neighbors คุณสร้างทรัพยากรโค้ดที่อนุญาตให้มีการคิวรีที่เกี่ยวข้องกับวัฒนธรรมและสื่อการละทิ้งศิลปะจากพิพิธภัณฑ์ศิลปะมหานครนครนิวยอร์กและอัมสเตอร์ดัม Rijksmuseum

ข้อกำหนดเบื้องต้น

ภาพรวมของ BallTree

แบบจําลอง k-NN ขึ้นอยู่กับโครงสร้างข้อมูล BallTree BallTree เป็นแผนภูมิไบนารีแบบเรียกใช้ซ้ํา ซึ่งแต่ละโหนด (หรือ "ลูก") มีพาร์ติชันหรือชุดย่อยของจุดข้อมูลที่คุณต้องการคิวรี ในการสร้าง BallTree ให้กําหนดศูนย์ "ลูก" (ตามคุณลักษณะที่ระบุไว้) ที่ใกล้เคียงกับแต่ละจุดข้อมูลมากที่สุด จากนั้น กําหนดแต่ละจุดข้อมูลไปยัง "ลูกบอล" ที่ใกล้เคียงที่สุด งานที่ได้รับมอบหมายเหล่านั้นสร้างโครงสร้างที่ช่วยให้โครงการไบนารีต้นไม้เหมือนทางธรณีและให้ยืมตัวเองเพื่อหาเพื่อนบ้าน k-ที่ใกล้ที่สุดที่ใบ BallTree

ตั้งค่า

นําเข้าไลบรารี Python ที่จําเป็นและเตรียมชุดข้อมูล:

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

ชุดข้อมูลมาจากตารางที่ประกอบด้วยข้อมูลงานศิลปะจากทั้งพิพิธภัณฑ์ Met และ Rijksmuseum ตารางมี Schema นี้:

  • ID: ตัวระบุที่ไม่ซ้ํากันสําหรับแต่ละส่วนของอาร์ต
    • ID ตัวอย่างที่พบ: 388395
    • ตัวอย่าง Rijks ID: SK-A-2344
  • หัวข้อ: ชื่อชิ้นศิลปะ ดังที่เขียนในฐานข้อมูลของพิพิธภัณฑ์
  • ศิลปิน: ศิลปะชิ้นงานศิลปะ ที่เขียนในฐานข้อมูลของพิพิธภัณฑ์
  • Thumbnail_Url: ตําแหน่งที่ตั้งของรูปขนาดย่อ JPEG ของชิ้นงานศิลปะ
  • Image_Url ตําแหน่งที่ตั้ง URL ของเว็บไซต์ของภาพชิ้นงานศิลปะที่โฮสต์บนเว็บไซต์ Met/Rijks
  • วัฒนธรรม: หมวดหมู่วัฒนธรรมของชิ้นงานศิลปะ
    • หมวดหมู่วัฒนธรรมตัวอย่าง: ละตินอเมริกา, อียิปต์ ฯลฯ
  • การจําแนกประเภท: ประเภทกลางของชิ้นงานศิลปะ
    • ประเภทกลางตัวอย่าง: งานไม้, ภาพวาด, ฯลฯ
  • Museum_Page: ลิงก์ URL ไปยังชิ้นงานศิลปะที่โฮสต์บนเว็บไซต์ Met/Rijks
  • Norm_Features: การฝังรูปภาพชิ้นงานศิลปะ
  • พิพิธภัณฑ์: พิพิธภัณฑ์ที่จัดงานศิลปะชิ้นจริง
# loads the dataset and the two trained conditional k-NN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

หากต้องการสร้างคิวรี ให้กําหนดหมวดหมู่

ใช้แบบจําลอง k-NN สองแบบ: แบบหนึ่งแบบสําหรับวัฒนธรรม และอีกแบบหนึ่งคือแบบปานกลาง:

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

กําหนดและพอดีกับแบบจําลอง k-NN แบบมีเงื่อนไข

สร้างแบบจําลอง k-NN แบบมีเงื่อนไขสําหรับทั้งคอลัมน์สื่อและวัฒนธรรม แต่ละแบบจําลองใช้

  • คอลัมน์ผลลัพธ์
  • คอลัมน์คุณลักษณะ (เวกเตอร์คุณลักษณะ)
  • คอลัมน์ค่า (ค่าเซลล์ภายใต้คอลัมน์ผลลัพธ์)
  • คอลัมน์ป้ายชื่อ (คุณภาพของ k-NN ที่เกี่ยวข้องเป็นไปตามเงื่อนไข)
medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

กําหนดวิธีการจับคู่และการแสดงภาพ

หลังจากตั้งค่าชุดข้อมูลและประเภทเริ่มต้นแล้ว ให้เตรียมวิธีการในการคิวรีและแสดงภาพผลลัพธ์ของ k-NN แบบมีเงื่อนไข:

addMatches() สร้าง Dataframe ด้วยคู่ที่ตรงกันสําหรับแต่ละหมวดหมู่:

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() การโทร plot_img เพื่อแสดงภาพรายการที่ตรงกันที่สุดสําหรับแต่ละประเภทลงในเส้นตาราง:

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

รวมทุกอย่างเข้าด้วยกัน

ที่จะเข้าไปรับ

  • ข้อมูล
  • แบบจําลอง k-NN แบบมีเงื่อนไข
  • ค่า ID ศิลปะที่จะคิวรี
  • เส้นทางของไฟล์ที่บันทึกการแสดงภาพผลลัพธ์

กําหนดฟังก์ชันที่เรียกว่า test_all()

แบบจําลองขนาดกลางและวัฒนธรรมได้รับการฝึกและโหลดมาก่อนหน้านี้

# main method to test a particular dataset with two conditional k-NN models and a set of art IDs, saving the result to filename.png

def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

เวอร์ชันสาธิต

เซลล์ต่อไปนี้ดําเนินการคิวรีแบบกลุ่ม โดยกําหนดรหัสรูปภาพที่ต้องการและชื่อไฟล์เพื่อบันทึกการแสดงภาพ

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")