:py:mod:`neural_compressor.adaptor.tf_utils.util` ================================================= .. py:module:: neural_compressor.adaptor.tf_utils.util .. autoapi-nested-parse:: Tensorflow Utils Helper functions. Module Contents --------------- Functions ~~~~~~~~~ .. autoapisummary:: neural_compressor.adaptor.tf_utils.util.version1_lt_version2 neural_compressor.adaptor.tf_utils.util.version1_gt_version2 neural_compressor.adaptor.tf_utils.util.version1_eq_version2 neural_compressor.adaptor.tf_utils.util.version1_gte_version2 neural_compressor.adaptor.tf_utils.util.version1_lte_version2 neural_compressor.adaptor.tf_utils.util.disable_random neural_compressor.adaptor.tf_utils.util.read_graph neural_compressor.adaptor.tf_utils.util.write_graph neural_compressor.adaptor.tf_utils.util.is_ckpt_format neural_compressor.adaptor.tf_utils.util.is_saved_model_format neural_compressor.adaptor.tf_utils.util.get_estimator_graph neural_compressor.adaptor.tf_utils.util.get_tensor_by_name neural_compressor.adaptor.tf_utils.util.iterator_sess_run neural_compressor.adaptor.tf_utils.util.collate_tf_preds neural_compressor.adaptor.tf_utils.util.get_input_output_node_names neural_compressor.adaptor.tf_utils.util.fix_ref_type_of_graph_def neural_compressor.adaptor.tf_utils.util.strip_unused_nodes neural_compressor.adaptor.tf_utils.util.strip_equivalent_nodes neural_compressor.adaptor.tf_utils.util.get_graph_def neural_compressor.adaptor.tf_utils.util.get_model_input_shape neural_compressor.adaptor.tf_utils.util.get_tensor_val_from_graph_node neural_compressor.adaptor.tf_utils.util.int8_node_name_reverse neural_compressor.adaptor.tf_utils.util.tf_diagnosis_helper neural_compressor.adaptor.tf_utils.util.generate_feed_dict neural_compressor.adaptor.tf_utils.util.get_weight_from_input_tensor neural_compressor.adaptor.tf_utils.util.apply_inlining neural_compressor.adaptor.tf_utils.util.construct_function_from_graph_def neural_compressor.adaptor.tf_utils.util.parse_saved_model neural_compressor.adaptor.tf_utils.util.reconstruct_saved_model .. py:function:: version1_lt_version2(version1, version2) Check if version1 is less than version2. .. py:function:: version1_gt_version2(version1, version2) Check if version1 is greater than version2. .. py:function:: version1_eq_version2(version1, version2) Check if version1 is equal to version2. .. py:function:: version1_gte_version2(version1, version2) Check if version1 is greater than or equal to version2. .. py:function:: version1_lte_version2(version1, version2) Check if version1 is less than or equal to version2. .. py:function:: disable_random(seed=1) A Decorator to disable tf random seed. .. py:function:: read_graph(in_graph, in_graph_is_binary=True) Reads input graph file as GraphDef. :param in_graph: input graph file. :param in_graph_is_binary: whether input graph is binary, default True. :return: input graphDef. .. py:function:: write_graph(out_graph_def, out_graph_file) Write output graphDef to file. :param out_graph_def: output graphDef. :param out_graph_file: path to output graph file. :return: None. .. py:function:: is_ckpt_format(model_path) Check the model_path format is ckpt or not. :param model_path: the model folder path :type model_path: string :returns: return the ckpt prefix if the model_path contains ckpt format data else None. :rtype: string .. py:function:: is_saved_model_format(model_path) Check the model_path format is saved_model or not. :param model_path: the model folder path :type model_path: string :returns: return True if the model_path contains saved_model format else False. :rtype: bool .. py:function:: get_estimator_graph(estimator, input_fn) Get the graph of the estimator. :param estimator: tf estimator model :param input_fn: input function :returns: graph .. py:function:: get_tensor_by_name(graph, name, try_cnt=3) Get the tensor by name. Considering the 'import' scope when model may be imported more then once, handle naming format like both name:0 and name. :param graph: the model to get name from :type graph: tf.compat.v1.GraphDef :param name: tensor of tensor_name:0 or tensor_name without suffixes :type name: string :param try_cnt: the times to add 'import/' to find tensor :returns: tensor got by name. :rtype: tensor .. py:function:: iterator_sess_run(sess, iter_op, feed_dict, output_tensor, iteration=-1, measurer=None) Run the graph that have iterator integrated in the graph. :param sess: the model sess to run the graph :type sess: tf.compat.v1.Session :param iter_op: the MakeIterator op :type iter_op: Operator :param feed_dict: the feeds to initialize a new iterator :type feed_dict: dict :param output_tensor: the output tensors :type output_tensor: list :param iteration: iterations to run, when -1 set, run to end of iterator :type iteration: int :returns: the results of the predictions :rtype: preds .. py:function:: collate_tf_preds(results) Collate the prediction results. .. py:function:: get_input_output_node_names(graph_def) Get the input node name and output node name of the graph_def. .. py:function:: fix_ref_type_of_graph_def(graph_def) Fix ref type of the graph_def. .. py:function:: strip_unused_nodes(graph_def, input_node_names, output_node_names) Strip unused nodes of the graph_def. The strip_unused_nodes pass is from tensorflow/python/tools/strip_unused_lib.py of official tensorflow r1.15 branch .. py:function:: strip_equivalent_nodes(graph_def, output_node_names) Strip nodes with the same input and attr. .. py:function:: get_graph_def(model, outputs=[], auto_input_output=False) Get the model's graph_def. .. py:function:: get_model_input_shape(model) Get the input shape of the input model. .. py:function:: get_tensor_val_from_graph_node(graph_node_name_mapping, node_name) Get the tensor value for given node name. :param graph_node_name_mapping: key: node name, val: node :param node_name: query node :returns: numpy array :rtype: tensor_val .. py:function:: int8_node_name_reverse(node) Reverse int8 node name. .. py:function:: tf_diagnosis_helper(fp32_model, quan_model, tune_cfg, save_path) Tensorflow diagnosis helper function. .. py:function:: generate_feed_dict(input_tensor, inputs) Generate feed dict helper function. .. py:function:: get_weight_from_input_tensor(model, input_tensor_names, op_types) Extracts weight tensors and their associated nodes from a smooth quant node's input tensor. :param model: A TensorFlow model containing a `graph_def` attribute. :param input_tensor_names: A list of input tensor names to search for weight tensors. :param op_types: A list of operation types to search for when looking for weight tensors. :returns: - sq_weight_tensors: A dictionary mapping each input tensor name to a dict of its associated weight tensors with weight name. - sq_weights_nodes: A dictionary mapping each input tensor name to a dict of its associated weight nodes with weight name. :rtype: A tuple of two dictionaries .. py:function:: apply_inlining(func) Apply an inlining optimization to the function's graph definition. :param func: A concrete function get from saved_model. :returns: The optimized graph in graph_def format. :rtype: new_graph_def .. py:function:: construct_function_from_graph_def(func, graph_def, frozen_func=None) Rebuild function from graph_def. :param func: The original concrete function get from saved_model. :param graph_def: The optimized graph after applying inlining optimization. :returns: The reconstructed function. :rtype: new_func .. py:function:: parse_saved_model(model, freeze=False, input_tensor_names=[], output_tensor_names=[]) Parse a input saved_model. :param model: The input saved_model. :type model: string or AutoTrackable object :returns: The graph_def parsed from saved_model. _saved_model: TF AutoTrackable object loaded from saved_model. func: The concrete function get from saved_model. frozen_func: The reconstructed function from inlining optimized graph. :rtype: graph_def .. py:function:: reconstruct_saved_model(graph_def, func, frozen_func, trackable, path) Reconstruct a saved_model. :param graph_def: The input graph_def. :param func: The concrete function get from the original saved_model. :param frozen_func: The reconstructed function from inlining optimized graph. :param trackable: TF AutoTrackable object loaded from the original saved_model. :param path: The destination path to save the reconstructed saved_model.