Partager via


structure DML_MULTIHEAD_ATTENTION_OPERATOR_DESC (directml.h)

Effectue une opération d’attention multi-tête (pour plus d’informations, voir Attention est tout ce dont vous avez besoin). Exactement une requête, un capteur clé et valeur doit être présent, qu’ils soient empilés ou non. Par exemple, si StackedQueryKey est fourni, les tenseurs de requête et de clé doivent être null, car ils sont déjà fournis dans une disposition empilée. Il en va de même pour StackedKeyValue et StackedQueryKeyValue. Les tenseurs empilés ont toujours cinq dimensions et sont toujours empilés sur la quatrième dimension.

Logiquement, l’algorithme peut être décomposé dans les opérations suivantes (les opérations entre crochets sont facultatives) :

[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);

Important

Cette API est disponible dans le cadre du package redistribuable autonome DirectML (voir Microsoft.AI.DirectML version 1.12 et ultérieures. Consultez également l’historique des versions DirectML.

Syntaxe

struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
    _Maybenull_ const DML_TENSOR_DESC* QueryTensor;
    _Maybenull_ const DML_TENSOR_DESC* KeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* ValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* MaskTensor;
    _Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
    const DML_TENSOR_DESC* OutputTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
    FLOAT Scale;
    FLOAT MaskFilterValue;
    UINT HeadCount;
    DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};

Membres

QueryTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Interroger avec la forme [batchSize, sequenceLength, hiddenSize], où hiddenSize = headCount * headSize. Ce tensor est mutuellement exclusif avec StackedQueryKeyTensor et StackedQueryKeyValueTensor. Le capteur peut également avoir 4 ou 5 dimensions, tant que les dimensions de début sont de 1.

KeyTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Clé avec forme [batchSize, keyValueSequenceLength, hiddenSize], où hiddenSize = headCount * headSize. Ce tensor s’exclue mutuellement avec StackedQueryKeyTensor, StackedKeyValueTensor et StackedQueryKeyValueTensor. Le capteur peut également avoir 4 ou 5 dimensions, tant que les dimensions de début sont de 1.

ValueTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Valeur avec forme [batchSize, keyValueSequenceLength, valueHiddenSize], où valueHiddenSize = headCount * valueHeadSize. Ce tensor est mutuellement exclusif avec StackedKeyValueTensor et StackedQueryKeyValueTensor. Le capteur peut également avoir 4 ou 5 dimensions, tant que les dimensions de début sont de 1.

StackedQueryKeyTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Requête empilée et clé avec forme [batchSize, sequenceLength, headCount, 2, headSize]. Ce tensor est mutuellement exclusif avec QueryTensor, KeyTensor, StackedKeyValueTensor et StackedQueryKeyValueTensor.

Disposition StackedQueryKeyTensor

StackedKeyValueTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Clé empilée et valeur avec la forme [batchSize, keyValueSequenceLength, headCount, 2, headSize]. Ce tensor s’exclue mutuellement avec KeyTensor, ValueTensor, StackedQueryKeyTensor et StackedQueryKeyValueTensor.

Disposition StackedKeyValueTensor

StackedQueryKeyValueTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Requête empilée, clé et valeur avec la forme [batchSize, sequenceLength, headCount, 3, headSize]. Ce tensor est mutuellement exclusif avec QueryTensor, KeyTensor, ValueTensor, StackedQueryKeyTensor et StackedKeyValueTensor.

Disposition StackedQueryKeyValueTensor

BiasTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Il s’agit du biais de la forme[hiddenSize + hiddenSize + valueHiddenSize], qui est ajouté àla valeur// avant la première opération GEMM. Ce capteur peut également avoir 2, 3, 4 ou 5 dimensions, tant que les dimensions de début sont de 1.

MaskTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Il s’agit du masque qui détermine quels éléments obtiennent leur valeur définie sur MaskFilterValue après l’opération GEMM QxK. Le comportement de ce masque dépend de la valeur de MaskType et est appliqué après RelativePositionBiasTensor ou après la première opération GEMM si RelativePositionBiasTensor a la valeur Null. Pour plus d’informations, consultez la définition de MaskType .

RelativePositionBiasTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Il s’agit du biais qui est ajouté au résultat de la première opération GEMM.

PastKeyTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

Capteur de clé de l’itération précédente avec la forme [batchSize, headCount, pastSequenceLength, headSize]. Quand ce tensor n’est pas null, il est concaténé avec le capteur de clé, ce qui entraîne unensoreur de forme [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize].

PastValueTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

La valeur tensoriel de l’itération précédente avec la forme [batchSize, headCount, pastSequenceLength, headSize]. Lorsque ce tensor n’est pas null, il est concaténé avec ValueDesc , ce qui entraîne unensoreur de forme [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize].

OutputTensor

Type : const DML_TENSOR_DESC*

Sortie, de la forme [batchSize, sequenceLength, valueHiddenSize].

OutputPresentKeyTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

État présent pour la touche d’attention croisée, avec forme [batchSize, headCount, keyValueSequenceLength, headSize] ou état présent pour l’attention de soi avec la forme [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]. Il contient le contenu du capteur de clé ou le contenu du tensoreur PastKey + concaténé à passer à l’itération suivante.

OutputPresentValueTensor

Type : _Maybenull_ const DML_TENSOR_DESC*

État présent pour la valeur de l’attention croisée, avec forme [batchSize, headCount, keyValueSequenceLength, headSize] ou état présent pour l’attention de soi avec la forme [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]. Il contient le contenu du capteur de valeur ou le contenu du tensoreur PastValue + concaténé à passer à l’itération suivante.

Scale

Type : FLOAT

Effectuez une mise à l’échelle pour multiplier le résultat de l’opération GEMM QxK, mais avant l’opération Softmax. Cette valeur est généralement 1/sqrt(headSize).

MaskFilterValue

Type : FLOAT

Valeur ajoutée au résultat de l’opération GEMM QxK aux positions définies par le masque en tant qu’éléments de remplissage. Cette valeur doit être un nombre très élevé négatif (généralement -10000.0f).

HeadCount

Type : UINT

Nombre de têtes d’attention.

MaskType

Type : DML_MULTIHEAD_ATTENTION_MASK_TYPE

Décrit le comportement de MaskTensor.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN. Lorsque le masque contient une valeur de 0, MaskFilterValue est ajouté ; mais lorsqu’elle contient une valeur de 1, rien n’est ajouté.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH. Le masque, de la forme [1, batchSize], contient les longueurs de séquence de la zone non pavée pour chaque lot, et tous les éléments après que la longueur de séquence obtiennent leur valeur définie sur MaskFilterValue.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START. Le masque, de la forme [2, batchSize], contient les index de fin (exclusif) et de début (inclus) de la zone non pavée, et tous les éléments en dehors de la zone obtiennent leur valeur définie sur MaskFilterValue.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END. Le masque, de la forme [batchSize * 3 + 2], a les valeurs suivantes : [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]].

Disponibilité

Cet opérateur a été introduit dans DML_FEATURE_LEVEL_6_1.

Contraintes Tensor

BiasTensor, KeyTensor, OutputPresentKeyTensor, OutputPresentValueTensor, OutputTensor, PastKeyTensor, PastValueTensor, QueryTensor, RelativePositionBiasTensor, StackedKeyValueTensor, StackedQueryKeyTensor, StackedQueryKeyValueTensor et ValueTensor doivent avoir le même Type de données.

Prise en charge de Tensor

Tenseur Genre Nombres de dimensions pris en charge Types de données pris en charge
QueryTensor Entrée facultative 3 à 5 FLOAT32, FLOAT16
KeyTensor Entrée facultative 3 à 5 FLOAT32, FLOAT16
Tenseur de valeur Entrée facultative 3 à 5 FLOAT32, FLOAT16
StackedQueryKeyTensor Entrée facultative 5 FLOAT32, FLOAT16
StackedKeyValueTensor Entrée facultative 5 FLOAT32, FLOAT16
StackedQueryKeyValueTensor Entrée facultative 5 FLOAT32, FLOAT16
BiasTensor Entrée facultative 1 à 5 FLOAT32, FLOAT16
MaskTensor Entrée facultative 1 à 5 INT32
RelativePositionBiasTensor Entrée facultative 4 à 5 FLOAT32, FLOAT16
PastKeyTensor Entrée facultative 4 à 5 FLOAT32, FLOAT16
PastValueTensor Entrée facultative 4 à 5 FLOAT32, FLOAT16
Tenseur de sortie Sortie 3 à 5 FLOAT32, FLOAT16
OutputPresentKeyTensor Sortie facultative 4 à 5 FLOAT32, FLOAT16
OutputPresentValueTensor Sortie facultative 4 à 5 FLOAT32, FLOAT16

Spécifications

   
En-tête directml.h