Udostępnij przez


struktura DML_MULTIHEAD_ATTENTION_OPERATOR_DESC (directml.h)

Wykonuje operację uwagi wielogłowej (aby uzyskać więcej informacji, zobacz Uwaga to wszystko, czego potrzebujesz). Dokładnie jeden tensor zapytania, klucza i wartości musi być obecny, niezależnie od tego, czy są one ułożone. Na przykład jeśli podano wartość StackedQueryKey , zarówno tensory zapytania , jak i klucza muszą mieć wartość null, ponieważ są one już podane w układzie skumulowanym. Dotyczy to również wartości StackedKeyValue i StackedQueryKeyValue. Skumulowane tensory zawsze mają pięć wymiarów i są zawsze ułożone na czwartym wymiarze.

Logicznie algorytm można podzielić na następujące operacje (operacje w nawiasach kwadratowych są opcjonalne):

[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);

Ważne

Ten interfejs API jest dostępny w ramach autonomicznego pakietu redystrybucyjnego DirectML (zobacz Microsoft.AI.DirectML w wersji 1.12 lub nowszej. Zobacz również historię wersji języka DirectML.

Składnia

struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
    _Maybenull_ const DML_TENSOR_DESC* QueryTensor;
    _Maybenull_ const DML_TENSOR_DESC* KeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* ValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* MaskTensor;
    _Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
    const DML_TENSOR_DESC* OutputTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
    FLOAT Scale;
    FLOAT MaskFilterValue;
    UINT HeadCount;
    DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};

Członkowie

QueryTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Wykonaj zapytanie o kształt [batchSize, sequenceLength, hiddenSize], gdzie hiddenSize = headCount * headSize. Ten tensor wzajemnie wyklucza się z StackedQueryKeyTensor i StackedQueryKeyValueTensor. Tensor może również mieć 4 lub 5 wymiarów, o ile wiodące wymiary to 1.

KeyTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Klucz z kształtem [batchSize, keyValueSequenceLength, hiddenSize], gdzie hiddenSize = headCount * headSize. Ten tensor wzajemnie wyklucza się z StackedQueryKeyTensor, StackedKeyValueTensor i StackedQueryKeyValueTensor. Tensor może również mieć 4 lub 5 wymiarów, o ile wiodące wymiary to 1.

ValueTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Wartość z kształtem [batchSize, keyValueSequenceLength, valueHiddenSize], gdzie valueHiddenSize = headCount * valueHeadSize. Tensor wzajemnie wyklucza się z StackedKeyValueTensor i StackedQueryKeyValueTensor. Tensor może również mieć 4 lub 5 wymiarów, o ile wiodące wymiary to 1.

StackedQueryKeyTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Skumulowane zapytanie i klucz z kształtem [batchSize, sequenceLength, headCount, 2, headSize]. Ten tensor wzajemnie wyklucza się z queryTensor, KeyTensor, StackedKeyValueTensor i StackedQueryKeyValueTensor.

Układ StackedQueryKeyTensor

StackedKeyValueTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Skumulowany klucz i wartość z kształtem [batchSize, keyValueSequenceLength, headCount, 2, headSize]. Ten tensor wzajemnie wyklucza się z keyTensor, ValueTensor, StackedQueryKeyTensor i StackedQueryKeyValueTensor.

Układ StackedKeyValueTensor

StackedQueryKeyValueTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Skumulowane zapytanie, klucz i wartość z kształtem [batchSize, sequenceLength, headCount, 3, headSize]. Ten tensor wzajemnie wyklucza się z queryTensor, KeyTensor, ValueTensor, StackedQueryKeyTensor i StackedKeyValueTensor.

Układ StackedQueryKeyValueTensor

BiasTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Jest to stronniczość kształtu [hiddenSize + hiddenSize + valueHiddenSize], który jest dodawany dowartości// przed pierwszą operacją GEMM. Tensor może mieć również 2, 3, 4 lub 5 wymiarów, o ile wiodące wymiary to 1.

MaskTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Jest to maska określająca, które elementy otrzymują ich wartość ustawioną na MaskFilterValue po operacji QxK GEMM. Zachowanie tej maski zależy od wartości MaskType i jest stosowane po relativePositionBiasTensor lub po pierwszej operacji GEMM, jeśli parametr RelativePositionBiasTensor ma wartość null. Aby uzyskać więcej informacji, zobacz definicję MaskType .

RelativePositionBiasTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Jest to stronniczość dodawana do wyniku pierwszej operacji GEMM.

PastKeyTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Tensor klucza z poprzedniej iteracji z kształtem [batchSize, headCount, pastSequenceLength, headSize]. Gdy tensor nie ma wartości null, jest on połączony z tensorem klucza, co powoduje utworzenie tensoru kształtu [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize].

PastValueTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Wartość tensor z poprzedniej iteracji z kształtem [batchSize, headCount, pastSequenceLength, headSize]. Gdy tensor nie ma wartości null, jest on połączony z wartością ValueDesc , co powoduje, że tensor kształtu [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize].

OutputTensor

Typ: const DML_TENSOR_DESC*

Dane wyjściowe kształtu [batchSize, sequenceLength, valueHiddenSize].

OutputPresentKeyTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Obecny stan dla klucza uwagi krzyżowego, z kształtem [batchSize, headCount, keyValueSequenceLength, headSize] lub stanem obecnym dla siebie z kształtem [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]. Zawiera on zawartość tensor klucza lub zawartość połączonych tensorkluczy + do przekazania do następnej iteracji.

OutputPresentValueTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Obecny stan wartości uwagi krzyżowej z kształtem [batchSize, headCount, keyValueSequenceLength, headSize] lub stanem obecnym dla siebie z kształtem [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]. Zawiera on zawartość tensoru wartości lub zawartość połączonych tensorwartości + do przekazania do następnej iteracji.

Scale

Typ: FLOAT

Skaluj, aby pomnożyć wynik operacji QxK GEMM, ale przed operacją Softmax. Ta wartość jest zwykle 1/sqrt(headSize)wartością .

MaskFilterValue

Typ: FLOAT

Wartość dodawana do wyniku operacji QxK GEMM do pozycji zdefiniowanych przez maskę zdefiniowaną jako elementy dopełniania. Ta wartość powinna być bardzo dużą liczbą ujemną (zwykle -10000,0f).

HeadCount

Typ: UINT

Liczba głów uwagi.

MaskType

Typ: DML_MULTIHEAD_ATTENTION_MASK_TYPE

Opisuje zachowanie biblioteki MaskTensor.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN. Gdy maska zawiera wartość 0, dodawana jest wartość MaskFilterValue ; ale jeśli zawiera wartość 1, nic nie zostanie dodane.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH. Maska kształtu [1, batchSize]zawiera długości sekwencji nieupadowanego obszaru dla każdej partii, a wszystkie elementy po długości sekwencji mają ustawioną wartość MaskFilterValue.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START. Maska kształtu [2, batchSize]zawiera indeksy końcowe (wykluczające) i początkowe (włącznie) obszaru, a wszystkie elementy poza obszarem mają ustawioną wartość MaskFilterValue.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END. Maska kształtu [batchSize * 3 + 2]ma następujące wartości: [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]].

Dostępność

Ten operator został wprowadzony w DML_FEATURE_LEVEL_6_1.

Ograniczenia dotyczące tensorów

BiasTensor, KeyTensor, OutputPresentKeyTensor, OutputPresentValueTensor, OutputTensor, PastKeyTensor, PastValueTensor, QueryTensor, RelativePositionBiasTensor, StackedKeyValueTensor, StackedQueryKeyTensor, StackedQueryKeyValueTensor i ValueTensor muszą mieć ten sam typ danych.

Obsługa biblioteki Tensor

Tensor Rodzaj Obsługiwane liczby wymiarów Obsługiwane typy danych
QueryTensor Opcjonalne dane wejściowe Od 3 do 5 FLOAT32, FLOAT16
Tensor Opcjonalne dane wejściowe Od 3 do 5 FLOAT32, FLOAT16
Tensor wartości Opcjonalne dane wejściowe Od 3 do 5 FLOAT32, FLOAT16
StackedQueryKeyTensor Opcjonalne dane wejściowe 5 FLOAT32, FLOAT16
StackedKeyValueTensor Opcjonalne dane wejściowe 5 FLOAT32, FLOAT16
StackedQueryKeyValueTensor Opcjonalne dane wejściowe 5 FLOAT32, FLOAT16
Tensor odchylenia Opcjonalne dane wejściowe Od 1 do 5 FLOAT32, FLOAT16
MaskTensor Opcjonalne dane wejściowe Od 1 do 5 INT32 powiedział:
Tensor RelativePositionBiasTensor Opcjonalne dane wejściowe Od 4 do 5 FLOAT32, FLOAT16
PastKeyTensor Opcjonalne dane wejściowe Od 4 do 5 FLOAT32, FLOAT16
Tensor PastValueTensor Opcjonalne dane wejściowe Od 4 do 5 FLOAT32, FLOAT16
Tensor wyjściowy Wynik Od 3 do 5 FLOAT32, FLOAT16
OutputPresentKeyTensor Opcjonalne dane wyjściowe Od 4 do 5 FLOAT32, FLOAT16
OutputPresentValueTensor Opcjonalne dane wyjściowe Od 4 do 5 FLOAT32, FLOAT16

Wymagania

   
Nagłówek directml.h