#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#
import os
from typing import List, Optional
import pandas as pd
from datasets.arrow_dataset import Dataset
from tlt.datasets.hf_dataset import HFDataset
from tlt.datasets.text_classification.text_classification_dataset import TextClassificationDataset
[docs]class HFCustomTextClassificationDataset(TextClassificationDataset, HFDataset):
"""
A custom text classification dataset that can be used with Transformer models.
"""
[docs] def __init__(
self,
dataset_dir,
dataset_name: Optional[str],
csv_file_name: str,
class_names: Optional[List[str]] = None,
column_names: Optional[List[str]] = None,
label_map_func: Optional[callable] = None,
label_col: Optional[int] = 0,
delimiter: Optional[str] = ",",
header: Optional[bool] = False,
select_cols: Optional[List[int]] = None,
exclude_cols: Optional[List[int]] = None,
shuffle_files: Optional[bool] = True,
num_workers: Optional[int] = 0,
):
"""
A custom text classification dataset that can be used with Transformer models.
Note that this dataset class expects a .csv file with two columns where the first column is the label and
the second column is the text/sentence to classify.
For example, a comma separated value file will look similar to the snippet below:
.. code-block:: text
class_a,<text>
class_b,<text>
class_a,<text>
...
If the .csv files has more columns, the select_cols or exclude_cols parameters can be used to filter out which
columns will be parsed.
Args:
dataset_dir (str): Directory containing the dataset
dataset_name (str): Name of the dataset. If no dataset name is given, the dataset_dir folder name
will be used as the dataset name.
csv_file_name (str): Name of the file to load from the dataset directory
class_names (list(str)): optional; List of ordered class names. If None, class_names are inferred from
label_col column
column_names (list(str)): optional; List of column names. If given, there must be exactly one value as
"label" in the position corresponding to the 'label_col' argument. If None, column names are assigned
as "label" for the label_col column and "text_1", "text_2", ... for the rest of the columns.
label_map_func (function): optional; Maps the label_map_func across the label column of the dataset to
apply a transform to the elements. For example, if the .csv file has string class labels
instead of numerical values, you can provide a function that maps the string to a numerical
value or specify the index of the label column to apply a default label_map_func which assigns an
integer for every unique class label, starting with 0.
label_col (int): optional; Column index of the dataset to use as label column. Defaults to "0"
delimiter (str): String character that separates the text in each row. Defaults to ","
header (bool): optional; Boolean indicating whether or not the csv file has a header line that should be
skipped. Defaults to False.
select_cols (list): optional; Specify a list of sorted indices for columns from the dataset file(s) that
should be parsed. Defaults to parsing all columns. At most one of select_cols and exclude_cols can
be specified.
exclude_cols (list): optional; Specify a list of sorted indices for columns from the dataset file(s) that
should be excluded from parsing. Defaults to parsing all columns. At most one of select_cols and
exclude_cols can be specified.
shuffle_files (bool): optional; Whether to shuffle the data. Defaults to True.
num_workers (int): Number of workers to pass into a DataLoader.
Raises:
FileNotFoundError: if the csv file is not found in the dataset directory
TypeError: if label_map_func is not callable
ValueError: if class_names list is empty
ValueError: if column_names list does not contain the value 'label'
ValueError: if index of 'label' in column_names and label_col mismatch
ValueError: if the values of column_names are not strings.
ValueError: if column_names contains more than one value as 'label'
"""
# Sanity checks
dataset_file = os.path.join(dataset_dir, csv_file_name)
if not os.path.exists(dataset_file):
raise FileNotFoundError("The dataset file ({}) does not exist".format(dataset_file))
if isinstance(class_names, list) and len(class_names) == 0:
raise ValueError("The class_names list cannot be empty.")
if label_map_func and not callable(label_map_func):
raise TypeError("The label_map_func is expected to be a function, but found a {}", type(label_map_func))
# The dataset name is only used for informational purposes. Default to use the file name without extension.
if not dataset_name:
dataset_name = os.path.splitext(csv_file_name)[0]
if column_names:
if 'label' not in column_names:
raise ValueError("The column_names list must contain one value as 'label'")
if column_names.count('label') > 1:
raise ValueError("There must be exactly one value as 'label' in column_names.")
if not all(isinstance(c, str) for c in column_names):
raise ValueError("All column names must be strings.")
if column_names.index('label') != label_col:
raise ValueError("The label_col index ({}) does not match with column_names {}."
"Either specify label_col argument (or) make the first value "
"in your column_names as 'label'".format(label_col, column_names))
TextClassificationDataset.__init__(self, dataset_dir, dataset_name, dataset_catalog=None)
print("WARNING: Using column {} as label column. To change this behavior, "
"specify the label_col argument".format(label_col))
if delimiter == 't':
delimiter = '\t'
if header:
dataset_df = pd.read_csv(dataset_file, delimiter=delimiter, encoding='utf-8', dtype=str, names=column_names,
header=0)
else:
dataset_df = pd.read_csv(dataset_file, delimiter=delimiter, encoding='utf-8', dtype=str, names=column_names,
header=None)
if not column_names:
column_names = {i: 'label' if i == label_col else f'text_{i}' for i in dataset_df.columns}
dataset_df.rename(column_names, axis=1, inplace=True)
if select_cols and not exclude_cols:
dataset_df = dataset_df[dataset_df.columns[select_cols]]
elif exclude_cols and not select_cols:
dataset_df = dataset_df.drop(dataset_df.columns[exclude_cols], axis=1)
elif select_cols and exclude_cols:
if not set(select_cols).isdisjoint(exclude_cols):
raise ValueError("select_cols and exclude_cols lists are ambiguous. \
Please make sure they are disjoint")
dataset_df = dataset_df.drop(dataset_df.columns[exclude_cols], axis=1)
dataset_df = dataset_df[dataset_df.columns[select_cols]]
if not class_names:
class_names = dataset_df.iloc[:, label_col].unique()
if not label_map_func:
label_str_dict = {label_name: idx for idx, label_name in enumerate(class_names)}
def label_map_func(x):
return label_str_dict[x]
dataset_df.iloc[:, label_col] = dataset_df.iloc[:, label_col].map(label_map_func)
self._dataset = Dataset.from_pandas(dataset_df)
self._info = {
"name": dataset_name,
"dataset_dir": dataset_dir,
"file_name": csv_file_name,
"delimiter": delimiter,
"header": header,
"select_cols": select_cols,
"exclude_cols": exclude_cols,
'class_names': class_names
}
self._class_names = class_names
self._validation_type = None
self._preprocessed = {}
self._shuffle = shuffle_files
self._num_workers = num_workers
@property
def dataset(self):
return self._dataset
@property
def class_names(self):
return self._class_names
@property
def info(self):
return {'dataset_info': self._info, 'preprocessing_info': self._preprocessed}