Partager via


DirectMLX

DirectMLX est une bibliothèque d’assistance en-tête C++ uniquement pour DirectML, destinée à faciliter la composition d’opérateurs individuels dans des graphiques.

DirectMLX propose des wrappers pratiques pour tous les types d’opérateurs DirectML (DML), ainsi que des surcharges d’opérateurs intuitives, ce qui simplifie l’instanciation des opérateurs DML et leur enchaînement dans des graphes complexes.

Où trouver DirectMLX.h

DirectMLX.h est distribué en tant que logiciel open source sous la licence MIT. La dernière version est disponible sur DirectML GitHub.

Exigences de version

DirectMLX nécessite DirectML version 1.4.0 ou ultérieure (voir l’historique des versions DirectML). Les versions antérieures de DirectML ne sont pas prises en charge.

DirectMLX.h nécessite un compilateur compatible C++11, y compris (mais pas limité à) :

  • Visual Studio 2017
  • Visual Studio 2019
  • Clang 10

Notez qu’un compilateur C++17 (ou version ultérieure) est l’option que nous vous recommandons. La compilation pour C++11 est possible, mais elle nécessite l’utilisation de bibliothèques tierces (telles que GSL et Abseil) pour remplacer les fonctionnalités de bibliothèque standard manquantes.

Si vous avez une configuration qui ne parvient pas à compiler DirectMLX.h, envoyez un problème sur notre GitHub.

Utilisation de base

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

dml::Graph graph(device);

// Input tensor of type FLOAT32 and sizes { 1, 2, 3, 4 }
auto x = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, {1, 2, 3, 4}));

// Create an operator to compute the square root of x
auto y = dml::Sqrt(x);

// Compile a DirectML operator from the graph. When executed, this compiled operator will compute
// the square root of its input.
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { y });

// Now initialize and dispatch the DML operator as usual

Voici un autre exemple, qui crée un graphe DirectML capable de calculer la formule quadratique.

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

std::pair<dml::Expression, dml::Expression>
    QuadraticFormula(dml::Expression a, dml::Expression b, dml::Expression c)
{
    // Quadratic formula: given an equation of the form ax^2 + bx + c = 0, x can be found by:
    //   x = -b +/- sqrt(b^2 - 4ac) / (2a)
    // https://en.wikipedia.org/wiki/Quadratic_formula

    // Note: DirectMLX provides operator overloads for common mathematical expressions. So for 
    // example a*c is equivalent to dml::Multiply(a, c).
    auto x1 = -b + dml::Sqrt(b*b - 4*a*c) / (2*a);
    auto x2 = -b - dml::Sqrt(b*b - 4*a*c) / (2*a);

    return { x1, x2 };
}

/* ... */

dml::Graph graph(device);

dml::TensorDimensions inputSizes = {1, 2, 3, 4};
auto a = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto b = dml::InputTensor(graph, 1, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto c = dml::InputTensor(graph, 2, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));

auto [x1, x2] = QuadraticFormula(a, b, c);

// When executed with input tensors a, b, and c, this compiled operator computes the two outputs
// of the quadratic formula, and returns them as two output tensors x1 and x2
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { x1, x2 });

// Now initialize and dispatch the DML operator as usual

Autres exemples

Vous trouverez des exemples complets à l’aide de DirectMLX sur le dépôt GitHub DirectML.

Options à la compilation

DirectMLX prend en charge les #define au moment de la compilation pour personnaliser différentes parties de l’en-tête.

Choix Descriptif
DMLX_NO_EXCEPTIONS Si #define'd, provoque des erreurs qui entraînent un appel à std::abort au lieu de lever une exception. Cela est défini par défaut si les exceptions ne sont pas disponibles (par exemple, si des exceptions ont été désactivées dans les options du compilateur).
DMLX_USE_WIL Si #define’d, les exceptions sont levées à l’aide de types d’exceptions Bibliothèque d’implémentation Windows. Sinon, les types d’exceptions standard (par exemple std::runtime_error) sont utilisés à la place. Cette option n’a aucun effet si DMLX_NO_EXCEPTIONS est définie.
DMLX_USE_ABSEIL Si #define’d, utilise Abseil comme remplacements de liste déroulante pour les types de bibliothèque standard indisponibles en C++11. Ces types incluent absl::optional (à la place de std::optional), absl::Span (à la place de std::span) et absl::InlinedVector.
DMLX_USE_GSL Contrôle s’il faut utiliser GSL comme remplacement de std::span. Si `#define` est défini, les utilisations de std::span sont remplacées par gsl::span dans les compilateurs sans implémentations natives de std::span. Sinon, une implémentation de dépôt inline est fournie à la place. Notez que cette option est utilisée uniquement lors de la compilation sur un compilateur pré-C++20 sans prise en charge std::span, et quand aucun autre remplacement de bibliothèque standard (comme Abseil) n’est en cours d’utilisation.

Contrôle de la disposition des tenseurs

Pour la plupart des opérateurs, DirectMLX calcule les propriétés des tenseurs de sortie de l’opérateur en votre nom. Par exemple, lors de l'exécution d'un dml::Reduce à travers les axes { 0, 2, 3 } avec un tenseur d'entrée de taille { 3, 4, 5, 6 }, DirectMLX calcule automatiquement les propriétés du tenseur de sortie, y compris la forme correcte de { 1, 4, 1, 1 }.

Toutefois, les autres propriétés d’un tensoriel de sortie incluent strides, TotalTensorSizeInBytes et GuaranteedBaseOffsetAlignment. Par défaut, DirectMLX définit ces propriétés de façon à ce que le tenseur n’ait aucun pas mémoire, aucun alignement garanti de l’adresse de base, et une taille totale en octets calculée par DMLCalcBufferTensorSize.

DirectMLX prend en charge la possibilité de personnaliser ces propriétés de capteur de sortie à l’aide d’objets appelés stratégies de capteur. TensorPolicy est un callback personnalisé invoqué par DirectMLX, qui retourne les propriétés du tenseur de sortie en fonction du type de données calculé, des indicateurs et des tailles du tenseur.

Les stratégies Tensor peuvent être définies sur l’objet dml ::Graph et seront utilisées pour tous les opérateurs suivants sur ce graphique. Les stratégies Tensor peuvent également être définies directement lors de la construction d’un TensorDesc.

La disposition des tenseurs produits par DirectMLX peut donc être contrôlée en définissant une TensorPolicy qui définit les pas appropriés sur ses tenseurs.

Exemple 1

// Define a policy, which is a function that returns a TensorProperties given a data type,
// flags, and sizes.
dml::TensorProperties MyCustomPolicy(
    DML_TENSOR_DATA_TYPE dataType,
    DML_TENSOR_FLAGS flags,
    Span<const uint32_t> sizes)
{
    // Compute your custom strides, total tensor size in bytes, and guaranteed base
    // offset alignment
    dml::TensorProperties props;
    props.strides = /* ... */;
    props.totalTensorSizeInBytes = /* ... */;
    props.guaranteedBaseOffsetAlignment = /* ... */;
    return props;
};

// Set the policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy(&MyCustomPolicy));

Exemple 2

DirectMLX fournit également d'autres politiques de tenseur intégrées. La stratégie InterleavedChannel, par exemple, est proposée à titre de commodité et peut être utilisée pour produire des tenseurs avec des pas mémoire permettant un agencement en ordre NHWC.

// Set the InterleavedChannel policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy::InterleavedChannel());

// When executed, the tensor `result` will be in NHWC layout (rather than the default NCHW)
auto result = dml::Convolution(/* ... */);

Voir aussi