noxla 0.4.0

Rust bindings to XLA's C++ API
use std::io::Read;
use std::path::{Path, PathBuf};
use std::{env, io};

use anyhow::Context;
use flate2::read::GzDecoder;
use tar::Archive;

#[derive(Clone, Copy, Eq, PartialEq)]
#[allow(clippy::enum_variant_names)]
enum OS {
    Linux,
    MacOS,
    Windows,
}

impl OS {
    fn get() -> Self {
        let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
        match os.as_str() {
            "linux" => Self::Linux,
            "macos" => Self::MacOS,
            "windows" => Self::Windows,
            os => panic!("Unsupported system {os}"),
        }
    }
}

fn env_var_rerun(name: &str) -> Option<String> {
    println!("cargo:rerun-if-env-changed={name}");
    env::var(name).ok()
}

fn main() -> anyhow::Result<()> {
    let out_dir = PathBuf::from(env::var("OUT_DIR").expect("missing out dir"));
    let os = OS::get();

    let xla_dir = env_var_rerun("XLA_EXTENSION_DIR")
        .map_or_else(|| out_dir.join("xla_extension"), PathBuf::from);
    if !xla_dir.exists() {
        if cfg!(feature = "shared") {
            download_shared_xla(&out_dir)?;
        } else {
            download_xla(&out_dir)?;
        }
    }

    let mut config = cpp_build::Config::new();
    config
        .flag("-std=c++17")
        .flag("-DLLVM_ON_UNIX=1")
        .flag("-DLLVM_VERSION_STRING=")
        .flag(&format!("-isystem{}", xla_dir.join("include").display()))
        .file("./vendor/jaxlib/cpu/cpu_kernels.cc")
        .file("./vendor/jaxlib/cpu/lapack_kernels.cc")
        .include("./vendor");
    if cfg!(feature = "cuda") {
        find_cuda_helper::include_cuda();
        let cuda = find_cuda_helper::find_cuda_root().unwrap();
        config
            .include(cuda.join("include"))
            .include(&cuda)
            .define("JAX_GPU_CUDA", Some("1"))
            .define("EL_CUDA", Some("1"))
            .file("./vendor/jaxlib/gpu/cholesky_update_kernel.cc")
            .file("./vendor/jaxlib/gpu/lu_pivot_kernels.cc")
            .file("./vendor/jaxlib/gpu/prng_kernels.cc");
        let mut cuda_config = cc::Build::new();
        cuda_config
            .flag("-std=c++17")
            .flag("-DLLVM_ON_UNIX=1")
            .flag("-DLLVM_VERSION_STRING=")
            .include(xla_dir.join("include"))
            .include("./vendor");
        cuda_config
            .flag("--disable-warnings")
            .cuda(true)
            .cudart("static")
            .include(cuda)
            .flag("-gencode")
            .flag("arch=compute_89,code=sm_89")
            .file("./vendor/jaxlib/gpu/blas_kernels.cc")
            .file("./vendor/jaxlib/gpu/cholesky_update_kernel.cu")
            .file("./vendor/jaxlib/gpu/gpu_kernel_helpers.cc")
            .file("./vendor/jaxlib/gpu/gpu_kernels.cc")
            .file("./vendor/jaxlib/gpu/lu_pivot_kernels.cu")
            .file("./vendor/jaxlib/gpu/prng_kernels.cu")
            .file("./vendor/jaxlib/gpu/rnn_kernels.cc")
            .file("./vendor/jaxlib/gpu/solver_kernels.cc")
            .file("./vendor/jaxlib/gpu/sparse_kernels.cc")
            .define("JAX_GPU_CUDA", Some("1"));
        cuda_config.compile("jaxlib_cuda");
    }
    config.build("src/lib.rs");
    println!("cargo:rerun-if-changed=src/executable.rs");
    println!("cargo:rerun-if-changed=src/literal.rs");
    println!("cargo:rerun-if-changed=src/op.rs");
    println!("cargo:rerun-if-changed=src/shape.rs");
    println!("cargo:rerun-if-changed=src/native_type.rs");
    println!("cargo:rerun-if-changed=src/builder.rs");
    println!("cargo:rerun-if-changed=src/error.rs");
    println!("cargo:rerun-if-changed=src/client.rs");
    println!("cargo:rerun-if-changed=src/buffer.rs");
    println!("cargo:rerun-if-changed=src/computation.rs");
    println!("cargo:rerun-if-changed=src/hlo_module.rs");

    let jax_metal_dir =
        env_var_rerun("JAX_METAL_DIR").map_or_else(|| out_dir.join("jax_metal"), PathBuf::from);
    if !jax_metal_dir.exists() && cfg!(target_os = "macos") {
        download_jax_metal(&jax_metal_dir)?;
    }

    // Exit early on docs.rs as the C++ library would not be available.
    if std::env::var("DOCS_RS").is_ok() {
        return Ok(());
    }

    // The --copy-dt-needed-entries -lstdc++ are helpful to get around some
    // "DSO missing from command line" error
    // undefined reference to symbol '_ZStlsIcSt11char_traitsIcESaIcEERSt13basic_ostreamIT_T0_ES7_RKNSt7__cxx1112basic_stringIS4_S5_T1_EE@@GLIBCXX_3.4.21'
    if os == OS::Linux {
        println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries");
        println!("cargo:rustc-link-arg=-Wl,-lstdc++");
    }

    if cfg!(feature = "shared") {
        println!("cargo:rustc-link-search={}", xla_dir.join("lib").display());
        println!("cargo:rustc-link-lib=dylib=xla_extension");
    } else {
        println!(
            "cargo:rustc-link-search=native={}",
            xla_dir.join("lib").display()
        );
        println!("cargo:rustc-link-lib=static=xla_extension");
    }
    if os == OS::MacOS {
        println!("cargo:rustc-link-lib=framework=Foundation");
        println!("cargo:rustc-link-lib=framework=SystemConfiguration");
        println!("cargo:rustc-link-lib=framework=Security");
    }

    Ok(())
}

fn download_jax_metal(jax_dir: &Path) -> anyhow::Result<()> {
    let url = "https://files.pythonhosted.org/packages/7e/59/ff91dc65e7f945479b08509185d07de0c947e81c07705367b018cb072ee9/jax_metal-0.0.4-py3-none-macosx_11_0_arm64.whl";
    let buf = download_file(url)?;
    let mut archive = zip::ZipArchive::new(io::Cursor::new(buf))?;
    archive.extract(jax_dir)?;
    Ok(())
}

fn download_shared_xla(xla_dir: &Path) -> anyhow::Result<()> {
    let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
    let arch = env::var("CARGO_CFG_TARGET_ARCH").expect("Unable to get TARGET_ARCH");
    let url = match (os.as_str(), arch.as_str()) {
        ("macos", arch) => format!("https://github.com/elixir-nx/xla/releases/download/v0.7.0/xla_extension-{}-darwin-cpu.tar.gz", arch),
        ("linux", arch) => {
            if true {
              format!("https://github.com/elixir-nx/xla/releases/download/v0.7.0/xla_extension-{}-linux-gnu-cuda12.tar.gz", arch)
            }else{
              format!("https://github.com/elixir-nx/xla/releases/download/v0.7.0/xla_extension-{}-linux-gnu-cpu.tar.gz", arch)
            }
        }

        (os, arch) => panic!("{}-{} is an unsupported platform", os, arch)
    };

    let buf = download_file(&url)?;
    let mut bytes = io::Cursor::new(buf);

    let tar = GzDecoder::new(&mut bytes);
    let mut archive = Archive::new(tar);
    archive.unpack(xla_dir)?;

    Ok(())
}

fn download_xla(xla_dir: &Path) -> anyhow::Result<()> {
    let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
    let arch = env::var("CARGO_CFG_TARGET_ARCH").expect("Unable to get TARGET_ARCH");
    let url = match (os.as_str(), arch.as_str()) {
        ("macos", arch) => format!("https://github.com/elodin-sys/xla/releases/download/v0.5.4/xla_extension-{}-darwin-cpu.tar.gz", arch),
        ("linux", arch) => format!("https://github.com/elodin-sys/xla/releases/download/v0.5.4/xla_extension-{}-linux-gnu-cpu.tar.gz", arch),
        (os, arch) => panic!("{}-{} is an unsupported platform", os, arch)
    };
    let buf = download_file(&url)?;
    let mut bytes = io::Cursor::new(buf);

    let tar = GzDecoder::new(&mut bytes);
    let mut archive = Archive::new(tar);
    archive.unpack(xla_dir)?;

    Ok(())
}

fn download_file(url: &str) -> anyhow::Result<Vec<u8>> {
    let res = ureq::get(url).call()?;
    let content_length = res
        .header("Content-Length")
        .context("Content-Length header not found")?
        .parse::<usize>()?;
    let mut buf = Vec::with_capacity(content_length);
    res.into_reader()
        .take(content_length as u64)
        .read_to_end(&mut buf)
        .context("Failed to read response")?;
    Ok(buf)
}