neural_compressor.model.tensorflow_model

Class for Tensorflow model.

Module Contents

Classes

TensorflowBaseModel

Build Tensorflow Base Model.

TensorflowSavedModelModel

Build Tensorflow saved model.

TensorflowQATModel

Build Tensorflow QAT model.

TensorflowCheckpointModel

Build Tensorflow checkpoint model.

TensorflowModel

A wrapper to construct a Tensorflow Model.

Functions

get_model_type(model)

Get Tensorflow mode type.

validate_graph_node(graph_def, node_names)

Validate nodes exist in the graph_def.

validate_and_inference_input_output(graph_def, ...)

Validate and inference the input and output tensor names of graph_def.

graph_session(model, input_tensor_names, ...)

Helper to build session with tf.compat.v1.Graph.

graph_def_session(model, input_tensor_names, ...)

Build session with tf.compat.v1.GraphDef.

frozen_pb_session(model, input_tensor_names, ...)

Build session with frozen pb.

load_saved_model(model, saved_model_tags, ...)

Load graph_def from saved model with the default serving signature key.

keras_session(model, input_tensor_names, ...)

Build session with keras model.

slim_session(model, input_tensor_names, ...)

Build session with slim model.

checkpoint_session(model, input_tensor_names, ...)

Build session with ckpt model.

estimator_session(model, input_tensor_names, ...)

Build session with estimator model.

saved_model_session(model, input_tensor_names, ...)

Build session with saved model.

neural_compressor.model.tensorflow_model.get_model_type(model)[source]

Get Tensorflow mode type.

Parameters:

model (string or model object) – model path or model object.

Returns:

model type

Return type:

string

neural_compressor.model.tensorflow_model.validate_graph_node(graph_def, node_names)[source]

Validate nodes exist in the graph_def.

Parameters:
  • graph_def (tf.compat.v1.GraphDef) – tf.compat.v1.GraphDef object.

  • node_names (list of string) – node names to be validated.

neural_compressor.model.tensorflow_model.validate_and_inference_input_output(graph_def, input_tensor_names, output_tensor_names)[source]

Validate and inference the input and output tensor names of graph_def.

Parameters:
  • graph_def (tf.compat.v1.GraphDef) – tf.compat.v1.GraphDef object.

  • input_tensor_names (list of string) – input_tensor_names of graph_def.

  • output_tensor_names (list of string) – output_tensor_names of graph_def.

Returns:

validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names.

Return type:

input_tensor_names (list of string)

neural_compressor.model.tensorflow_model.graph_session(model, input_tensor_names, output_tensor_names, **kwargs)[source]

Helper to build session with tf.compat.v1.Graph.

Parameters:
  • model (tf.compat.v1.Graph) – tf.compat.v1.Graph object.

  • input_tensor_names (list of string) – input_tensor_names of model.

  • output_tensor_names (list of string) – output_tensor_names of model.

Returns:

tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names.

Return type:

sess (tf.compat.v1.Session)

neural_compressor.model.tensorflow_model.graph_def_session(model, input_tensor_names, output_tensor_names, **kwargs)[source]

Build session with tf.compat.v1.GraphDef.

Parameters:
  • model (tf.compat.v1.GraphDef) – tf.compat.v1.GraphDef object.

  • input_tensor_names (list of string) – input_tensor_names of model.

  • output_tensor_names (list of string) – output_tensor_names of model.

Returns:

tf.compat.v1.Session object input_tensor_names (list of string): validated input_tensor_names output_tensor_names (list of string): validated output_tensor_names

Return type:

sess (tf.compat.v1.Session)

neural_compressor.model.tensorflow_model.frozen_pb_session(model, input_tensor_names, output_tensor_names, **kwargs)[source]

Build session with frozen pb.

Parameters:
  • model (string) – model path.

  • input_tensor_names (list of string) – input_tensor_names of model.

  • output_tensor_names (list of string) – output_tensor_names of model.

Returns:

tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names.

Return type:

sess (tf.compat.v1.Session)

neural_compressor.model.tensorflow_model.load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_names)[source]

Load graph_def from saved model with the default serving signature key.

Parameters:
  • saved_model_dir – Directory of the SavedModel.

  • saved_model_tags – Set of tags identifying the MetaGraphDef within the SavedModel to analyze.

Returns:

The loaded GraphDef. input_tensors: List of input tensors. output_tensors: List of output tensors.

Return type:

graph_def

neural_compressor.model.tensorflow_model.keras_session(model, input_tensor_names, output_tensor_names, **kwargs)[source]

Build session with keras model.

Parameters:
  • model (string or tf.keras.Model) – model path or tf.keras.Model object.

  • input_tensor_names (list of string) – input_tensor_names of model.

  • output_tensor_names (list of string) – output_tensor_names of model.

Returns:

tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names.

Return type:

sess (tf.compat.v1.Session)

neural_compressor.model.tensorflow_model.slim_session(model, input_tensor_names, output_tensor_names, **kwargs)[source]

Build session with slim model.

Parameters:
  • model (string) – model path.

  • input_tensor_names (list of string) – input_tensor_names of model.

  • output_tensor_names (list of string) – output_tensor_names of model.

Returns:

tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names.

Return type:

sess (tf.compat.v1.Session)

neural_compressor.model.tensorflow_model.checkpoint_session(model, input_tensor_names, output_tensor_names, **kwargs)[source]

Build session with ckpt model.

Parameters:
  • model (string) – model path.

  • input_tensor_names (list of string) – input_tensor_names of model.

  • output_tensor_names (list of string) – validated output_tensor_names of model.

Returns:

tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names.

Return type:

sess (tf.compat.v1.Session)

neural_compressor.model.tensorflow_model.estimator_session(model, input_tensor_names, output_tensor_names, **kwargs)[source]

Build session with estimator model.

Parameters:
  • model (tf.estimator.Estimator) – tf.estimator.Estimator object.

  • input_tensor_names (list of string) – input_tensor_names of model.

  • output_tensor_names (list of string) – output_tensor_names of model.

  • kwargs (dict) – other required parameters, like input_fn.

Returns:

tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names.

Return type:

sess (tf.compat.v1.Session)

neural_compressor.model.tensorflow_model.saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs)[source]

Build session with saved model.

Parameters:
  • model (string) – model path.

  • input_tensor_names (list of string) – input_tensor_names of model.

  • output_tensor_names (list of string) – output_tensor_names of model.

Returns:

tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names.

Return type:

sess (tf.compat.v1.Session)

class neural_compressor.model.tensorflow_model.TensorflowBaseModel(model, **kwargs)[source]

Build Tensorflow Base Model.

class neural_compressor.model.tensorflow_model.TensorflowSavedModelModel(model, **kwargs)[source]

Build Tensorflow saved model.

class neural_compressor.model.tensorflow_model.TensorflowQATModel(model='', **kwargs)[source]

Build Tensorflow QAT model.

class neural_compressor.model.tensorflow_model.TensorflowCheckpointModel(model, **kwargs)[source]

Build Tensorflow checkpoint model.

class neural_compressor.model.tensorflow_model.TensorflowModel[source]

A wrapper to construct a Tensorflow Model.