:py:mod:`neural_compressor.model.tensorflow_model` ================================================== .. py:module:: neural_compressor.model.tensorflow_model .. autoapi-nested-parse:: Class for Tensorflow model. Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: neural_compressor.model.tensorflow_model.TensorflowBaseModel neural_compressor.model.tensorflow_model.TensorflowSavedModelModel neural_compressor.model.tensorflow_model.TensorflowLLMModel neural_compressor.model.tensorflow_model.TensorflowQATModel neural_compressor.model.tensorflow_model.TensorflowCheckpointModel neural_compressor.model.tensorflow_model.TensorflowModel Functions ~~~~~~~~~ .. autoapisummary:: neural_compressor.model.tensorflow_model.get_model_type neural_compressor.model.tensorflow_model.validate_graph_node neural_compressor.model.tensorflow_model.validate_and_inference_input_output neural_compressor.model.tensorflow_model.graph_session neural_compressor.model.tensorflow_model.graph_def_session neural_compressor.model.tensorflow_model.frozen_pb_session neural_compressor.model.tensorflow_model.load_saved_model neural_compressor.model.tensorflow_model.try_loading_keras neural_compressor.model.tensorflow_model.keras_session neural_compressor.model.tensorflow_model.slim_session neural_compressor.model.tensorflow_model.checkpoint_session neural_compressor.model.tensorflow_model.estimator_session neural_compressor.model.tensorflow_model.saved_model_session .. py:function:: get_model_type(model) Get Tensorflow mode type. :param model: model path or model object. :type model: string or model object :returns: model type :rtype: string .. py:function:: validate_graph_node(graph_def, node_names) Validate nodes exist in the graph_def. :param graph_def: tf.compat.v1.GraphDef object. :type graph_def: tf.compat.v1.GraphDef :param node_names: node names to be validated. :type node_names: list of string .. py:function:: validate_and_inference_input_output(graph_def, input_tensor_names, output_tensor_names) Validate and inference the input and output tensor names of graph_def. :param graph_def: tf.compat.v1.GraphDef object. :type graph_def: tf.compat.v1.GraphDef :param input_tensor_names: input_tensor_names of graph_def. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of graph_def. :type output_tensor_names: list of string :returns: validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: input_tensor_names (list of string) .. py:function:: graph_session(model, input_tensor_names, output_tensor_names, **kwargs) Helper to build session with tf.compat.v1.Graph. :param model: tf.compat.v1.Graph object. :type model: tf.compat.v1.Graph :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: graph_def_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with tf.compat.v1.GraphDef. :param model: tf.compat.v1.GraphDef object. :type model: tf.compat.v1.GraphDef :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object input_tensor_names (list of string): validated input_tensor_names output_tensor_names (list of string): validated output_tensor_names :rtype: sess (tf.compat.v1.Session) .. py:function:: frozen_pb_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with frozen pb. :param model: model path. :type model: string :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_names) Load graph_def from saved model with the default serving signature key. :param model: Directory of the SavedModel. :param saved_model_tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: The loaded GraphDef. input_tensors: List of input tensors. output_tensors: List of output tensors. :rtype: graph_def .. py:function:: try_loading_keras(model, input_tensor_names, output_tensor_names) Try different ways of loading keras models. :param model: model path or tf.keras.Model object. :type model: string or tf.keras.Model :param input_tensor_names: input tensor names of the model. :type input_tensor_names: list of string :param output_tensor_names: output tensor names of the model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_names (list of string): validated input names. output_names (list of string): validated output names. :rtype: graph_def (tf.compat.v1.Session) .. py:function:: keras_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with keras model. :param model: model path or tf.keras.Model object. :type model: string or tf.keras.Model :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. :rtype: sess (tf.compat.v1.Session) .. py:function:: slim_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with slim model. :param model: model path. :type model: string :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: checkpoint_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with ckpt model. :param model: model path. :type model: string :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: validated output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: estimator_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with estimator model. :param model: tf.estimator.Estimator object. :type model: tf.estimator.Estimator :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :param kwargs: other required parameters, like input_fn. :type kwargs: dict :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:function:: saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs) Build session with saved model. :param model: model path. :type model: string :param input_tensor_names: input_tensor_names of model. :type input_tensor_names: list of string :param output_tensor_names: output_tensor_names of model. :type output_tensor_names: list of string :returns: tf.compat.v1.Session object. input_tensor_names (list of string): validated input_tensor_names. output_tensor_names (list of string): validated output_tensor_names. :rtype: sess (tf.compat.v1.Session) .. py:class:: TensorflowBaseModel(model, **kwargs) Build Tensorflow Base Model. .. py:class:: TensorflowSavedModelModel(model, **kwargs) Build Tensorflow saved model. .. py:class:: TensorflowLLMModel(model, **kwargs) The class Tensorflow saved model whose GraphDef exceeding maximum protobuf size of 2GB. .. py:class:: TensorflowQATModel(model='', **kwargs) Build Tensorflow QAT model. .. py:class:: TensorflowCheckpointModel(model, **kwargs) Build Tensorflow checkpoint model. .. py:class:: TensorflowModel A wrapper to construct a Tensorflow Model.