torch_sys::c_generated

Function atg__transformer_encoder_layer_fwd

source
pub unsafe extern "C" fn atg__transformer_encoder_layer_fwd(
    out__: *mut *mut C_tensor,
    src_: *mut C_tensor,
    embed_dim_: i64,
    num_heads_: i64,
    qkv_weight_: *mut C_tensor,
    qkv_bias_: *mut C_tensor,
    proj_weight_: *mut C_tensor,
    proj_bias_: *mut C_tensor,
    use_gelu_: c_int,
    norm_first_: c_int,
    eps_: f64,
    norm_weight_1_: *mut C_tensor,
    norm_bias_1_: *mut C_tensor,
    norm_weight_2_: *mut C_tensor,
    norm_bias_2_: *mut C_tensor,
    ffn_weight_1_: *mut C_tensor,
    ffn_bias_1_: *mut C_tensor,
    ffn_weight_2_: *mut C_tensor,
    ffn_bias_2_: *mut C_tensor,
    mask_: *mut C_tensor,
    mask_type_v: i64,
    mask_type_null: i8,
)