다중 헤드 주의 작업을 수행합니다(자세한 내용은 주의가 필요한 것만 참조). 정확히 하나의 쿼리, 키 및 값 텐서가 쌓여 있는지 여부에 관계없이 존재해야 합니다. 예를 들어 StackedQueryKey 가 제공된 경우 쿼리 및 키 텐서는 모두 null이어야 합니다. 이러한 텐서는 이미 누적 레이아웃에 제공되어 있기 때문에 Null이어야 합니다. StackedKeyValue 및 StackedQueryKeyValue도 마찬가지입니다. 누적된 텐서에는 항상 5개의 차원이 있으며 항상 4차원에 쌓입니다.
논리적으로 알고리즘은 다음 작업으로 분해될 수 있습니다(대괄호로 묶은 작업은 선택 사항).
[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);
중요합니다
이 API는 DirectML 독립 실행형 재배포 가능 패키지의 일부로 사용할 수 있습니다( Microsoft.AI.DirectML 버전 1.12 이상 참조). 또한 DirectML 버전 기록 참조하세요.
문법
struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
_Maybenull_ const DML_TENSOR_DESC* QueryTensor;
_Maybenull_ const DML_TENSOR_DESC* KeyTensor;
_Maybenull_ const DML_TENSOR_DESC* ValueTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
_Maybenull_ const DML_TENSOR_DESC* BiasTensor;
_Maybenull_ const DML_TENSOR_DESC* MaskTensor;
_Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
_Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
const DML_TENSOR_DESC* OutputTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
FLOAT Scale;
FLOAT MaskFilterValue;
UINT HeadCount;
DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};
구성원
QueryTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
셰이프를 [batchSize, sequenceLength, hiddenSize]사용하여 쿼리합니다. 여기서 hiddenSize = headCount * headSize. 이 텐서는 StackedQueryKeyTensor 및 StackedQueryKeyValueTensor와 함께 사용할 수 없습니다. 또한 선행 차원이 1인 경우 텐서에는 4 또는 5차원이 있을 수 있습니다.
KeyTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
셰이프 [batchSize, keyValueSequenceLength, hiddenSize]가 있는 키, 여기서 hiddenSize = headCount * headSize. 이 텐서는 StackedQueryKeyTensor, StackedKeyValueTensor 및 StackedQueryKeyValueTensor와 함께 사용할 수 없습니다. 또한 선행 차원이 1인 경우 텐서에는 4 또는 5차원이 있을 수 있습니다.
ValueTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
셰이프 [batchSize, keyValueSequenceLength, valueHiddenSize]가 있는 값( 여기서 valueHiddenSize = headCount * valueHeadSize. 이 텐서는 StackedKeyValueTensor 및 StackedQueryKeyValueTensor와 함께 사용할 수 없습니다. 또한 선행 차원이 1인 경우 텐서에는 4 또는 5차원이 있을 수 있습니다.
StackedQueryKeyTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
셰이프 [batchSize, sequenceLength, headCount, 2, headSize]가 있는 누적 쿼리 및 키입니다. 이 텐서는 QueryTensor, KeyTensor, StackedKeyValueTensor 및 StackedQueryKeyValueTensor와 함께 사용할 수 없습니다.
StackedKeyValueTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
셰이프 [batchSize, keyValueSequenceLength, headCount, 2, headSize]가 있는 누적 키 및 값입니다. 이 텐서는 KeyTensor, ValueTensor, StackedQueryKeyTensor 및 StackedQueryKeyValueTensor와 함께 사용할 수 없습니다.
StackedQueryKeyValueTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
셰이프 [batchSize, sequenceLength, headCount, 3, headSize]가 있는 누적 쿼리, 키 및 값입니다. 이 텐서는 QueryTensor, KeyTensor, ValueTensor, StackedQueryKeyTensor 및 StackedKeyValueTensor와 함께 사용할 수 없습니다.
BiasTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
첫 번째 GEMM 작업 전에 [hiddenSize + hiddenSize + valueHiddenSize]//에 추가되는 셰이프의 바이어스입니다. 이 텐서에는 선행 차원이 1인 한 2, 3, 4 또는 5 차원이 있을 수도 있습니다.
MaskTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
QxK GEMM 작업 후 값을 MaskFilterValue 로 설정하는 요소를 결정하는 마스크입니다. 이 마스크의 동작은 MaskType의 값에 따라 달라지고 RelativePositionBiasTensor 이후 또는 RelativePositionBiasTensor가 null인 경우 첫 번째 GEMM 작업 후에 적용됩니다. 자세한 내용은 MaskType 정의를 참조하세요.
RelativePositionBiasTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
첫 번째 GEMM 작업의 결과에 추가되는 바이어스입니다.
PastKeyTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
도형 [batchSize, headCount, pastSequenceLength, headSize]이 있는 이전 반복의 키 텐서입니다. 이 텐서가 null이 아닌 경우 키 텐서와 연결되어 셰이프 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]의 텐서가 생성됩니다.
PastValueTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
도형 [batchSize, headCount, pastSequenceLength, headSize]이 있는 이전 반복의 값 텐서입니다. 이 텐서가 null이 아닌 경우 ValueDesc 와 연결되어 셰이프 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]의 텐서가 생성됩니다.
OutputTensor
형식: const DML_TENSOR_DESC*
셰이프 [batchSize, sequenceLength, valueHiddenSize]의 출력입니다.
OutputPresentKeyTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
교차 주의 키에 대한 현재 상태, 도형 [batchSize, headCount, keyValueSequenceLength, headSize] 을 사용한 자기 주의 상태 또는 현재 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]상태입니다. 다음 반복에 전달할 키 텐서의 콘텐츠 또는 연결된 PastKey + 키 텐서의 콘텐츠가 포함됩니다.
OutputPresentValueTensor
형식: _Maybenull_ const DML_TENSOR_DESC*
교차 주의 값의 현재 상태( 도형을 사용한 자기 주의에 대한 도형 [batchSize, headCount, keyValueSequenceLength, headSize] 또는 현재 상태)[batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] 다음 반복에 전달할 값 텐서의 콘텐츠 또는 연결된 PastValue + 값 텐서의 내용이 포함됩니다.
Scale
형식: FLOAT
크기를 조정하여 QxK GEMM 작업의 결과를 곱하지만 Softmax 연산 전에 곱합니다. 이 값은 일반적으로 1/sqrt(headSize).
MaskFilterValue
형식: FLOAT
마스크가 패딩 요소로 정의된 위치에 QxK GEMM 작업의 결과에 추가되는 값입니다. 이 값은 매우 큰 음수(일반적으로 -10000.0f)여야 합니다.
HeadCount
형식: UINT
주의 머리의 수입니다.
MaskType
형식: DML_MULTIHEAD_ATTENTION_MASK_TYPE
MaskTensor의 동작을 설명합니다.
DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN. 마스크 값이 0이면 MaskFilterValue 가 추가됩니다. 하지만 값이 1이면 아무 것도 추가하지 않습니다.
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH. 셰이프 [1, batchSize]의 마스크에는 각 일괄 처리에 대한 패드가 없는 영역의 시퀀스 길이가 포함되며, 시퀀스 길이 이후의 모든 요소는 해당 값을 MaskFilterValue로 설정합니다.
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START. 셰이프 [2, batchSize]의 마스크는 패드가 없는 영역의 끝(배타적) 및 시작(포함) 인덱스를 포함하고 영역 외부의 모든 요소는 해당 값을 MaskFilterValue로 설정합니다.
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END. 셰이프 [batchSize * 3 + 2]의 마스크 값은 다음과 같습니다 [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]].
가용도
이 연산자는 DML_FEATURE_LEVEL_6_1 도입되었습니다.
Tensor 제약 조건
BiasTensor, KeyTensor, OutputPresentKeyTensor, OutputPresentValueTensor, OutputTensor, PastKeyTensor, PastValueTensor, QueryTensor, RelativePositionBiasTensor, StackedKeyValueTensor, StackedQueryKeyTensor, StackedQueryKeyValueTensor 및 ValueTensor 는 동일한 DataType을 가져야 합니다.
Tensor 지원
| 텐서 | 친절한 | 지원되는 차원 수 | 지원되는 데이터 형식 |
|---|---|---|---|
| 쿼리텐서 | 선택적 입력 | 3~5 | FLOAT32, FLOAT16 |
| 키텐서 | 선택적 입력 | 3~5 | FLOAT32, FLOAT16 |
| 값 텐서 | 선택적 입력 | 3~5 | FLOAT32, FLOAT16 |
| StackedQueryKey텐서 | 선택적 입력 | 5 | FLOAT32, FLOAT16 |
| StackedKeyValue텐서 | 선택적 입력 | 5 | FLOAT32, FLOAT16 |
| StackedQueryKeyValue텐서 | 선택적 입력 | 5 | FLOAT32, FLOAT16 |
| 바이어스 텐서 | 선택적 입력 | 1 ~5 | FLOAT32, FLOAT16 |
| 마스크텐서 | 선택적 입력 | 1 ~5 | 인T32 |
| 상대적 위치바이어스 텐서 | 선택적 입력 | 4~5 | FLOAT32, FLOAT16 |
| 패스키텐서 | 선택적 입력 | 4~5 | FLOAT32, FLOAT16 |
| 과거값텐서 | 선택적 입력 | 4~5 | FLOAT32, FLOAT16 |
| 출력텐서 | 출력 | 3~5 | FLOAT32, FLOAT16 |
| 출력PresentKeyTensor | 선택적 출력 | 4~5 | FLOAT32, FLOAT16 |
| 출력PresentValueTensor | 선택적 출력 | 4~5 | FLOAT32, FLOAT16 |
요구 사항
| 머리글 | directml.h |