XLA 自定义通话

本文档介绍了如何借助 XLA FFI 库编写和使用 XLA 自定义调用。自定义调用是一种机制,用于在编译时向 XLA 编译器描述 HLO 模块中的外部“操作”,而 XLA FFI 是一种向 XLA 注册此类操作的实现(在运行时注册)的机制。FFI 代表“外部函数接口”,是一组 C API,定义了一个二进制接口 (ABI),以便 XLA 调用使用其他编程语言编写的外部代码。XLA 为使用 C++ 编写的 XLA FFI 提供仅限头文件的绑定,从而向最终用户隐藏底层 C API 的所有低级详细信息。

在 CPU 上创建自定义调用

您可以通过 XLA 的客户端 API 创建表示自定义调用的 HLO 指令。例如,以下代码使用自定义调用在 CPU 上计算 A[i] = B[i % 128]+ C[i]。(当然可以,而且应该!请使用常规 HLO 来执行此操作。)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

在 GPU 上创建自定义调用

XLA FFI 中的 GPU 自定义调用注册几乎完全相同,唯一的区别在于,对于 GPU,您需要请求底层平台数据流(CUDA 或 ROCM 流)才能在设备上启动内核。以下 CUDA 示例所执行计算 (A[i] = B[i % 128] + C[i]) 的方式与上述 CPU 代码相同。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

首先请注意,GPU 自定义调用函数仍是在 CPU 上执行的函数do_custom_call CPU 函数负责将 GPU 上的工作加入队列。在这里,它会启动 CUDA 内核,但它也可以执行一些其他操作,例如调用 cuBLAS。

参数和结果也存在于主机上,并且数据成员包含指向设备(即 GPU)内存的指针。传递到自定义调用处理程序的缓冲区具有底层设备缓冲区的形状,因此自定义调用可以从中计算内核启动参数。

将元组传递给自定义调用

请参考以下自定义调用。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

在 CPU 和 GPU 上,元组在内存中以指针数组的形式表示。当 XLA 使用元组参数或结果调用自定义调用时,它会扁平化这些调用并作为常规缓冲区参数或结果进行传递。

元组输出作为临时缓冲区

自定义调用的元组输入是一种便利,但并非绝对必需。如果我们不支持对自定义调用进行元组输入,则您始终可以使用 get-tuple-element 解压缩元组,然后再将其传递给自定义调用。

另一方面,元组输出可让您执行您无法通过其他方式无法实现的操作。

使用元组输出的显而易见原因是,元组输出是自定义调用(或任何其他 XLA 操作)返回多个独立数组的方式。

但不太明显的是,元组输出也是提供自定义调用临时内存的一种方式。可以,输出可以表示临时缓冲区。设想一下,输出缓冲区具有一项操作可向其写入的属性,而且可在该缓冲区被写入后从该缓冲区中读取数据。这正是你想要从临时缓冲区中得到的结果。

在上面的示例中,假设我们将 F32[1024] 用作临时缓冲区。然后,我们可以按上述方式编写 HLO,并且我们永远不会读取自定义调用输出的元组索引 1。