DirectMLX

DirectMLX 是 DirectML 的C++仅标头帮助程序库,旨在更轻松地将单个运算符组合到图形中。

DirectMLX 为所有 DirectML (DML) 运算符类型以及直观的运算符重载提供了方便的包装器,这使得实例化 DML 运算符更简单,并将其链接到复杂的图形中。

DirectMLX.h 所在位置

DirectMLX.h 在 MIT 许可证下以开源软件的形式分发。 可以在 DirectML GitHub 上找到最新版本。

版本要求

DirectMLX 需要 DirectML 版本 1.4.0 或更高版本(请参阅 DirectML 版本历史记录)。 不支持较旧版本的 DirectML。

DirectMLX.h 需要支持 C++11 的编译器,包括(包括但不限于):

  • Visual Studio 2017
  • Visual Studio 2019
  • Clang 10 (叮当 10)

请注意,C++17(或更新)编译器是我们建议的选项。 可以编译 C++11,但它需要使用第三方库(如 GSLAbseil)来替换缺少的标准库功能。

如果配置无法编译 DirectMLX.h,请在 GitHub 上提出问题

基本用法

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

dml::Graph graph(device);

// Input tensor of type FLOAT32 and sizes { 1, 2, 3, 4 }
auto x = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, {1, 2, 3, 4}));

// Create an operator to compute the square root of x
auto y = dml::Sqrt(x);

// Compile a DirectML operator from the graph. When executed, this compiled operator will compute
// the square root of its input.
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { y });

// Now initialize and dispatch the DML operator as usual

下面是另一个示例,它创建能够计算 二次公式的 DirectML 图形。

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

std::pair<dml::Expression, dml::Expression>
    QuadraticFormula(dml::Expression a, dml::Expression b, dml::Expression c)
{
    // Quadratic formula: given an equation of the form ax^2 + bx + c = 0, x can be found by:
    //   x = -b +/- sqrt(b^2 - 4ac) / (2a)
    // https://en.wikipedia.org/wiki/Quadratic_formula

    // Note: DirectMLX provides operator overloads for common mathematical expressions. So for 
    // example a*c is equivalent to dml::Multiply(a, c).
    auto x1 = -b + dml::Sqrt(b*b - 4*a*c) / (2*a);
    auto x2 = -b - dml::Sqrt(b*b - 4*a*c) / (2*a);

    return { x1, x2 };
}

/* ... */

dml::Graph graph(device);

dml::TensorDimensions inputSizes = {1, 2, 3, 4};
auto a = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto b = dml::InputTensor(graph, 1, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto c = dml::InputTensor(graph, 2, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));

auto [x1, x2] = QuadraticFormula(a, b, c);

// When executed with input tensors a, b, and c, this compiled operator computes the two outputs
// of the quadratic formula, and returns them as two output tensors x1 and x2
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { x1, x2 });

// Now initialize and dispatch the DML operator as usual

更多示例

可以在 DirectML GitHub 存储库中找到使用 DirectMLX 的完整示例。

编译时选项

DirectMLX 支持编译时 #define 来自定义标头的各个部件。

选项 DESCRIPTION
DMLX_NO_EXCEPTIONS 如果已 #define,将会导致调用 std::abort 的错误,而不是引发异常。 如果异常不可用(例如,在编译器选项中禁用了异常),则默认定义此项。
DMLX_USE_WIL 如果定义了 #define,则会使用 Windows 实现库 中的异常类型抛出异常。 否则,将改用标准异常类型(例如 std::runtime_error)。 如果定义了 DMLX_NO_EXCEPTIONS ,此选项将不起作用。
DMLX_USE_ABSEIL 如果已 #define,请使用 Abseil 直接替换 C++11 中不可用的标准库类型。 这些类型包括 absl::optional (代替 std::optional)、 absl::Span (代替 std::span)和 absl::InlinedVector
DMLX_USE_GSL 控制是否使用 GSL 作为 std::span 的替代选项。 如果已 #define,则在没有本机 std::span 实现的编译器上,gsl::span 的使用将会替换为 std::span。 否则,将改为提供内联直接实现。 请注意,仅当使用不支持 std::span 的 C++20 之前版本的编译器进行编译时,且没有使用其他替代标准库(如 Abseil)时,才使用此选项。

控制张量布局

对于大多数运算符,DirectMLX 代表你计算运算符的输出张量的属性。 例如,在跨输入张量大小为 dml::Reduce 的轴 { 0, 2, 3 } 执行 { 3, 4, 5, 6 } 时,DirectMLX 将自动计算输出张量的属性,包括 { 1, 4, 1, 1 } 的正确形状。

但输出张量的其他属性还包括 Strides、TotalTensorSizeInBytes 和GuaranteedBaseOffsetAlignment。 默认情况下,DirectMLX 设置这些属性,以使张量没有步幅、没有保证的基本偏移对齐,并且总张量大小(字节)由 DMLCalcBufferTensorSize 计算得出。

DirectMLX 支持使用称为 张量策略的对象自定义这些输出张量属性的功能。 TensorPolicy 是由 DirectMLX 调用的可自定义回调,并在给定张量计算数据类型、标志和大小的情况下返回输出张量属性。

可以在 dml::Graph 对象上设置 Tensor 策略,并将用于该图上的所有后续运算符。 还可以在构造 TensorDesc 时直接设置 Tensor 策略。

因此,DirectMLX 生成的张量布局可以通过设置 TensorPolicy 来控制,该 TensorPolicy 可在其张量上设置适当的步幅。

示例 1

// Define a policy, which is a function that returns a TensorProperties given a data type,
// flags, and sizes.
dml::TensorProperties MyCustomPolicy(
    DML_TENSOR_DATA_TYPE dataType,
    DML_TENSOR_FLAGS flags,
    Span<const uint32_t> sizes)
{
    // Compute your custom strides, total tensor size in bytes, and guaranteed base
    // offset alignment
    dml::TensorProperties props;
    props.strides = /* ... */;
    props.totalTensorSizeInBytes = /* ... */;
    props.guaranteedBaseOffsetAlignment = /* ... */;
    return props;
};

// Set the policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy(&MyCustomPolicy));

示例 2

DirectMLX 还提供一些内置的备用张量策略。 例如,InterleavedChannel 策略是为了使用便于使用而提供的,它可用于生成具有步幅的张量,以便按 NHWC 顺序写入。

// Set the InterleavedChannel policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy::InterleavedChannel());

// When executed, the tensor `result` will be in NHWC layout (rather than the default NCHW)
auto result = dml::Convolution(/* ... */);

另请参阅