neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.fuse_decomposed_in

Fuse Decomposed InstanceNorm Graph Rewriter.

Module Contents

Classes

FuseDecomposedINOptimizer

Fuse decomposed small ops into InstanceNorm.

Functions

node_name_from_input(node_name)

Strips off ports and other decorations to get the underlying node name.

node_from_map(node_map, name)

Pulls a node def from a dictionary for a given name.

values_from_const(node_def)

Extracts the values from a const NodeDef as a numpy ndarray.

valid_reshape_inputs(reshape_in0_ndef, reshape_in1_ndef)

Check if the inputs of the Reshape are valid.

bypass_reshape(input_node_map, input_name)

Get Reshape input nodes.

get_const_dim_count(node_def)

Get the number of dimensions for a Const node.

class neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.fuse_decomposed_in.FuseDecomposedINOptimizer(input_graph_def)[source]

Fuse decomposed small ops into InstanceNorm.

neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.fuse_decomposed_in.node_name_from_input(node_name)[source]

Strips off ports and other decorations to get the underlying node name.

neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.fuse_decomposed_in.node_from_map(node_map, name)[source]

Pulls a node def from a dictionary for a given name.

Parameters:
  • node_map – Dictionary containing an entry indexed by name for every node.

  • name – Identifies the node we want to find.

Returns:

NodeDef of the node with the given name.

Raises:

ValueError – If the node isn’t present in the dictionary.

neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.fuse_decomposed_in.values_from_const(node_def)[source]

Extracts the values from a const NodeDef as a numpy ndarray.

Parameters:

node_def – Const NodeDef that has the values we want to access.

Returns:

Numpy ndarray containing the values.

Raises:

ValueError – If the node isn’t a Const.

neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.fuse_decomposed_in.valid_reshape_inputs(reshape_in0_ndef, reshape_in1_ndef)[source]

Check if the inputs of the Reshape are valid.

neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.fuse_decomposed_in.bypass_reshape(input_node_map, input_name)[source]

Get Reshape input nodes.

neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.fuse_decomposed_in.get_const_dim_count(node_def)[source]

Get the number of dimensions for a Const node.

Parameters:

node_def – Const NodeDef.

Returns:

Number of dimensions for the Const node.