Chiamate personalizzate XLA

Questo documento descrive come scrivere e utilizzare chiamate personalizzate XLA utilizzando la libreria XLA FFI. La chiamata personalizzata è un meccanismo per descrivere un'"operazione" esterna nel modulo HLO al compilatore XLA (in fase di compilazione), mentre XLA FFI è un meccanismo per registrare l'implementazione di queste operazioni con XLA (in fase di esecuzione). FFI è l'acronimo di "foreign Functions Interface" ed è un insieme di API C che definiscono un'interfaccia binaria (ABI) che XLA può chiamare in codice esterno scritto in altri linguaggi di programmazione. XLA fornisce associazioni solo di intestazione per XLA FFI scritti in C++, che nascondono tutti i dettagli di basso livello delle API C sottostanti all'utente finale.

Crea una chiamata personalizzata sulla CPU

Puoi creare un'istruzione HLO che rappresenta una chiamata personalizzata tramite l'API client di XLA. Ad esempio, il codice seguente utilizza una chiamata personalizzata per calcolare A[i] = B[i % 128]+ C[i] sulla CPU. (Ovviamente potreste e dovresti! (esegui questa operazione con un normale HLO).

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Crea una chiamata personalizzata su GPU

La registrazione delle chiamate personalizzate della GPU con XLA FFI è quasi identica, l'unica differenza è che per la GPU devi richiedere uno stream della piattaforma sottostante (stream CUDA o ROCM) per poter avviare il kernel sul dispositivo. Ecco un esempio CUDA che esegue lo stesso calcolo (A[i] = B[i % 128] + C[i]) del codice CPU riportato sopra.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Nota innanzitutto che la funzione di chiamata personalizzata della GPU è ancora una funzione eseguita sulla CPU. La funzione CPU do_custom_call è responsabile del lavoro accodato sulla GPU. Qui avvia un kernel CUDA, ma potrebbe fare anche qualcos'altro, ad esempio chiamare cuBLAS.

Anche gli argomenti e i risultati risiedono nell'host e il membro dei dati contiene una memoria del puntatore al dispositivo (ad esempio GPU). I buffer passati al gestore di chiamate personalizzato hanno la forma dei buffer del dispositivo sottostanti, quindi la chiamata personalizzata può calcolare i parametri di avvio del kernel da questi.

Passaggio di tuple alle chiamate personalizzate

Considera la seguente chiamata personalizzata.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Sia su CPU sia su GPU, una tupla è rappresentata in memoria sotto forma di array di puntatori. Quando XLA chiama chiamate personalizzate con argomenti o risultati a tupla, le appiattisce e vengono passate come argomenti o risultati del buffer regolari.

La tupla viene emessa come buffer temporanei

L'inserimento di tuple nelle chiamate personalizzate è pratico, ma non è strettamente necessario. Se non supportavamo gli input tuple nelle chiamate personalizzate, puoi sempre estrarre le tuple utilizzando get-tuple-element prima di passarle alla chiamata personalizzata.

D'altra parte, gli output a tuple ti consentono di eseguire operazioni che altrimenti non avresti potuto eseguire.

Il motivo ovvio degli output di tuple è che gli output di tuple sono il modo in cui una chiamata personalizzata (o qualsiasi altra operazione XLA) restituisce più array indipendenti.

Ma meno ovviamente, un output a tuple è anche un modo per assegnare alla tua memoria temporanea delle chiamate. Sì, un output può rappresentare un buffer temporaneo. Considera che un buffer di output ha la proprietà su cui l'operazione può scrivere e può leggerlo dopo la scrittura. È esattamente quello che vuoi da un buffer di temperatura.

Nell'esempio precedente, supponiamo di voler utilizzare F32[1024] come buffer temporaneo. Quindi scriveremmo l'HLO come indicato sopra e non avremmo mai letto l'indice tuple 1 dell'output della chiamata personalizzata.