use std::{
env,
path::{Path, PathBuf},
};
pub fn include_cuda() {
if env::var("DOCS_RS").is_err() && !cfg!(doc) {
let paths = find_cuda_lib_dirs();
if paths.is_empty() {
panic!("Could not find a cuda installation");
}
for path in paths {
println!("cargo:rustc-link-search=native={}", path.display());
}
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=CUDA_LIBRARY_PATH");
println!("cargo:rerun-if-env-changed=CUDA_ROOT");
println!("cargo:rerun-if-env-changed=CUDA_PATH");
println!("cargo:rerun-if-env-changed=CUDA_TOOLKIT_ROOT_DIR");
}
}
fn is_cuda_root_path<P: AsRef<Path>>(path: P) -> bool {
path.as_ref().join("include").join("cuda.h").is_file()
}
pub fn find_cuda_root() -> Option<PathBuf> {
for path in ["CUDA_PATH", "CUDA_ROOT", "CUDA_TOOLKIT_ROOT_DIR"]
.iter()
.filter_map(|name| std::env::var(*name).ok())
{
if is_cuda_root_path(&path) {
return Some(path.into());
}
}
#[cfg(not(target_os = "windows"))]
let default_paths = ["/usr/local/cuda", "/opt/cuda"];
#[cfg(target_os = "windows")]
let default_paths = ["C:/CUDA"]; for path in default_paths {
if is_cuda_root_path(path) {
return Some(path.into());
}
}
None
}
#[cfg(target_os = "windows")]
pub fn find_cuda_lib_dirs() -> Vec<PathBuf> {
if let Some(root_path) = find_cuda_root() {
let target = env::var("TARGET")
.expect("cargo did not set the TARGET environment variable as required.");
let target_components: Vec<_> = target.as_str().split('-').collect();
if target_components[2] != "windows" {
panic!(
"The CUDA_PATH variable is only used by cuda-sys on Windows. Your target is {}.",
target
);
}
debug_assert_eq!(
"pc", target_components[1],
"Expected a Windows target to have the second component be 'pc'. Target: {}",
target
);
let lib_path = match target_components[0] {
"x86_64" => "x64",
"i686" => {
panic!("Rust cuda-sys does not currently support 32-bit Windows.");
}
_ => {
panic!("Rust cuda-sys only supports the x86_64 Windows architecture.");
}
};
let lib_dir = root_path.join("lib").join(lib_path);
return if lib_dir.is_dir() {
vec![lib_dir]
} else {
vec![]
};
}
vec![]
}
pub fn read_env() -> Vec<PathBuf> {
if let Ok(path) = env::var("CUDA_LIBRARY_PATH") {
let split_char = if cfg!(target_os = "windows") {
";"
} else {
":"
};
path.split(split_char).map(PathBuf::from).collect()
} else {
vec![]
}
}
#[cfg(not(target_os = "windows"))]
pub fn find_cuda_lib_dirs() -> Vec<PathBuf> {
let mut candidates = read_env();
candidates.push(PathBuf::from("/opt/cuda"));
candidates.push(PathBuf::from("/usr/local/cuda"));
for e in glob::glob("/usr/local/cuda-*").unwrap().flatten() {
candidates.push(e)
}
let mut valid_paths = vec![];
for base in &candidates {
let lib = PathBuf::from(base).join("lib64");
if lib.is_dir() {
valid_paths.push(lib.clone());
valid_paths.push(lib.join("stubs"));
}
let base = base.join("targets/x86_64-linux");
let header = base.join("include/cuda.h");
if header.is_file() {
valid_paths.push(base.join("lib"));
valid_paths.push(base.join("lib/stubs"));
continue;
}
}
valid_paths
}
#[cfg(target_os = "windows")]
pub fn find_optix_root() -> Option<PathBuf> {
env::var("OPTIX_ROOT")
.ok()
.or_else(|| env::var("OPTIX_ROOT_DIR").ok())
.map(PathBuf::from)
}
#[cfg(target_family = "unix")]
pub fn find_optix_root() -> Option<PathBuf> {
env::var("OPTIX_ROOT")
.ok()
.or_else(|| env::var("OPTIX_ROOT_DIR").ok())
.map(PathBuf::from)
}
#[cfg(doc)]
pub fn find_libnvvm_bin_dir() -> String {
String::new()
}
#[cfg(all(target_os = "windows", not(doc)))]
pub fn find_libnvvm_bin_dir() -> String {
if env::var("DOCS_RS").is_ok() {
return String::new();
}
find_cuda_root()
.expect("Failed to find CUDA ROOT, make sure the CUDA SDK is installed and CUDA_PATH or CUDA_ROOT are set!")
.join("nvvm")
.join("lib")
.join("x64")
.to_string_lossy()
.into_owned()
}
#[cfg(all(target_os = "linux", not(doc)))]
pub fn find_libnvvm_bin_dir() -> String {
if env::var("DOCS_RS").is_ok() {
return String::new();
}
find_cuda_root()
.expect("Failed to find CUDA ROOT, make sure the CUDA SDK is installed and CUDA_PATH or CUDA_ROOT are set!")
.join("nvvm")
.join("lib64")
.to_string_lossy()
.into_owned()
}