Compartilhar via


estrutura DML_MULTIHEAD_ATTENTION_OPERATOR_DESC (directml.h)

Executa uma operação de atenção de várias cabeças (para obter mais informações, consulte Atenção é tudo o que você precisa). Exatamente um tensor de Consulta, Chave e Valor deve estar presente, independentemente de estarem ou não empilhados. Por exemplo, se StackedQueryKey for fornecido, os tensores de consulta e chave deverão ser nulos, já que eles já são fornecidos em um layout empilhado. O mesmo vale para StackedKeyValue e StackedQueryKeyValue. Os tensores empilhados sempre têm cinco dimensões e são sempre empilhados na quarta dimensão.

Logicamente, o algoritmo pode ser decomposto nas seguintes operações (as operações entre colchetes são opcionais):

[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);

Importante

Essa API está disponível como parte do pacote redistribuível autônomo do DirectML (consulte Microsoft.AI.DirectML versão 1.12 e posterior. Consulte também de histórico de versões do DirectML.

Sintaxe

struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
    _Maybenull_ const DML_TENSOR_DESC* QueryTensor;
    _Maybenull_ const DML_TENSOR_DESC* KeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* ValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* MaskTensor;
    _Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
    const DML_TENSOR_DESC* OutputTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
    FLOAT Scale;
    FLOAT MaskFilterValue;
    UINT HeadCount;
    DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};

Membros

QueryTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Consultar com a forma [batchSize, sequenceLength, hiddenSize], onde hiddenSize = headCount * headSize. Este tensor é mutuamente exclusivo com StackedQueryKeyTensor e StackedQueryKeyValueTensor. O tensor também pode ter 4 ou 5 dimensões, desde que as dimensões principais sejam de 1.

KeyTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Chave com forma [batchSize, keyValueSequenceLength, hiddenSize], onde hiddenSize = headCount * headSize. Este tensor é mutuamente exclusivo com StackedQueryKeyTensor, StackedKeyValueTensor e StackedQueryKeyValueTensor. O tensor também pode ter 4 ou 5 dimensões, desde que as dimensões principais sejam de 1.

ValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Valor com forma [batchSize, keyValueSequenceLength, valueHiddenSize], onde valueHiddenSize = headCount * valueHeadSize. Este tensor é mutuamente exclusivo com StackedKeyValueTensor e StackedQueryKeyValueTensor. O tensor também pode ter 4 ou 5 dimensões, desde que as dimensões principais sejam de 1.

StackedQueryKeyTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Consulta empilhada e chave com forma [batchSize, sequenceLength, headCount, 2, headSize]. Este tensor é mutuamente exclusivo com QueryTensor, KeyTensor, StackedKeyValueTensor e StackedQueryKeyValueTensor.

Layout StackedQueryKeyTensor

StackedKeyValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Chave empilhada e valor com forma [batchSize, keyValueSequenceLength, headCount, 2, headSize]. Este tensor é mutuamente exclusivo com KeyTensor, ValueTensor, StackedQueryKeyTensor e StackedQueryKeyValueTensor.

Layout StackedKeyValueTensor

StackedQueryKeyValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Consulta empilhada, chave e valor com forma [batchSize, sequenceLength, headCount, 3, headSize]. Este tensor é mutuamente exclusivo com QueryTensor, KeyTensor, ValueTensor, StackedQueryKeyTensor e StackedKeyValueTensor.

Layout StackedQueryKeyValueTensor

BiasTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Esse é o viés da forma[hiddenSize + hiddenSize + valueHiddenSize], que é adicionada aoValor/ de / antes da primeira operação GEMM. Esse tensor também pode ter dimensões 2, 3, 4 ou 5, desde que as dimensões principais sejam de 1.

MaskTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Essa é a máscara que determina quais elementos obtêm seu valor definido como MaskFilterValue após a operação GEMM do QxK. O comportamento dessa máscara depende do valor de MaskType e é aplicado após RelativePositionBiasTensor ou após a primeira operação GEMM se RelativePositionBiasTensor for nulo. Consulte a definição de MaskType para obter mais informações.

RelativePositionBiasTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Esse é o viés que é adicionado ao resultado da primeira operação GEMM.

PastKeyTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Tensor de chave da iteração anterior com forma [batchSize, headCount, pastSequenceLength, headSize]. Quando esse tensor não é nulo, ele é concatenado com o tensor de chave, o que resulta em um tensor de forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize].

PastValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Valor tensor da iteração anterior com forma [batchSize, headCount, pastSequenceLength, headSize]. Quando esse tensor não é nulo, ele é concatenado com ValueDesc , o que resulta em um tensor de forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize].

OutputTensor

Tipo: const DML_TENSOR_DESC*

Saída, da forma [batchSize, sequenceLength, valueHiddenSize].

OutputPresentKeyTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Apresentar estado para chave de atenção cruzada, com forma [batchSize, headCount, keyValueSequenceLength, headSize] ou estado presente para autoatendimento com forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]. Ele contém o conteúdo do tensor de chave ou o conteúdo do tensor dechave + concatenado para passar para a próxima iteração.

OutputPresentValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Apresentar estado para valor de atenção cruzada, com forma [batchSize, headCount, keyValueSequenceLength, headSize] ou estado presente para autoatendimento com forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]. Ele contém o conteúdo do tensor de valor ou o conteúdo do tensor deValor + concatenado para passar para a próxima iteração.

Scale

Tipo: FLOAT

Dimensione para multiplicar o resultado da operação GEMM do QxK, mas antes da operação Softmax. Esse valor geralmente 1/sqrt(headSize)é .

MaskFilterValue

Tipo: FLOAT

Valor que é adicionado ao resultado da operação GEMM do QxK às posições que a máscara definiu como elementos de preenchimento. Esse valor deve ser um número negativo muito grande (geralmente -10000,0f).

HeadCount

Tipo: UINT

Número de cabeças de atenção.

MaskType

Tipo: DML_MULTIHEAD_ATTENTION_MASK_TYPE

Descreve o comportamento de MaskTensor.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN. Quando a máscara contém um valor de 0, MaskFilterValue é adicionado; mas quando ele contém um valor de 1, nada é adicionado.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH. A máscara, de forma [1, batchSize], contém os comprimentos de sequência da área não adicionada para cada lote e todos os elementos após o comprimento da sequência obtêm seu valor definido como MaskFilterValue.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START. A máscara, de forma [2, batchSize], contém os índices final (exclusivo) e inicial (inclusivo) da área não adicionada e todos os elementos fora da área obtêm seu valor definido como MaskFilterValue.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END. A máscara, de forma [batchSize * 3 + 2], tem os seguintes valores: [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]].

Disponibilidade

Esse operador foi introduzido no DML_FEATURE_LEVEL_6_1.

Restrições do Tensor

BiasTensor, KeyTensor, OutputPresentKeyTensor, OutputPresentValueTensor, OutputTensor, PastKeyTensor, PastValueTensor, QueryTensor, RelativePositionBiasTensor, StackedKeyValueTensor, StackedQueryKeyTensor, StackedQueryKeyValueTensor e ValueTensor devem ter o mesmo DataType.

Suporte ao Tensor

Tensor Tipo Contagens de dimensões com suporte Tipos de dados com suporte
QueryTensor Entrada opcional 3 a 5 FLOAT32, FLOAT16
Tensor de chave Entrada opcional 3 a 5 FLOAT32, FLOAT16
Tensor de valor Entrada opcional 3 a 5 FLOAT32, FLOAT16
StackedQueryKeyTensor Entrada opcional 5 FLOAT32, FLOAT16
StackedKeyValueTensor Entrada opcional 5 FLOAT32, FLOAT16
StackedQueryKeyValueTensor Entrada opcional 5 FLOAT32, FLOAT16
Tensor de viés Entrada opcional 1 a 5 FLOAT32, FLOAT16
Tensor de máscara Entrada opcional 1 a 5 INT32
RelativePositionBiasTensor Entrada opcional 4 a 5 FLOAT32, FLOAT16
PastKeyTensor Entrada opcional 4 a 5 FLOAT32, FLOAT16
Tensor de valor passado Entrada opcional 4 a 5 FLOAT32, FLOAT16
Tensor de saída Saída 3 a 5 FLOAT32, FLOAT16
OutputPresentKeyTensor Saída opcional 4 a 5 FLOAT32, FLOAT16
OutputPresentValueTensor Saída opcional 4 a 5 FLOAT32, FLOAT16

Requisitos

   
cabeçalho directml.h