torch_sys::c_generated

Function atg__fused_sdp_choice

source
pub unsafe extern "C" fn atg__fused_sdp_choice(
    query_: *mut C_tensor,
    key_: *mut C_tensor,
    value_: *mut C_tensor,
    attn_mask_: *mut C_tensor,
    dropout_p_: f64,
    is_causal_: c_int,
    scale_v: f64,
    scale_null: i8,
    enable_gqa_: c_int,
) -> i64