tch 0.18.0

Rust wrappers for the PyTorch C++ api (libtorch).
Documentation
[dependencies.clap]
features = ["derive"]
optional = true
version = "4.2.4"

[dependencies.cpython]
optional = true
version = "0.7.1"

[dependencies.half]
version = "2"

[dependencies.image]
optional = true
version = "0.24.5"

[dependencies.lazy_static]
version = "1.3.0"

[dependencies.libc]
version = "0.2.0"

[dependencies.memmap2]
optional = true
version = "0.6.1"

[dependencies.ndarray]
version = "0.15"

[dependencies.rand]
version = "0.8"

[dependencies.regex]
optional = true
version = "1.6.0"

[dependencies.safetensors]
version = "0.3.0"

[dependencies.serde_json]
optional = true
version = "1.0.96"

[dependencies.thiserror]
version = "1"

[dependencies.torch-sys]
version = "0.18.0"

[dependencies.zip]
version = "0.6"

[dev-dependencies.anyhow]
version = "^1.0.60"

[[example]]
name = "basics"
path = "examples/basics.rs"

[[example]]
name = "char-rnn"
path = "examples/char-rnn/main.rs"

[[example]]
name = "cifar"
path = "examples/cifar/main.rs"

[[example]]
name = "custom-optimizer"
path = "examples/custom-optimizer/main.rs"

[[example]]
name = "gan"
path = "examples/gan/main.rs"

[[example]]
name = "jit"
path = "examples/jit/main.rs"

[[example]]
name = "jit-quantized"
path = "examples/jit-quantized/main.rs"

[[example]]
name = "jit-trace"
path = "examples/jit-trace/main.rs"

[[example]]
name = "jit-train"
path = "examples/jit-train/main.rs"

[[example]]
name = "llama"
path = "examples/llama/main.rs"
required-features = ["regex", "clap", "serde_json", "memmap2"]

[[example]]
name = "memory_test"
path = "examples/memory_test.rs"

[[example]]
name = "min-gpt"
path = "examples/min-gpt/main.rs"

[[example]]
name = "mnist"
path = "examples/mnist/main.rs"

[[example]]
name = "neural-style-transfer"
path = "examples/neural-style-transfer/main.rs"

[[example]]
name = "pretrained-models"
path = "examples/pretrained-models/main.rs"

[[example]]
name = "reinforcement-learning"
path = "examples/reinforcement-learning/main.rs"
required-features = ["rl-python"]

[[example]]
name = "stable-diffusion"
path = "examples/stable-diffusion/main.rs"
required-features = ["regex"]

[[example]]
name = "tensor-tools"
path = "examples/tensor-tools.rs"

[[example]]
name = "transfer-learning"
path = "examples/transfer-learning/main.rs"

[[example]]
name = "translation"
path = "examples/translation/main.rs"

[[example]]
name = "vae"
path = "examples/vae/main.rs"

[[example]]
name = "yolo"
path = "examples/yolo/main.rs"

[features]
cuda-tests = []
doc-only = ["torch-sys/doc-only"]
download-libtorch = ["torch-sys/download-libtorch"]
python-extension = ["torch-sys/python-extension"]
rl-python = ["cpython"]

[lib]
name = "tch"
path = "src/lib.rs"

[package]
authors = ["Laurent Mazare <lmazare@gmail.com>"]
autobenches = false
autobins = false
autoexamples = false
autotests = false
build = "build.rs"
categories = ["science"]
description = "Rust wrappers for the PyTorch C++ api (libtorch)."
edition = "2021"
exclude = ["examples/stable-diffusion/media/*"]
keywords = ["pytorch", "deep-learning", "machine-learning"]
license = "MIT/Apache-2.0"
name = "tch"
readme = "README.md"
repository = "https://github.com/LaurentMazare/tch-rs"
version = "0.18.0"

[package.metadata.docs.rs]
features = ["doc-only"]

[[test]]
name = "autocast"
path = "tests/autocast.rs"

[[test]]
name = "data_tests"
path = "tests/data_tests.rs"

[[test]]
name = "device_tests"
path = "tests/device_tests.rs"

[[test]]
name = "display_tests"
path = "tests/display_tests.rs"

[[test]]
name = "jit_tests"
path = "tests/jit_tests.rs"

[[test]]
name = "nn_tests"
path = "tests/nn_tests.rs"

[[test]]
name = "serialization_tests"
path = "tests/serialization_tests.rs"

[[test]]
name = "tensor_indexing"
path = "tests/tensor_indexing.rs"

[[test]]
name = "tensor_tests"
path = "tests/tensor_tests.rs"

[[test]]
name = "test_utils"
path = "tests/test_utils.rs"

[[test]]
name = "var_store"
path = "tests/var_store.rs"

[[test]]
name = "vision_tests"
path = "tests/vision_tests.rs"