:py:mod:`neural_compressor.tensorflow.quantization.utils.utility` ================================================================= .. py:module:: neural_compressor.tensorflow.quantization.utils.utility .. autoapi-nested-parse:: Tensorflow Utils Helper functions. Module Contents --------------- Functions ~~~~~~~~~ .. autoapisummary:: neural_compressor.tensorflow.quantization.utils.utility.read_graph neural_compressor.tensorflow.quantization.utils.utility.write_graph neural_compressor.tensorflow.quantization.utils.utility.is_ckpt_format neural_compressor.tensorflow.quantization.utils.utility.is_saved_model_format neural_compressor.tensorflow.quantization.utils.utility.get_tensor_by_name neural_compressor.tensorflow.quantization.utils.utility.iterator_sess_run neural_compressor.tensorflow.quantization.utils.utility.collate_tf_preds neural_compressor.tensorflow.quantization.utils.utility.get_input_output_node_names neural_compressor.tensorflow.quantization.utils.utility.fix_ref_type_of_graph_def neural_compressor.tensorflow.quantization.utils.utility.strip_unused_nodes neural_compressor.tensorflow.quantization.utils.utility.get_estimator_graph neural_compressor.tensorflow.quantization.utils.utility.strip_equivalent_nodes neural_compressor.tensorflow.quantization.utils.utility.get_graph_def neural_compressor.tensorflow.quantization.utils.utility.get_model_input_shape neural_compressor.tensorflow.quantization.utils.utility.get_tensor_val_from_graph_node neural_compressor.tensorflow.quantization.utils.utility.int8_node_name_reverse neural_compressor.tensorflow.quantization.utils.utility.generate_feed_dict neural_compressor.tensorflow.quantization.utils.utility.get_weight_from_input_tensor neural_compressor.tensorflow.quantization.utils.utility.apply_inlining neural_compressor.tensorflow.quantization.utils.utility.construct_function_from_graph_def neural_compressor.tensorflow.quantization.utils.utility.parse_saved_model neural_compressor.tensorflow.quantization.utils.utility.reconstruct_saved_model .. py:function:: read_graph(in_graph, in_graph_is_binary=True) Reads input graph file as GraphDef. :param in_graph: input graph file. :param in_graph_is_binary: whether input graph is binary, default True. :return: input graphDef. .. py:function:: write_graph(out_graph_def, out_graph_file) Write output graphDef to file. :param out_graph_def: output graphDef. :param out_graph_file: path to output graph file. :return: None. .. py:function:: is_ckpt_format(model_path) Check the model_path format is ckpt or not. :param model_path: the model folder path :type model_path: string :returns: return the ckpt prefix if the model_path contains ckpt format data else None. :rtype: string .. py:function:: is_saved_model_format(model_path) Check the model_path format is saved_model or not. :param model_path: the model folder path :type model_path: string :returns: return True if the model_path contains saved_model format else False. :rtype: bool .. py:function:: get_tensor_by_name(graph, name, try_cnt=3) Get the tensor by name. Considering the 'import' scope when model may be imported more then once, handle naming format like both name:0 and name. :param graph: the model to get name from :type graph: tf.compat.v1.GraphDef :param name: tensor of tensor_name:0 or tensor_name without suffixes :type name: string :param try_cnt: the times to add 'import/' to find tensor :returns: tensor got by name. :rtype: tensor .. py:function:: iterator_sess_run(sess, iter_op, feed_dict, output_tensor, iteration=-1, measurer=None) Run the graph that have iterator integrated in the graph. :param sess: the model sess to run the graph :type sess: tf.compat.v1.Session :param iter_op: the MakeIterator op :type iter_op: Operator :param feed_dict: the feeds to initialize a new iterator :type feed_dict: dict :param output_tensor: the output tensors :type output_tensor: list :param iteration: iterations to run, when -1 set, run to end of iterator :type iteration: int :returns: the results of the predictions :rtype: preds .. py:function:: collate_tf_preds(results) Collate the prediction results. .. py:function:: get_input_output_node_names(graph_def) Get the input node name and output node name of the graph_def. .. py:function:: fix_ref_type_of_graph_def(graph_def) Fix ref type of the graph_def. .. py:function:: strip_unused_nodes(graph_def, input_node_names, output_node_names) Strip unused nodes of the graph_def. The strip_unused_nodes pass is from tensorflow/python/tools/strip_unused_lib.py of official tensorflow r1.15 branch .. py:function:: get_estimator_graph(estimator, input_fn) Get the graph of the estimator. :param estimator: tf estimator model :param input_fn: input function :returns: graph .. py:function:: strip_equivalent_nodes(graph_def, output_node_names) Strip nodes with the same input and attr. .. py:function:: get_graph_def(model, outputs=[], auto_input_output=False) Get the model's graph_def. .. py:function:: get_model_input_shape(model) Get the input shape of the input model. .. py:function:: get_tensor_val_from_graph_node(graph_node_name_mapping, node_name) Get the tensor value for given node name. :param graph_node_name_mapping: key: node name, val: node :param node_name: query node :returns: numpy array :rtype: tensor_val .. py:function:: int8_node_name_reverse(node) Reverse int8 node name. .. py:function:: generate_feed_dict(input_tensor, inputs) Generate feed dict helper function. .. py:function:: get_weight_from_input_tensor(model, input_tensor_names, op_types) Extracts weight tensors and their associated nodes from a smooth quant node's input tensor. :param model: A TensorFlow model containing a `graph_def` attribute. :param input_tensor_names: A list of input tensor names to search for weight tensors. :param op_types: A list of operation types to search for when looking for weight tensors. :returns: - sq_weight_tensors: A dictionary mapping each input tensor name to a dict of its associated weight tensors with weight name. - sq_weights_nodes: A dictionary mapping each input tensor name to a dict of its associated weight nodes with weight name. :rtype: A tuple of two dictionaries .. py:function:: apply_inlining(func) Apply an inlining optimization to the function's graph definition. :param func: A concrete function get from saved_model. :returns: The optimized graph in graph_def format. :rtype: new_graph_def .. py:function:: construct_function_from_graph_def(func, graph_def, frozen_func=None) Rebuild function from graph_def. :param func: The original concrete function get from saved_model. :param graph_def: The optimized graph after applying inlining optimization. :returns: The reconstructed function. :rtype: new_func .. py:function:: parse_saved_model(model, freeze=False, input_tensor_names=[], output_tensor_names=[]) Parse a input saved_model. :param model: The input saved_model. :type model: string or AutoTrackable object :returns: The graph_def parsed from saved_model. _saved_model: TF AutoTrackable object loaded from saved_model. func: The concrete function get from saved_model. frozen_func: The reconstructed function from inlining optimized graph. :rtype: graph_def .. py:function:: reconstruct_saved_model(graph_def, func, frozen_func, trackable, path) Reconstruct a saved_model. :param graph_def: The input graph_def. :param func: The concrete function get from the original saved_model. :param frozen_func: The reconstructed function from inlining optimized graph. :param trackable: TF AutoTrackable object loaded from the original saved_model. :param path: The destination path to save the reconstructed saved_model.