שיחות בהתאמה אישית של XLA

במסמך הזה מוסבר איך לכתוב קריאות מותאמות אישית של XLA ולהשתמש בהן באמצעות ספריית XLA FFI. קריאה בהתאמה אישית היא מנגנון לתיאור 'פעולה' חיצונית ב מודול HLO למהדר (compiler) XLA (בזמן הרכבה), ו-XLA FFI הוא מנגנון לרשום פעולות כאלה באמצעות XLA (בזמן ריצה). FFI הוא ראשי תיבות של 'ממשק פונקציה חיצונית', והוא קבוצה של ממשקי API ל-C שמגדירים ממשק בינארי (ABI) כדי לאפשר ל-XLA לבצע קריאה לקוד חיצוני שנכתב בשפות תכנות אחרות. XLA מספק קישורים של כותרות בלבד ל-XLA FFI שנכתב ב-C++‎, שמסתיר ממשתמש הקצה את כל הפרטים ברמה הנמוכה של ממשקי ה-API הבסיסיים של C.

שיחות בהתאמה אישית של JAX + XLA

במסמכי התיעוד של JAX מוסבר איך דוגמאות מלאות לשילוב קריאות מותאמות אישית ו-XLA FFI עם JAX.

קישור XLA FFI

קישור XLA FFI הוא מפרט של חתימה מותאמת אישית של קריאה בזמן הידור: ארגומנטים מותאמים אישית של קריאה, מאפיינים וסוגים שלהם, ופרמטרים נוספים שמועברים דרך הקשר הביצוע (כלומר, gpu stream לקצה העורפי של GPU). XLA FFI ניתן לשייך את הממצא לכל קריאה לפעולה מסוג C++ (מצביע פונקציה, lambda וכו') באמצעות תואמת לחתימה operator(). המבנה של ה-handler מפענח קריאת XLA FFI frame (מוגדר על ידי ממשק ה-API היציב C), .מקלידים את כל הפרמטרים ומעבירים תוצאות מפוענחות לקריאה חוזרת (callback) שהוגדרה על ידי המשתמש.

קישור XLA FFI מסתמך במידה רבה על תכנות מטא באמצעות תבניות כדי שאפשר יהיה לקמפל את הטיפולן שנוצר לקוד המכונה היעיל ביותר. זמן ריצה התקורות מסודרות לפי כמה ננו-שניות לכל שיחה מותאמת אישית הפרמטר.

נקודות התאמה אישית של XLA FFI שהוטמעו כמומחיות של תבנית, משתמשים יכולים להגדיר איך לפענח את הסוגים המותאמים אישית שלהם, כלומר, אפשר כדי להגדיר פענוח מותאם אישית לסוגי enum class בהגדרת המשתמש.

שגיאות חוזרות משיחות בהתאמה אישית

הטמעות של שיחות בהתאמה אישית חייבות להחזיר את הערך xla::ffi::Error לאות או שגיאה לסביבת זמן הריצה של XLA. היא דומה ל-absl::Status ויש בה אותם קודי שגיאה. אנחנו לא משתמשים ב-absl::Status כי אין לו ABI יציב, ולא בטוח להעביר אותו בין ספריית הקריאה בהתאמה אישית שנטענת באופן דינמי לבין XLA עצמה.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

שמירת ארגומנטים ותוצאות במאגר

ב-XLA נעשה שימוש בסגנון העברת יעד לתוצאות: קריאות בהתאמה אישית (או כל פעולה אחרת של XLA, לצורך העניין) לא מקצות זיכרון לתוצאות, אלא כותבות ליעדים שהועברו על ידי סביבת זמן הריצה של XLA. ב-XLA נעשה שימוש בהקצאת מאגרים סטטיים, והמאגרים מוקצים לכל הערכים על סמך טווחי החיים שלהם בזמן הידור.

התוצאות מועברות למטפלים של FFI עטופות בתבנית Result<T>, שיש לה סמנטיקה של מצביע: operator-> נותן גישה לפרמטר הבסיסי.

הארגומנטים והתוצאות של AnyBuffer מספקים גישה לפרמטרים מותאמים אישית של מאגר קריאות מכל סוג נתונים. האפשרות הזו שימושית כשיש לקריאה בהתאמה אישית הטמעה גנרית שפועלת במספר סוגי נתונים, והטמעת הקריאה בהתאמה אישית מבצעת ניתוב בזמן ריצה על סמך סוג הנתונים. AnyBuffer מאפשר גישה לסוג הנתונים במאגר, למאפיינים ולמצביע למאגר עצמו.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

ארגומנטים ותוצאות של מאגרים מוגבלים

Buffer מאפשר להוסיף אילוצים על סוג הנתונים ועל הדירוג של המאגר, והם ייבדקו באופן אוטומטי על ידי הטיפולן, ויחזירו שגיאה לסביבת זמן הריצה של XLA אם הארגומנטים של זמן הריצה לא תואמים לחתימה של הטיפולן ב-FFI.

// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

ארגומנטים ותוצאות וריאנטים

אם מספר הארגומנטים והתוצאה יכולים להיות שונים במקרים שונים של קריאה מותאמת אישית, אפשר לפענח אותן בזמן הריצה באמצעות RemainingArgs וגם RemainingRets.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

אפשר להצהיר על ארגומנטים ותוצאות פולימורפיים אחרי ארגומנטים ותוצאות רגילים, אבל אי אפשר לקשר ארגומנטים ותוצאות רגילים אחרי ארגומנטים ותוצאות פולימורפיים.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

מאפיינים

XLA FFI תומך בפענוח אוטומטי של mlir::DictionaryAttr שמועברים בתור custom_call backend_config לפרמטרים של טיפולי FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

בדוגמה הזו, לקריאה המותאמת אישית יש ארגומנט אחד של מאגר נתונים זמני ושני מאפיינים, XLA FFI יכול לפענח את הקוד האישי ולהעביר אותו למשתמש שהוגדר לקריאה.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

מאפייני Enum בהגדרת המשתמש

XLA FFI יכול לפענח באופן אוטומטי מאפייני MLIR אינטגרליים למערכי ערכים מוגדרים על ידי משתמש. למחלקה של הטיפוס בן המנייה צריך להיות אותו סוג אינטגרל בסיסי והפענוח של חייב להיות רשום במפורש ב-XLA FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

קישור כל מאפייני השיחה המותאמים אישית

אפשר לקבל גישה לכל מאפייני השיחה המותאמים אישית כמילון ולפענח באופן עצל רק את המאפיינים הנדרשים בזמן הריצה.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

מאפייני מבנה שהוגדרו על ידי המשתמש

XLA FFI יכול לפענח מאפייני מילון למבנים מוגדרים על ידי משתמשים.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

בדוגמה שלמעלה, range הוא מאפיין mlir::DictionaryAttr, ובמקום לגשת לשדות במילון לפי שם, אפשר לפענח אותו באופן אוטומטי כמבנה של C++‎. פענוח צריך להיות רשום במפורש מאקרו XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (מאחורי הסצנה שהיא מגדירה התמחות של תבנית במרחב השמות ::xla::ffi, ולכן צריך להוסיף את המאקרו מרחב השמות הגלובלי).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
                                             StructMember<int64_t>("i64"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

אפשר לטעון מאפיינים מותאמים אישית ממילון, בדיוק כמו כל מאפיין אחר. בדוגמה הבאה, כל מאפייני השיחה המותאמים אישית מפענחים כ-Dictionary, ואפשר לגשת ל-range לפי שם.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

יצירת קריאה בהתאמה אישית ל-CPU

אפשר ליצור הוראה HLO שמייצגת שיחה בהתאמה אישית דרך הלקוח של XLA API. לדוגמה, הקוד הבא משתמש בקריאה מותאמת אישית כדי לחשב את A[i] = B[i % 128]+ C[i] במעבד (CPU). (כמובן שאפשר – והיית צריך! – מבצעים זאת באמצעות HLO רגיל).

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

יצירת קריאה מותאמת אישית ב-GPU

הרישום של קריאה בהתאמה אישית ל-GPU באמצעות XLA FFI הוא כמעט זהה, ההבדל היחיד הוא שב-GPU צריך לבקש מקור נתונים של פלטפורמה בסיסית (מקור נתונים של CUDA או ROCM) כדי להפעיל את הליבה במכשיר. הנה דוגמה ל-CUDA שמבצעת את אותה חישוב (A[i] = B[i % 128] + C[i]) כמו הקוד ל-CPU שלמעלה.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

קודם כול, חשוב לזכור שפונקציית הקריאה בהתאמה אישית של ה-GPU עדיין פועלת במעבד (CPU). הפונקציה המעבד (CPU) do_custom_call אחראית על הוספת העבודה לתור ב-GPU. כאן הוא מפעיל ליבה של CUDA, אבל הוא יכול גם לעשות משהו אחר, למשל להפעיל את cuBLAS.

הארגומנטים והתוצאות מופיעים גם במארח, וחבר הנתונים מכיל מצביע לזיכרון של המכשיר (כלומר GPU). מאגרי הנתונים (buffers) שמועברים למטפל הקריאה בהתאמה אישית הם באותו פורמט של מאגרי הנתונים הבסיסיים במכשיר, כך שהקריאה בהתאמה אישית יכולה לחשב מהם פרמטרים להפעלת הליבה.

העברת צמדי ערכים (tuples) לשיחות בהתאמה אישית

נבחן את הקריאה בהתאמה אישית הבאה.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

גם במעבד (CPU) וגם ב-GPU, tuple מיוצג בזיכרון כמערך של מצביעים. כש-XLA מבצע קריאות בהתאמה אישית עם ארגומנטים או תוצאות של צמדי ערכי-מפתח (tuple), הוא משטח אותם ומעביר אותם כארגומנטים או כתוצאות רגילים של מאגר.

פלט של צמדי ערכים כמאגרים זמניים

קל להזין קלט כפול לשיחות מותאמות אישית, אבל זה לא ממש מובן הנחוצים. אם לא תמכנו בקלט tuples של קריאות מותאמות אישית, תמיד תוכלו פורקים את הכפולים באמצעות רכיב get-tuplement לפני שמעבירים אותם שיחה.

לעומת זאת, פלט של קבוצת ערכים מאפשר לכם לעשות דברים שלא תוכלו לעשות אחרת.

הסיבה הברורה לפלטים דו-כיוונית היא שפלטים דו-כיווניים הם call (או כל op אחר של XLA) מחזירה כמה מערכים בלתי תלויים.

אבל פחות ברור שאפשר להשתמש ביציאת טופל לצורך הענקת זיכרון זמני לשיחה בהתאמה אישית. כן, פלט יכול לייצג מאגר נתונים זמני. חשוב להשתמש במאגר נתונים זמני של פלט יש את המאפיין שהאופציה יכולה לכתוב לו, והוא יכול לקרוא ממנו אחרי נכתב אליו. זה בדיוק מה שרצית ממאגר נתונים זמני.

בדוגמה שלמעלה, נניח שרצינו להשתמש ב-F32[1024] כמאגר זמני. לאחר מכן נכתוב את ה-HLO בדיוק כמו שלמעלה, ופשוט לא קראנו אף פעם את אינדקס 1 של פלט השיחה המותאמת אישית.