torch_sys::c_generated

Function atg__triton_multi_head_attention

source
pub unsafe extern "C" fn atg__triton_multi_head_attention(
    out__: *mut *mut C_tensor,
    query_: *mut C_tensor,
    key_: *mut C_tensor,
    value_: *mut C_tensor,
    embed_dim_: i64,
    num_head_: i64,
    qkv_weight_: *mut C_tensor,
    qkv_bias_: *mut C_tensor,
    proj_weight_: *mut C_tensor,
    proj_bias_: *mut C_tensor,
    mask_: *mut C_tensor,
)