candle_core

Trait CustomOp1

Source
pub trait CustomOp1 {
    // Required methods
    fn name(&self) -> &'static str;
    fn cpu_fwd(
        &self,
        storage: &CpuStorage,
        layout: &Layout,
    ) -> Result<(CpuStorage, Shape)>;

    // Provided methods
    fn cuda_fwd(
        &self,
        _storage: &CudaStorage,
        _layout: &Layout,
    ) -> Result<(CudaStorage, Shape)> { ... }
    fn metal_fwd(
        &self,
        _storage: &MetalStorage,
        _layout: &Layout,
    ) -> Result<(MetalStorage, Shape)> { ... }
    fn bwd(
        &self,
        _arg: &Tensor,
        _res: &Tensor,
        _grad_res: &Tensor,
    ) -> Result<Option<Tensor>> { ... }
}
Expand description

Unary ops that can be defined in user-land.

Required Methods§

Source

fn name(&self) -> &'static str

Source

fn cpu_fwd( &self, storage: &CpuStorage, layout: &Layout, ) -> Result<(CpuStorage, Shape)>

The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

Provided Methods§

Source

fn cuda_fwd( &self, _storage: &CudaStorage, _layout: &Layout, ) -> Result<(CudaStorage, Shape)>

The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

Source

fn metal_fwd( &self, _storage: &MetalStorage, _layout: &Layout, ) -> Result<(MetalStorage, Shape)>

The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

Source

fn bwd( &self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor, ) -> Result<Option<Tensor>>

This function takes as argument the argument arg used in the forward pass, the result produced by the forward operation res and the gradient of the result grad_res. The function should return the gradient of the argument.

Implementors§