find_cuda_helper/
lib.rs

1//! Tiny crate for common logic for finding and including CUDA.
2
3use std::{
4    env,
5    path::{Path, PathBuf},
6};
7
8pub fn include_cuda() {
9    if env::var("DOCS_RS").is_err() && !cfg!(doc) {
10        let paths = find_cuda_lib_dirs();
11        if paths.is_empty() {
12            panic!("Could not find a cuda installation");
13        }
14        for path in paths {
15            println!("cargo:rustc-link-search=native={}", path.display());
16        }
17
18        println!("cargo:rustc-link-lib=dylib=cuda");
19        println!("cargo:rerun-if-changed=build.rs");
20        println!("cargo:rerun-if-env-changed=CUDA_LIBRARY_PATH");
21        println!("cargo:rerun-if-env-changed=CUDA_ROOT");
22        println!("cargo:rerun-if-env-changed=CUDA_PATH");
23        println!("cargo:rerun-if-env-changed=CUDA_TOOLKIT_ROOT_DIR");
24    }
25}
26
27// Returns true if the given path is a valid cuda installation
28fn is_cuda_root_path<P: AsRef<Path>>(path: P) -> bool {
29    path.as_ref().join("include").join("cuda.h").is_file()
30}
31
32pub fn find_cuda_root() -> Option<PathBuf> {
33    // search through the common environment variables first
34    for path in ["CUDA_PATH", "CUDA_ROOT", "CUDA_TOOLKIT_ROOT_DIR"]
35        .iter()
36        .filter_map(|name| std::env::var(*name).ok())
37    {
38        if is_cuda_root_path(&path) {
39            return Some(path.into());
40        }
41    }
42
43    // If it wasn't specified by env var, try the default installation paths
44    #[cfg(not(target_os = "windows"))]
45    let default_paths = ["/usr/local/cuda", "/opt/cuda"];
46    #[cfg(target_os = "windows")]
47    let default_paths = ["C:/CUDA"]; // TODO (AL): what's the actual path here?
48
49    for path in default_paths {
50        if is_cuda_root_path(path) {
51            return Some(path.into());
52        }
53    }
54
55    None
56}
57
58#[cfg(target_os = "windows")]
59pub fn find_cuda_lib_dirs() -> Vec<PathBuf> {
60    if let Some(root_path) = find_cuda_root() {
61        // To do this the right way, we check to see which target we're building for.
62        let target = env::var("TARGET")
63            .expect("cargo did not set the TARGET environment variable as required.");
64
65        // Targets use '-' separators. e.g. x86_64-pc-windows-msvc
66        let target_components: Vec<_> = target.as_str().split('-').collect();
67
68        // We check that we're building for Windows. This code assumes that the layout in
69        // CUDA_PATH matches Windows.
70        if target_components[2] != "windows" {
71            panic!(
72                "The CUDA_PATH variable is only used by cuda-sys on Windows. Your target is {}.",
73                target
74            );
75        }
76
77        // Sanity check that the second component of 'target' is "pc"
78        debug_assert_eq!(
79            "pc", target_components[1],
80            "Expected a Windows target to have the second component be 'pc'. Target: {}",
81            target
82        );
83
84        // x86_64 should use the libs in the "lib/x64" directory. If we ever support i686 (which
85        // does not ship with cublas support), its libraries are in "lib/Win32".
86        let lib_path = match target_components[0] {
87            "x86_64" => "x64",
88            "i686" => {
89                // lib path would be "Win32" if we support i686. "cublas" is not present in the
90                // 32-bit install.
91                panic!("Rust cuda-sys does not currently support 32-bit Windows.");
92            }
93            _ => {
94                panic!("Rust cuda-sys only supports the x86_64 Windows architecture.");
95            }
96        };
97
98        let lib_dir = root_path.join("lib").join(lib_path);
99
100        return if lib_dir.is_dir() {
101            vec![lib_dir]
102        } else {
103            vec![]
104        };
105    }
106
107    vec![]
108}
109
110pub fn read_env() -> Vec<PathBuf> {
111    if let Ok(path) = env::var("CUDA_LIBRARY_PATH") {
112        // The location of the libcuda, libcudart, and libcublas can be hardcoded with the
113        // CUDA_LIBRARY_PATH environment variable.
114        let split_char = if cfg!(target_os = "windows") {
115            ";"
116        } else {
117            ":"
118        };
119        path.split(split_char).map(PathBuf::from).collect()
120    } else {
121        vec![]
122    }
123}
124
125#[cfg(not(target_os = "windows"))]
126pub fn find_cuda_lib_dirs() -> Vec<PathBuf> {
127    let mut candidates = read_env();
128    candidates.push(PathBuf::from("/opt/cuda"));
129    candidates.push(PathBuf::from("/usr/local/cuda"));
130    for e in glob::glob("/usr/local/cuda-*").unwrap().flatten() {
131        candidates.push(e)
132    }
133
134    let mut valid_paths = vec![];
135    for base in &candidates {
136        let lib = PathBuf::from(base).join("lib64");
137        if lib.is_dir() {
138            valid_paths.push(lib.clone());
139            valid_paths.push(lib.join("stubs"));
140        }
141        let base = base.join("targets/x86_64-linux");
142        let header = base.join("include/cuda.h");
143        if header.is_file() {
144            valid_paths.push(base.join("lib"));
145            valid_paths.push(base.join("lib/stubs"));
146            continue;
147        }
148    }
149    valid_paths
150}
151
152#[cfg(target_os = "windows")]
153pub fn find_optix_root() -> Option<PathBuf> {
154    // the optix SDK installer sets OPTIX_ROOT_DIR whenever it installs.
155    // We also check OPTIX_ROOT first in case someone wants to override it without overriding
156    // the SDK-set variable.
157
158    env::var("OPTIX_ROOT")
159        .ok()
160        .or_else(|| env::var("OPTIX_ROOT_DIR").ok())
161        .map(PathBuf::from)
162}
163
164#[cfg(target_family = "unix")]
165pub fn find_optix_root() -> Option<PathBuf> {
166    env::var("OPTIX_ROOT")
167        .ok()
168        .or_else(|| env::var("OPTIX_ROOT_DIR").ok())
169        .map(PathBuf::from)
170}
171
172#[cfg(doc)]
173pub fn find_libnvvm_bin_dir() -> String {
174    String::new()
175}
176
177#[cfg(all(target_os = "windows", not(doc)))]
178pub fn find_libnvvm_bin_dir() -> String {
179    if env::var("DOCS_RS").is_ok() {
180        return String::new();
181    }
182    find_cuda_root()
183        .expect("Failed to find CUDA ROOT, make sure the CUDA SDK is installed and CUDA_PATH or CUDA_ROOT are set!")
184        .join("nvvm")
185        .join("lib")
186        .join("x64")
187        .to_string_lossy()
188        .into_owned()
189}
190
191#[cfg(all(target_os = "linux", not(doc)))]
192pub fn find_libnvvm_bin_dir() -> String {
193    if env::var("DOCS_RS").is_ok() {
194        return String::new();
195    }
196    find_cuda_root()
197        .expect("Failed to find CUDA ROOT, make sure the CUDA SDK is installed and CUDA_PATH or CUDA_ROOT are set!")
198        .join("nvvm")
199        .join("lib64")
200        .to_string_lossy()
201        .into_owned()
202}