:py:mod:`neural_compressor.model.tensorflow_model`
==================================================

.. py:module:: neural_compressor.model.tensorflow_model

.. autoapi-nested-parse::

   Class for Tensorflow model.



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   neural_compressor.model.tensorflow_model.TensorflowBaseModel
   neural_compressor.model.tensorflow_model.TensorflowSavedModelModel
   neural_compressor.model.tensorflow_model.TensorflowQATModel
   neural_compressor.model.tensorflow_model.TensorflowCheckpointModel
   neural_compressor.model.tensorflow_model.TensorflowModel



Functions
~~~~~~~~~

.. autoapisummary::

   neural_compressor.model.tensorflow_model.get_model_type
   neural_compressor.model.tensorflow_model.validate_graph_node
   neural_compressor.model.tensorflow_model.validate_and_inference_input_output
   neural_compressor.model.tensorflow_model.graph_session
   neural_compressor.model.tensorflow_model.graph_def_session
   neural_compressor.model.tensorflow_model.frozen_pb_session
   neural_compressor.model.tensorflow_model.load_saved_model
   neural_compressor.model.tensorflow_model.keras_session
   neural_compressor.model.tensorflow_model.slim_session
   neural_compressor.model.tensorflow_model.checkpoint_session
   neural_compressor.model.tensorflow_model.estimator_session
   neural_compressor.model.tensorflow_model.saved_model_session



.. py:function:: get_model_type(model)

   Get Tensorflow mode type.

   :param model: model path or model object.
   :type model: string or model object

   :returns: model type
   :rtype: string


.. py:function:: validate_graph_node(graph_def, node_names)

   Validate nodes exist in the graph_def.

   :param graph_def: tf.compat.v1.GraphDef object.
   :type graph_def: tf.compat.v1.GraphDef
   :param node_names: node names to be validated.
   :type node_names: list of string


.. py:function:: validate_and_inference_input_output(graph_def, input_tensor_names, output_tensor_names)

   Validate and inference the input and output tensor names of graph_def.

   :param graph_def: tf.compat.v1.GraphDef object.
   :type graph_def: tf.compat.v1.GraphDef
   :param input_tensor_names: input_tensor_names of graph_def.
   :type input_tensor_names: list of string
   :param output_tensor_names: output_tensor_names of graph_def.
   :type output_tensor_names: list of string

   :returns: validated input_tensor_names.
             output_tensor_names (list of string): validated output_tensor_names.
   :rtype: input_tensor_names (list of string)


.. py:function:: graph_session(model, input_tensor_names, output_tensor_names, **kwargs)

   Helper to build session with tf.compat.v1.Graph.

   :param model: tf.compat.v1.Graph object.
   :type model: tf.compat.v1.Graph
   :param input_tensor_names: input_tensor_names of model.
   :type input_tensor_names: list of string
   :param output_tensor_names: output_tensor_names of model.
   :type output_tensor_names: list of string

   :returns: tf.compat.v1.Session object.
             input_tensor_names (list of string): validated input_tensor_names.
             output_tensor_names (list of string): validated output_tensor_names.
   :rtype: sess (tf.compat.v1.Session)


.. py:function:: graph_def_session(model, input_tensor_names, output_tensor_names, **kwargs)

   Build session with tf.compat.v1.GraphDef.

   :param model: tf.compat.v1.GraphDef object.
   :type model: tf.compat.v1.GraphDef
   :param input_tensor_names: input_tensor_names of model.
   :type input_tensor_names: list of string
   :param output_tensor_names: output_tensor_names of model.
   :type output_tensor_names: list of string

   :returns: tf.compat.v1.Session object
             input_tensor_names (list of string): validated input_tensor_names
             output_tensor_names (list of string): validated output_tensor_names
   :rtype: sess (tf.compat.v1.Session)


.. py:function:: frozen_pb_session(model, input_tensor_names, output_tensor_names, **kwargs)

   Build session with frozen pb.

   :param model: model path.
   :type model: string
   :param input_tensor_names: input_tensor_names of model.
   :type input_tensor_names: list of string
   :param output_tensor_names: output_tensor_names of model.
   :type output_tensor_names: list of string

   :returns: tf.compat.v1.Session object.
             input_tensor_names (list of string): validated input_tensor_names.
             output_tensor_names (list of string): validated output_tensor_names.
   :rtype: sess (tf.compat.v1.Session)


.. py:function:: load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_names)

   Load graph_def from saved model with the default serving signature key.

   :param saved_model_dir: Directory of the SavedModel.
   :param saved_model_tags: Set of tags identifying the MetaGraphDef within the
                            SavedModel to analyze.

   :returns: The loaded GraphDef.
             input_tensors: List of input tensors.
             output_tensors: List of output tensors.
   :rtype: graph_def


.. py:function:: keras_session(model, input_tensor_names, output_tensor_names, **kwargs)

   Build session with keras model.

   :param model: model path or tf.keras.Model object.
   :type model: string or tf.keras.Model
   :param input_tensor_names: input_tensor_names of model.
   :type input_tensor_names: list of string
   :param output_tensor_names: output_tensor_names of model.
   :type output_tensor_names: list of string

   :returns: tf.compat.v1.Session object.
             input_tensor_names (list of string): validated input_tensor_names.
             output_tensor_names (list of string): validated output_tensor_names.
   :rtype: sess (tf.compat.v1.Session)


.. py:function:: slim_session(model, input_tensor_names, output_tensor_names, **kwargs)

   Build session with slim model.

   :param model: model path.
   :type model: string
   :param input_tensor_names: input_tensor_names of model.
   :type input_tensor_names: list of string
   :param output_tensor_names: output_tensor_names of model.
   :type output_tensor_names: list of string

   :returns: tf.compat.v1.Session object.
             input_tensor_names (list of string): validated input_tensor_names.
             output_tensor_names (list of string): validated output_tensor_names.
   :rtype: sess (tf.compat.v1.Session)


.. py:function:: checkpoint_session(model, input_tensor_names, output_tensor_names, **kwargs)

   Build session with ckpt model.

   :param model: model path.
   :type model: string
   :param input_tensor_names: input_tensor_names of model.
   :type input_tensor_names: list of string
   :param output_tensor_names: validated output_tensor_names of model.
   :type output_tensor_names: list of string

   :returns: tf.compat.v1.Session object.
             input_tensor_names (list of string): validated input_tensor_names.
             output_tensor_names (list of string): validated output_tensor_names.
   :rtype: sess (tf.compat.v1.Session)


.. py:function:: estimator_session(model, input_tensor_names, output_tensor_names, **kwargs)

   Build session with estimator model.

   :param model: tf.estimator.Estimator object.
   :type model: tf.estimator.Estimator
   :param input_tensor_names: input_tensor_names of model.
   :type input_tensor_names: list of string
   :param output_tensor_names: output_tensor_names of model.
   :type output_tensor_names: list of string
   :param kwargs: other required parameters, like input_fn.
   :type kwargs: dict

   :returns: tf.compat.v1.Session object.
             input_tensor_names (list of string): validated input_tensor_names.
             output_tensor_names (list of string): validated output_tensor_names.
   :rtype: sess (tf.compat.v1.Session)


.. py:function:: saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs)

   Build session with saved model.

   :param model: model path.
   :type model: string
   :param input_tensor_names: input_tensor_names of model.
   :type input_tensor_names: list of string
   :param output_tensor_names: output_tensor_names of model.
   :type output_tensor_names: list of string

   :returns: tf.compat.v1.Session object.
             input_tensor_names (list of string): validated input_tensor_names.
             output_tensor_names (list of string): validated output_tensor_names.
   :rtype: sess (tf.compat.v1.Session)


.. py:class:: TensorflowBaseModel(model, **kwargs)

   Bases: :py:obj:`neural_compressor.model.base_model.BaseModel`

   Build Tensorflow Base Model.

   .. py:property:: name

      Renturn name.

   .. py:property:: weights

      Return weights.

   .. py:property:: q_config

      Return q_config.

   .. py:property:: workspace_path

      Return workspace path.

   .. py:property:: model_type

      Return model type.

   .. py:property:: model

      Return model itself.

   .. py:property:: graph_def

      Return graph defination.

   .. py:property:: graph_info

      Return graph info.

   .. py:property:: sess

      Return Session object.

   .. py:property:: graph

      Return model graph.

   .. py:property:: iter_op

      Return model iter op list.

   .. py:property:: input_tensor_names

      Return input tensor names.

   .. py:property:: output_tensor_names

      Return output tensor names.

   .. py:property:: input_node_names

      Return input node names.

   .. py:property:: output_node_names

      Return output node names.

   .. py:property:: input_tensor

      Return input tensor.

   .. py:property:: output_tensor

      Return output tensor.

   .. py:method:: framework()

      Return framework.


   .. py:method:: save(root=None)

      Save Tensorflow model.



.. py:class:: TensorflowSavedModelModel(model, **kwargs)

   Bases: :py:obj:`TensorflowBaseModel`

   Build Tensorflow saved model.

   .. py:property:: model

      Return model itself.

   .. py:method:: get_all_weight_names()

      Get weight names of model.

      :returns: weight names list.
      :rtype: list


   .. py:method:: update_weights(tensor_name, new_tensor)

      Update model weights.


   .. py:method:: get_weight(tensor_name)

      Return model wight with a given tensor name.

      :param tensor_name: name of a tensor.
      :type tensor_name: str


   .. py:method:: report_sparsity()

      Get sparsity of the model.

      :returns: DataFrame of sparsity of each weight.
                total_sparsity (float): total sparsity of model.
      :rtype: df (DataFrame)


   .. py:method:: build_saved_model(root=None)

      Build Tensorflow saved model.

      :param root: path to saved model. Defaults to None.
      :type root: str, optional

      :returns: path to saved model.
                builder (tf.compat.v1.saved_model.builder.SavedModelBuilder): builds
                    the SavedModel protocol buffer and saves variables and assets.
      :rtype: root (str)


   .. py:method:: save(root=None)

      Save Tensorflow model.



.. py:class:: TensorflowQATModel(model='', **kwargs)

   Bases: :py:obj:`TensorflowSavedModelModel`

   Build Tensorflow QAT model.

   .. py:property:: model

      Return model itself.

   .. py:method:: save(root=None)

      Save Tensorflow QAT model.



.. py:class:: TensorflowCheckpointModel(model, **kwargs)

   Bases: :py:obj:`TensorflowBaseModel`

   Build Tensorflow checkpoint model.

   .. py:property:: graph_def

      Return graph defination.


.. py:class:: TensorflowModel

   Bases: :py:obj:`object`

   A wrapper to construct a Tensorflow Model.