candle_core

Trait CustomOp3

Source
pub trait CustomOp3 {
    // Required methods
    fn name(&self) -> &'static str;
    fn cpu_fwd(
        &self,
        s1: &CpuStorage,
        l1: &Layout,
        s2: &CpuStorage,
        l2: &Layout,
        s3: &CpuStorage,
        l3: &Layout,
    ) -> Result<(CpuStorage, Shape)>;

    // Provided methods
    fn cuda_fwd(
        &self,
        _: &CudaStorage,
        _: &Layout,
        _: &CudaStorage,
        _: &Layout,
        _: &CudaStorage,
        _: &Layout,
    ) -> Result<(CudaStorage, Shape)> { ... }
    fn metal_fwd(
        &self,
        _: &MetalStorage,
        _: &Layout,
        _: &MetalStorage,
        _: &Layout,
        _: &MetalStorage,
        _: &Layout,
    ) -> Result<(MetalStorage, Shape)> { ... }
    fn bwd(
        &self,
        _arg1: &Tensor,
        _arg2: &Tensor,
        _arg3: &Tensor,
        _res: &Tensor,
        _grad_res: &Tensor,
    ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> { ... }
}

Required Methods§

Source

fn name(&self) -> &'static str

Source

fn cpu_fwd( &self, s1: &CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout, s3: &CpuStorage, l3: &Layout, ) -> Result<(CpuStorage, Shape)>

The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

Provided Methods§

Source

fn cuda_fwd( &self, _: &CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout, ) -> Result<(CudaStorage, Shape)>

The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

Source

fn metal_fwd( &self, _: &MetalStorage, _: &Layout, _: &MetalStorage, _: &Layout, _: &MetalStorage, _: &Layout, ) -> Result<(MetalStorage, Shape)>

The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

Source

fn bwd( &self, _arg1: &Tensor, _arg2: &Tensor, _arg3: &Tensor, _res: &Tensor, _grad_res: &Tensor, ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)>

Implementors§