:py:mod:`neural_compressor.model.tensorflow_model` ================================================== .. py:module:: neural_compressor.model.tensorflow_model .. autoapi-nested-parse:: Class for Tensorflow model. Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: neural_compressor.model.tensorflow_model.TensorflowBaseModel neural_compressor.model.tensorflow_model.TensorflowSavedModelModel neural_compressor.model.tensorflow_model.TensorflowQATModel neural_compressor.model.tensorflow_model.TensorflowCheckpointModel neural_compressor.model.tensorflow_model.TensorflowModel Functions ~~~~~~~~~ .. autoapisummary:: neural_compressor.model.tensorflow_model.get_model_type neural_compressor.model.tensorflow_model.validate_graph_node neural_compressor.model.tensorflow_model.validate_and_inference_input_output neural_compressor.model.tensorflow_model.graph_session neural_compressor.model.tensorflow_model.graph_def_session neural_compressor.model.tensorflow_model.frozen_pb_session neural_compressor.model.tensorflow_model.load_saved_model neural_compressor.model.tensorflow_model.keras_session neural_compressor.model.tensorflow_model.slim_session neural_compressor.model.tensorflow_model.checkpoint_session neural_compressor.model.tensorflow_model.estimator_session neural_compressor.model.tensorflow_model.saved_model_session .. py:function:: get_model_type(model) Get Tensorflow mode type. :param model: model path or model object. :type model: string or model object :returns: model type :rtype: string .. py:function:: validate_graph_node(graph_def, node_names) Validate nodes exist in the graph_def. :param graph_def: tf.compat.v1.GraphDef object. :type graph_def: tf.compat.v1.GraphDef :param node_names: node names to be validated. :type node_names: list of string .. py:function:: validate_and_inference_input_output(graph_def, input_tensor_names, output_tensor_names) Validate and inference the input and output tensor names of graph_def. :param graph_def: tf.compat.v1.GraphDef object. :type graph_def: tf.compat.v1.GraphDef :param input_tensor_names: input_tensor_names of graph_def. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of graph_def. :type output_tensor_names: list of string :returns: validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: input_tensor_names (list of string) .. py:function:: graph_session(model, input_tensor_names, output_tensor_names, **kwargs) Helper to build session with tf.compat.v1.Graph. :param model: tf.compat.v1.Graph object. :type model: tf.compat.v1.Graph :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: graph_def_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with tf.compat.v1.GraphDef. :param model: tf.compat.v1.GraphDef object. :type model: tf.compat.v1.GraphDef :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object input_tensor_names (list of string): validated input_tensor_names output_tensor_names (list of string): validated output_tensor_names :rtype: sess (tf.compat.v1.Session) .. py:function:: frozen_pb_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with frozen pb. :param model: model path. :type model: string :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_names) Load graph_def from saved model with the default serving signature key. :param saved_model_dir: Directory of the SavedModel. :param saved_model_tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. :returns: The loaded GraphDef. input_tensors: List of input tensors. output_tensors: List of output tensors. :rtype: graph_def .. py:function:: keras_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with keras model. :param model: model path or tf.keras.Model object. :type model: string or tf.keras.Model :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: slim_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with slim model. :param model: model path. :type model: string :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: checkpoint_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with ckpt model. :param model: model path. :type model: string :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: validated output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: estimator_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with estimator model. :param model: tf.estimator.Estimator object. :type model: tf.estimator.Estimator :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :param kwargs: other required parameters, like input_fn. :type kwargs: dict :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with saved model. :param model: model path. :type model: string :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:class:: TensorflowBaseModel(model, **kwargs) Bases: :py:obj:`neural_compressor.model.base_model.BaseModel` Build Tensorflow Base Model. .. py:property:: name Renturn name. .. py:property:: weights Return weights. .. py:property:: q_config Return q_config. .. py:property:: workspace_path Return workspace path. .. py:property:: model_type Return model type. .. py:property:: model Return model itself. .. py:property:: graph_def Return graph defination. .. py:property:: graph_info Return graph info. .. py:property:: sess Return Session object. .. py:property:: graph Return model graph. .. py:property:: iter_op Return model iter op list. .. py:property:: input_tensor_names Return input tensor names. .. py:property:: output_tensor_names Return output tensor names. .. py:property:: input_node_names Return input node names. .. py:property:: output_node_names Return output node names. .. py:property:: input_tensor Return input tensor. .. py:property:: output_tensor Return output tensor. .. py:method:: framework() Return framework. .. py:method:: save(root=None) Save Tensorflow model. .. py:class:: TensorflowSavedModelModel(model, **kwargs) Bases: :py:obj:`TensorflowBaseModel` Build Tensorflow saved model. .. py:property:: model Return model itself. .. py:method:: get_all_weight_names() Get weight names of model. :returns: weight names list. :rtype: list .. py:method:: update_weights(tensor_name, new_tensor) Update model weights. .. py:method:: get_weight(tensor_name) Return model wight with a given tensor name. :param tensor_name: name of a tensor. :type tensor_name: str .. py:method:: report_sparsity() Get sparsity of the model. :returns: DataFrame of sparsity of each weight. total_sparsity (float): total sparsity of model. :rtype: df (DataFrame) .. py:method:: build_saved_model(root=None) Build Tensorflow saved model. :param root: path to saved model. Defaults to None. :type root: str, optional :returns: path to saved model. builder (tf.compat.v1.saved_model.builder.SavedModelBuilder): builds the SavedModel protocol buffer and saves variables and assets. :rtype: root (str) .. py:method:: save(root=None) Save Tensorflow model. .. py:class:: TensorflowQATModel(model='', **kwargs) Bases: :py:obj:`TensorflowSavedModelModel` Build Tensorflow QAT model. .. py:property:: model Return model itself. .. py:method:: save(root=None) Save Tensorflow QAT model. .. py:class:: TensorflowCheckpointModel(model, **kwargs) Bases: :py:obj:`TensorflowBaseModel` Build Tensorflow checkpoint model. .. py:property:: graph_def Return graph defination. .. py:class:: TensorflowModel Bases: :py:obj:`object` A wrapper to construct a Tensorflow Model.