Source code for tlt.datasets.text_classification.tf_custom_text_classification_dataset

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#

import os
import tensorflow as tf

from tlt.datasets.tf_dataset import TFDataset
from tlt.datasets.text_classification.text_classification_dataset import TextClassificationDataset
from tlt.utils.dataset_utils import prepare_huggingface_input_data
from tlt.utils.inc_utils import INCTFDataLoader


[docs]class TFCustomTextClassificationDataset(TextClassificationDataset, TFDataset): """ A custom text classification dataset that can be used with TensorFlow models. Note that this dataset class expects a .csv file with two columns where the first column is the label and the second column is the text/sentence to classify. For example, a comma separated value file will look similar to the snippet below: .. code-block:: text class_a,<text> class_b,<text> class_a,<text> ... If the .csv files has more columns, the select_cols or exclude_cols parameters can be used to filter out which columns will be parsed. Args: dataset_dir (str): Directory containing the dataset dataset_name (str): Name of the dataset. If no dataset name is given, the dataset_dir folder name will be used as the dataset name. csv_file_name (str): Name of the csv file to load from the dataset directory class_names (list): List of ordered class names label_map_func (function): optional; Maps the label_map_func across the label column of the dataset to apply a transform to the elements. For example, if the .csv file has string class labels instead of numerical values, provide a function that maps the string to a numerical value. defaults (list): optional; List of default values for the .csv file fields. Defaults to [tf.string, tf.string] delimiter (str): optional; String character that separates the label and text in each row. Defaults to ",". header (bool): optional; Boolean indicating whether or not the csv file has a header line that should be skipped. Defaults to False. select_cols (list): optional; Specify a list of sorted indices for columns from the dataset file(s) that should be parsed. Defaults to parsing all columns. At most one of select_cols and exclude_cols can be specified. exclude_cols (list): optional; Specify a list of sorted indices for columns from the dataset file(s) that should be excluded from parsing. Defaults to parsing all columns. At most one of select_cols and exclude_cols can be specified. shuffle_files (bool): optional; Whether to shuffle the data. Defaults to True. seed (int): optional; Random seed for shuffling Raises: FileNotFoundError: if the csv file is not found in the dataset directory TypeError: if the class_names parameter is not a list or the label_map_func is not callable ValueError: if the class_names list is empty """
[docs] def __init__(self, dataset_dir, dataset_name, csv_file_name, class_names=[], label_map_func=None, defaults=[tf.string, tf.string], delimiter=",", header=False, select_cols=None, exclude_cols=None, shuffle_files=True, seed=None, **kwargs): """ Class constructor """ dataset_file = os.path.join(dataset_dir, csv_file_name) if not os.path.exists(dataset_file): raise FileNotFoundError("The dataset file ({}) does not exist".format(dataset_file)) if label_map_func and not callable(label_map_func): raise TypeError("The label_map_func is expected to be a function, but found a {}", type(label_map_func)) # The dataset name is only used for informational purposes. Default to use the file name without extension. if not dataset_name: dataset_name = csv_file_name[:csv_file_name.index('.')] if '.' in csv_file_name else csv_file_name TextClassificationDataset.__init__(self, dataset_dir, dataset_name, dataset_catalog=None) self._dataset = tf.data.experimental.CsvDataset(filenames=dataset_file, record_defaults=defaults, field_delim=delimiter, use_quote_delim=False, header=header, select_cols=select_cols, exclude_cols=exclude_cols) if shuffle_files: self._dataset = self._dataset.shuffle(1, seed=seed) # Count the number of lines in the csv file to get the dataset length dataset_len = sum(1 for _ in open(dataset_file)) if header: dataset_len -= 1 # Set the cardinality so that the dataset length can be used for shuffle splits and progress bars self._dataset = self._dataset.apply(tf.data.experimental.assert_cardinality(dataset_len)) # If a map function has not been defined, we know that we a least need to convert the string from the # csv file to a integer for the label field if not label_map_func: def label_map_func(x): return int(x) self._dataset = self._dataset.map(lambda x, y: (y, label_map_func(x))) self._info = { "name": dataset_name, "dataset_dir": dataset_dir, "file_name": csv_file_name, "delimiter": delimiter, "defaults": defaults, "header": header, "select_cols": select_cols, "exclude_cols": exclude_cols } self._preprocessed = None self._class_names = class_names self._train_pct = 1.0 self._val_pct = 0 self._test_pct = 0 self._validation_type = None self._train_subset = None self._validation_subset = None self._test_subset = None
@property def class_names(self): """ Returns the list of class names """ return self._class_names @property def info(self): """ Returns a dictionary of information about the dataset """ return {'dataset_info': self._info, 'preprocessing_info': self._preprocessed} @property def dataset(self): """ Returns the framework dataset object (tf.data.Dataset) """ return self._dataset def preprocess(self, batch_size): """ Batch the dataset Args: batch_size (int): desired batch size Raises: TypeError if the batch_size is not a positive integer ValueError if the dataset is not defined or has already been processed """ if not isinstance(batch_size, int) or batch_size < 1: raise ValueError("batch_size should be a positive integer") if self._preprocessed: raise ValueError("Data has already been preprocessed: {}".format(self._preprocessed)) # Get the non-None splits split_list = ['_dataset', '_train_subset', '_validation_subset', '_test_subset'] subsets = [s for s in split_list if getattr(self, s, None)] for subset in subsets: setattr(self, subset, getattr(self, subset).cache()) setattr(self, subset, getattr(self, subset).batch(batch_size)) setattr(self, subset, getattr(self, subset).prefetch(tf.data.AUTOTUNE)) self._preprocessed = {'batch_size': batch_size} def get_inc_dataloaders(self, hub_name, max_seq_length): calib_data, calib_labels = prepare_huggingface_input_data(self.train_subset, hub_name, max_seq_length) calib_data['label'] = tf.convert_to_tensor(calib_labels) eval_data, eval_labels = prepare_huggingface_input_data(self.validation_subset, hub_name, max_seq_length) eval_data['label'] = tf.convert_to_tensor(eval_labels) calib_data.pop('token_type_ids') eval_data.pop('token_type_ids') calib_dataloader = INCTFDataLoader(calib_data, batch_size=self._preprocessed['batch_size']) eval_dataloader = INCTFDataLoader(eval_data, batch_size=self._preprocessed['batch_size']) return calib_dataloader, eval_dataloader