pub enum GraphOptimizationLevel {
    Disable,
    Level1,
    Level2,
    Level3,
}
Expand description

ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially graph-level transformations, ranging from small graph simplifications and node eliminations to more complex node fusions and layout optimizations.

Graph optimizations are divided in several categories (or levels) based on their complexity and functionality. They can be performed either online or offline. In online mode, the optimizations are done before performing the inference, while in offline mode, the runtime saves the optimized graph to disk (most commonly used when converting an ONNX model to an ONNX Runtime model).

The optimizations belonging to one level are performed after the optimizations of the previous level have been applied (e.g., extended optimizations are applied after basic optimizations have been applied).

All optimizations (i.e. GraphOptimizationLevel::Level3) are enabled by default.

Online/offline mode

All optimizations can be performed either online or offline. In online mode, when initializing an inference session, we also apply all enabled graph optimizations before performing model inference. Applying all optimizations each time we initiate a session can add overhead to the model startup time (especially for complex models), which can be critical in production scenarios. This is where the offline mode can bring a lot of benefit. In offline mode, after performing graph optimizations, ONNX Runtime serializes the resulting model to disk. Subsequently, we can reduce startup time by using the already optimized model and disabling all optimizations.

Notes:

  • When running in offline mode, make sure to use the exact same options (e.g., execution providers, optimization level) and hardware as the target machine that the model inference will run on (e.g., you cannot run a model pre-optimized for a GPU execution provider on a machine that is equipped only with CPU).
  • When layout optimizations are enabled, the offline mode can only be used on compatible hardware to the environment when the offline model is saved. For example, if model has layout optimized for AVX2, the offline model would require CPUs that support AVX2.

Variants§

§

Disable

Disables all graph optimizations.

§

Level1

Level 1 includes semantics-preserving graph rewrites which remove redundant nodes and redundant computation. They run before graph partitioning and thus apply to all the execution providers. Available basic/level 1 graph optimizations are as follows:

  • Constant Folding: Statically computes parts of the graph that rely only on constant initializers. This eliminates the need to compute them during runtime.
  • Redundant node eliminations: Remove all redundant nodes without changing the graph structure. The following such optimizations are currently supported:
    • Identity Elimination
    • Slice Elimination
    • Unsqueeze Elimination
    • Dropout Elimination
  • Semantics-preserving node fusions : Fuse/fold multiple nodes into a single node. For example, Conv Add fusion folds the Add operator as the bias of the Conv operator. The following such optimizations are currently supported:
    • Conv Add Fusion
    • Conv Mul Fusion
    • Conv BatchNorm Fusion
    • Relu Clip Fusion
    • Reshape Fusion
§

Level2

Level 2 optimizations include complex node fusions. They are run after graph partitioning and are only applied to the nodes assigned to the CPU or CUDA execution provider. Available extended/level 2 graph optimizations are as follows:

OptimizationEPsComments
GEMM Activation FusionCPU
Matmul Add FusionCPU
Conv Activation FusionCPU
GELU FusionCPU, CUDA
Layer Normalization FusionCPU, CUDA
BERT Embedding Layer FusionCPU, CUDAFuses BERT embedding layers, layer normalization, & attention mask length
Attention Fusion*CPU, CUDA
Skip Layer Normalization FusionCPU, CUDAFuse bias of fully connected layers, skip connections, and layer normalization
Bias GELU FusionCPU, CUDAFuse bias of fully connected layers & GELU activation
GELU Approximation*CUDADisabled by default; enable with OrtSessionOptions::EnableGeluApproximation

NOTE: To optimize performance of the BERT model, approximation is used in GELU Approximation and Attention Fusion for the CUDA execution provider. The impact on accuracy is negligible based on our evaluation; F1 score for a BERT model on SQuAD v1.1 is almost the same (87.05 vs 87.03).

§

Level3

Level 3 optimizations include memory layout optimizations, which may optimize the graph to use the NCHWc memory layout rather than NCHW to improve spatial locality for some targets.

Trait Implementations§

source§

impl Debug for GraphOptimizationLevel

source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
source§

impl From<GraphOptimizationLevel> for GraphOptimizationLevel

source§

fn from(val: GraphOptimizationLevel) -> Self

Converts to this type from the input type.

Auto Trait Implementations§

Blanket Implementations§

source§

impl<T> Any for Twhere T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for Twhere T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for Twhere T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T> Instrument for T

source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
source§

impl<T, U> Into<U> for Twhere U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T, U> TryFrom<U> for Twhere U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for Twhere U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
source§

impl<T> WithSubscriber for T

source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more