Edit

Share via


Use ai.classify with PySpark

The ai.classify function uses generative AI to categorize input text according to custom labels you choose, with a single line of code.

Note

Overview

The ai.classify function is available for Spark DataFrames. You must specify the name of an existing input column as a parameter, along with a list of classification labels.

The function returns a new DataFrame with labels that match each row of input text, stored in an output column.

Syntax

df.ai.classify(labels=["category1", "category2", "category3"], input_col="text", output_col="classification")

Parameters

Name Description
labels
Required
An array of strings that represents the set of classification labels to match to text values in the input column.
input_col
Required
A string that contains the name of an existing column with input text values to classify according to the custom labels.
output_col
Optional
A string that contains the name of a new column where you want to store a classification label for each input text row. If you don't set this parameter, a default name is generated for the output column.
error_col
Optional
A string that contains the name of a new column. The new column stores any OpenAI errors that result from processing each row of input text. If you don't set this parameter, a default name is generated for the error column. If there are no errors for a row of input, the value in this column is null.

Returns

The function returns a Spark DataFrame that includes a new column that contains classification labels that match each input text row. If a text value can't be classified, the corresponding label is null.

Example

# This code uses AI. Always review output for mistakes. 

df = spark.createDataFrame([
        ("This duvet, lovingly hand-crafted from all-natural fabric, is perfect for a good night's sleep.",),
        ("Tired of friends judging your baking? With these handy-dandy measuring cups, you'll create culinary delights.",),
        ("Enjoy this *BRAND NEW CAR!* A compact SUV perfect for the professional commuter!",)
    ], ["descriptions"])
    
categories = df.ai.classify(labels=["kitchen", "bedroom", "garage", "other"], input_col="descriptions", output_col="categories")
display(categories)

This example code cell provides the following output:

Screenshot of a data frame with 'descriptions' and 'category' columns. The 'category' column lists each description’s category name.