Crate xla

Source
Expand description

Rust bindings for XLA (Accelerated Linear Algebra).

XLA is a compiler library for Machine Learning. It can be used to run models efficiently on GPUs, TPUs, and on CPUs too.

XlaOps are used to build a computation graph. This graph can built into a XlaComputation. This computation can then be compiled into a PjRtLoadedExecutable and then this executable can be run on a PjRtClient. Literal values are used to represent tensors in the host memory, and PjRtBuffer represent views of tensors/memory on the targeted device.

The following example illustrates how to build and run a simple computation.

// Create a CPU client.
let client = xla::PjRtClient::cpu()?;

// A builder object is used to store the graph of XlaOp.
let builder = xla::XlaBuilder::new("test-builder");

// Build a simple graph summing two constants.
let cst20 = xla_builder.constant_r0(20f32);
let cst22 = xla_builder.constant_r0(22f32);
let sum = (cst20 + cst22)?;

// Create a computation from the final node.
let sum = sum.build()?;

// Compile this computation for the target device and then execute it.
let result = client.compile(&sum)?;
let result = &result.execute::<xla::Literal>(&[])?;

// Retrieve the resulting value.
let result = result[0][0].to_literal_sync()?.to_vec::<f32>()?;

Structs§

ArrayShape
Bf16
F16
HloModuleProto
Literal
A literal represent a value, typically a multi-dimensional array, stored on the host device.
PjRtBuffer
A buffer represents a view on a memory slice hosted on a device.
PjRtClient
A client represents a device that can be used to run some computations. A computation graph is compiled in a way that is specific to a device before it can be run.
PjRtDevice
A device attached to a super::PjRtClient.
PjRtLoadedExecutable
XlaBuilder
XlaComputation
A computation is built from a root XlaOp. Computations are device independent and can be specialized to a given device through a compilation step.
XlaOp

Enums§

ElementType
Error
Main library error type.
PrimitiveType
The primitive types supported by XLA. S8 is a signed 1 byte integer, U32 is an unsigned 4 bytes integer, etc.
Shape
A shape specifies a primitive type as well as some array dimensions.
TfLogLevel

Traits§

ArrayElement
FromRawBytes
NativeType
A type implementing the NativeType trait can be directly converted to constant ops or literals.

Functions§

set_tf_min_log_level

Type Aliases§

Result