Panggilan kustom XLA

Dokumen ini menjelaskan cara menulis dan menggunakan panggilan kustom XLA menggunakan library XLA FFI. Panggilan kustom adalah mekanisme untuk mendeskripsikan "operasi" eksternal dalam modul HLO ke compiler XLA (pada waktu kompilasi), dan XLA FFI adalah mekanisme untuk mendaftarkan implementasi operasi tersebut dengan XLA (pada waktu proses). FFI adalah singkatan dari "foreign function interface" dan merupakan kumpulan API C yang menentukan antarmuka biner (ABI) bagi XLA untuk memanggil kode eksternal yang ditulis dalam bahasa pemrograman lain. XLA menyediakan binding khusus header untuk XLA FFI yang ditulis dalam C++, yang menyembunyikan semua detail tingkat rendah dari API C yang mendasarinya dari pengguna akhir.

Membuat panggilan kustom pada CPU

Anda dapat membuat petunjuk HLO yang merepresentasikan panggilan kustom melalui API klien XLA. Misalnya, kode berikut menggunakan panggilan kustom untuk menghitung A[i] = B[i % 128]+ C[i] pada CPU. (Tentu saja Anda bisa – dan seharusnya! – lakukan ini dengan HLO reguler.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Membuat panggilan kustom di GPU

Pendaftaran panggilan kustom GPU dengan XLA FFI hampir identik. Satu-satunya perbedaan adalah bahwa untuk GPU, Anda perlu meminta aliran platform yang mendasarinya (aliran CUDA atau ROCM) agar dapat meluncurkan kernel di perangkat. Berikut adalah contoh CUDA yang melakukan komputasi yang sama (A[i] = B[i % 128] + C[i]) dengan kode CPU di atas.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Perhatikan terlebih dahulu bahwa fungsi panggilan khusus GPU masih merupakan fungsi yang dieksekusi pada CPU. Fungsi CPU do_custom_call bertanggung jawab untuk mengantrekan pekerjaan di GPU. Di sini kernel CUDA meluncurkan kernel CUDA, tetapi juga dapat melakukan hal lain, seperti memanggil cuBLAS.

Argumen dan hasil juga ada di host, dan anggota data berisi pointer ke memori perangkat (yaitu GPU). Buffer yang diteruskan ke pengendali panggilan kustom memiliki bentuk buffer perangkat dasar, sehingga panggilan kustom dapat menghitung parameter peluncuran kernel dari buffer tersebut.

Meneruskan tupel ke panggilan kustom

Pertimbangkan panggilan kustom berikut.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Pada CPU dan GPU, tuple direpresentasikan dalam memori sebagai array pointer. Saat XLA memanggil panggilan kustom dengan argumen tuple atau hasilnya, akan diratakan dan diteruskan sebagai hasil atau argumen buffer reguler.

Output Tuple sebagai buffer sementara

Input tuple ke panggilan kustom memang mudah, tetapi tidak sepenuhnya diperlukan. Jika kita tidak mendukung input tuple ke panggilan kustom, Anda dapat mengekstrak tuple kapan saja menggunakan elemen get-tuple-sebelum meneruskannya ke panggilan kustom.

Di sisi lain, output tuple memungkinkan Anda melakukan hal-hal yang tidak dapat dilakukan jika tidak.

Alasan yang jelas untuk memiliki output tuple adalah karena output tuple merupakan cara panggilan kustom (atau op XLA lainnya) menampilkan beberapa array independen.

Namun yang kurang jelas, output tuple juga merupakan cara untuk memberikan memori sementara panggilan kustom Anda. Ya, output dapat mewakili buffering sementara. Misalkan, buffer output memiliki properti yang dapat ditulis oleh operasi, dan dapat membacanya setelah ditulis. Itulah yang Anda inginkan dari buffer sementara.

Pada contoh di atas, misalkan kita ingin menggunakan F32[1024] sebagai buffer sementara. Kemudian kita akan menulis HLO seperti di atas, dan kita tidak akan pernah membaca indeks tuple 1 dari output panggilan kustom.