Explaining Fine Tuned Text Classifier with PyTorch using the Intel® Explainable AI API

This notebook demonstrates fine tuning pretrained models from Hugging Face using text classification datasets from the Hugging Face Datasets catalog or a custom dataset. The notebook uses Intel® Extension for PyTorch*, which extends PyTorch with optimizations for an extra performance boost on Intel hardware.

Please install the dependencies from the pytorch_requirements.txt file before executing this notebook.

The notebook performs the following steps: 1. Import dependencies and setup parameters 2. Prepare the dataset 3. Prepare the Model for Fine Tuning and Evaluation 4. Export the model 5. Reload the model and make predictions 6. Get Explainations with Intel Explainable AI Tools

1. Import dependencies and setup parameters

This notebook assumes that you have already followed the instructions in the README.md to setup a PyTorch environment with all the dependencies required to run the notebook.

[ ]:
import intel_extension_for_pytorch as ipex
import logging
import numpy as np
import os
import pandas as pd
import sys
import torch
import warnings
import typing
import pickle

from tqdm.auto import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from datasets import ClassLabel, load_dataset, load_metric, Split
from datasets import logging as datasets_logging
from transformers.utils import logging as transformers_logging
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    get_scheduler
)
from tlt.utils.file_utils import download_and_extract_zip_file

# Set the logging stream to stdout
for handler in transformers_logging._get_library_root_logger().handlers:
    handler.setStream(sys.stdout)

sh = datasets_logging.logging.StreamHandler(sys.stdout)

datasets_logging.set_verbosity_error()
warnings.filterwarnings('ignore')
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
[ ]:
# Specify the name of the Hugging Face pretrained model to use (https://huggingface.co/models)
# For example:
#   albert-base-v2
#   bert-base-uncased
#   distilbert-base-uncased
#   distilbert-base-uncased-finetuned-sst-2-english
#   roberta-base
model_name = "distilbert-base-uncased"

# Define an output directory
output_dir = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "output", model_name)

# Define a dataset directory
dataset_dir = os.environ["DATASET_DIR"] if "DATASET_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "dataset")

print("Model name:", model_name)
print("Output directory:", output_dir)
print("Dataset directory:", dataset_dir)

2. Prepare the dataset

The notebook has two options for getting a dataset: * Option A: Use a dataset from the Hugging Face Datasets catalog * Option B: Use a custom dataset (downloaded from another source or from your local system)

In both cases, the code ends up defining `datasets.Dataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset>`__ objects for the train and evaluation splits.

Execute the following cell to load the tokenizer and declare the base class used for the dataset setup.

[ ]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

class TextClassificationData():
    """
    Base class used for defining the text classification dataset being used. Defines Hugging Face datasets.Dataset
    objects for train and evaluations splits, along with helper functions for preprocessing the dataset.
    """

    def __init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key):
        self.tokenizer = tokenizer
        self.dataset_name = dataset_name
        self.class_labels = None

        # Tokenized train and eval ds
        self.train_ds = None
        self.eval_ds = None

        # Column keys
        self.sentence1_key = sentence1_key
        self.sentence2_key = sentence2_key
        self.label_key = label_key

    def tokenize_function(self, examples):
        # Define the tokenizer args, depending on if the data has 2 sentences or just 1
        args = ((examples[self.sentence1_key],) if self.sentence2_key is None \
                 else (examples[self.sentence1_key], examples[self.sentence2_key]))
        return self.tokenizer(*args, padding="max_length", truncation=True)

    def tokenize_dataset(self, dataset):
        # Apply the tokenize function to the dataset
        tokenized_dataset = dataset.map(self.tokenize_function, batched=True)

        # Remove the raw text from the tokenized dataset
        raw_text_columns = [self.sentence1_key, self.sentence2_key] if self.sentence2_key else [self.sentence1_key]
        return tokenized_dataset.remove_columns(raw_text_columns)

    def define_train_eval_splits(self, dataset, train_split_name, eval_split_name, train_size=None, eval_size=None):
        self.train_ds = dataset[train_split_name].shuffle().select(range(train_size)) if train_size \
            else tokenized_dataset[train_split_name]
        self.eval_ds = dataset[eval_split_name].shuffle().select(range(eval_size)) if eval_size \
            else tokenized_dataset[eval_split_name]

    def get_label_names(self):
        if self.class_labels:
            return self.class_labels.names
        else:
            raise ValueError("Class labels were not defined")

    def display_sample(self, split_name="train", sample_size=7):
        # Display a sample of the raw data
        sentence1_sample = self.dataset[split_name][self.sentence1_key][:sample_size]
        sentence2_sample = self.dataset[split_name][self.sentence2_key][:sample_size] if self.sentence2_key else None
        label_sample = self.dataset[split_name][self.label_key][:sample_size]
        dataset_sample = zip(sentence1_sample, sentence2_sample, label_sample) if self.sentence2_key \
            else zip(sentence1_sample, label_sample)

        columns = [self.sentence1_key, self.sentence2_key, self.label_key] if self.sentence2_key else \
            [self.sentence1_key, self.label_key]

        # Display the sample using a dataframe
        sample = pd.DataFrame(dataset_sample, columns=columns)
        return sample.style.hide_index()

Now that the base class is defined, either run Option A to use the Hugging Face Dataset catalog or Option B for a custom dataset downloaded from online or from your local system.

Option A: Use a Hugging Face dataset

Hugging Face Datasets has a catalog of datasets that can be specified by name. Information about the dataset is available in the catalog (including information on the size of the dataset and the splits).

The next cell gets the IMDb movie review dataset using the Hugging Face datasets API. If the notebook is executed multiple times, the dataset will be used from the dataset directory, to speed up the time that it takes to run.

The IMDb dataset in Hugging Face has 3 splits: train, test, and unsupervised. This notebook will be using data from the train split for training and data from the test split for evaluation. The data has 2 columns: text (string with the movie review) and label (integer class label). The code in the next cell is setup to run using the IMDb dataset, so note that if a different dataset is being used, you may need to change the split names and/or the column names.

[ ]:
class HFDSTextClassificationData(TextClassificationData):
    """
    Class used for loading and preprocessing text classification datasets from the Hugging Face datasets catalog
    """

    def __init__(self, tokenizer, dataset_dir, dataset_name, train_size, eval_size, train_split_name,
                 eval_split_name, sentence1_key, sentence2_key, label_key):
        """
        Initialize the HFDSTextClassificationData class for a text classification dataset from Hugging Face.

        :param tokenizer: Tokenizer to preprocess the dataset
        :param dataset_dir: Cache directory used when loading the dataset
        :param dataset_name: Name of the dataset to load from the Hugging Face catalog
        :param train_size: Size of the training dataset. For quicker training or debug, use a subset of the data.
                           Set to `None` to use all the data.
        :param eval_size: Size of the evaluation dataset.
        :param train_split_name: String specifying which split to load for training (e.g. "train[:80%]"). See the
                                 https://www.tensorflow.org/datasets/splits documentation for more information on
                                 defining splits.
        :param eval_split_name: String specifying the split to load for evaluation.
        :param sentence1_key: Name of the sentence1 column
        :param sentence2_key: Name of the sentence2 column or `None` if there's only one text column
        :param label_key: Name of the label column
        """

        # Init base class
        TextClassificationData.__init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key)

        # Load the dataset from the Hugging Face dataset API
        self.dataset = load_dataset(dataset_name, cache_dir=dataset_dir)

        # Tokenize the dataset
        tokenized_dataset = self.tokenize_dataset(self.dataset)

        # Get the training and eval dataset based on the specified dataset sizes
        self.define_train_eval_splits(tokenized_dataset, train_split_name, eval_split_name, train_size, eval_size)

        # Save the class label information to use later when predicting
        self.class_labels = self.dataset[train_split_name].features[label_key]

# Name of the Hugging Face dataset
dataset_name = "imdb"

# For quicker training and debug runs, use a subset of the dataset by specifying the size of the train/eval datasets.
# Set the sizes `None` to use the full dataset. The full IMDb dataset has 25,000 training and 25,000 test examples.
train_dataset_size = 1000
eval_dataset_size = 1000

# Name of the columns in the dataset (the column names may vary if you are not using the IMDb dataset)
sentence1_key = "text"
sentence2_key = None
label_key = "label"

dataset = HFDSTextClassificationData(tokenizer, dataset_dir, dataset_name, train_dataset_size, eval_dataset_size,
                                     Split.TRAIN, Split.TEST, sentence1_key, sentence2_key, label_key)

# Print a sample of the data
dataset.display_sample(Split.TRAIN, sample_size=5)

Skip to Step 3 Get the model and setup the Trainer to continue using the dataset from the Hugging Face catalog.

Option B: Use a custom dataset

Instead of using a dataset from the Hugging Face dataset catalog, a custom dataset from your local system or a download can be used.

In this example, we download the SMS Spam Collection dataset. The zip file has a single tab-separated value file with two columns. The first column is the label (ham or spam) and the second column is the text of the SMS message:

<ham or spam>   <text>
<ham or spam>   <text>
<ham or spam>   <text>
...

If you are using a custom dataset that has a similarly formatted csv or tsv file, you can use the class defined below. Create your object by passing in custom values for csv file name, delimiter, the label map, mapping function, etc.

[ ]:
class CustomCsvTextClassificationData(TextClassificationData):
    """
    Class used for loading and preprocessing text classification datasets from CSV files
    """

    def __init__(self, tokenizer, dataset_name, dataset_dir, data_files, delimiter, label_names, sentence1_key, sentence2_key,
                 label_key, train_percent=0.8, eval_percent=0.2, train_size=None, eval_size=None, map_function=None):
        """
        Intialize the CustomCsvTextClassificationData class for a text classification
        dataset. The classes uses the Hugging Face datasets API to load the CSV file,
        and split it into a train and eval datasets based on the specified percentages.
        If train_size and eval_size are also defined, the datasets are reduced to the
        specified number of examples.

        :param tokenizer: Tokenizer to preprocess the dataset
        :param dataset_name: Dataset name for identification purposes
        :param dataset_dir: Directory where the csv file(s) are located
        :param data_files: List of data file names
        :param delimiter: Delimited for the csv files
        :param label_names: List of label names
        :param sentence1_key: Name of the sentence1 column
        :param sentence2_key: Name of the sentence2 column or `None` if there's only one text column
        :param label_key: Name of the label column
        :param train_percent: Decimal value for the percentage of the dataset that should be used for training
                              (e.g. 0.8 for 80%)
        :param eval_percent: Decimal value for the percentage of the dataset that should used for validation
                             (e.g. 0.2 for 20%)
        :param train_size: Size of the training dataset. For quicker training or debug, use a subset of the data.
                           Set to `None` to use all the data.
        :param eval_size: Size of the eval dataset. Set to `None` to use all the data.
        :param map_function: (Optional) Map function to apply to the dataset. For example, if the csv file has string
                             labels instead of numerical values, map function can do the conversion.
        """
        # Init base class
        TextClassificationData.__init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key)

        if (train_percent + eval_percent) > 1:
            raise ValueError("The combined value of the train percentage and eval percentage " \
                             "cannot be greater than 1")

        # Create a list of the column names
        column_names = [label_key, sentence1_key, sentence2_key] if sentence2_key else [label_key, sentence1_key]

        # Load the dataset using the Hugging Face API
        self.dataset = load_dataset(dataset_dir, delimiter=delimiter, data_files=data_files, column_names=column_names)

        # Optionally map the dataset labels using the map_function
        if map_function:
            self.dataset = self.dataset.map(map_function)

        # Setup the class labels
        self.class_labels = ClassLabel(num_classes=len(label_names), names=label_names)
        self.dataset[Split.TRAIN].features[label_key] = self.class_labels

        # Split the dataset based on the percentages defined
        self.dataset = self.dataset[Split.TRAIN].train_test_split(train_size=train_percent, test_size=eval_percent)

        # Tokenize the dataset
        tokenized_dataset = self.tokenize_dataset(self.dataset)

        # Get the training and eval dataset based on the specified dataset sizes
        self.define_train_eval_splits(tokenized_dataset, Split.TRAIN, Split.TEST, train_size, eval_size)


# Modify the variables below to use a different dataset or a csv file on your local system.
# The csv_path variable should be pointing to a csv file with 2 columns (the label and the text)
dataset_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip"
dataset_dir = os.path.join(dataset_dir, "smsspamcollection")
csv_name = "SMSSpamCollection"
delimiter = "\t"
label_names = ["ham", "spam"]

# Rename the file to include the csv extension so that the dataset API knows how to load the file
renamed_csv = "{}.csv".format(csv_name)

# If we don't already have the csv file, download and extract the zip file to get it.
if not os.path.exists(os.path.join(dataset_dir, csv_name)) and \
                      not os.path.exists(os.path.join(dataset_dir, renamed_csv)):
    download_and_extract_zip_file(dataset_url, dataset_dir)

if not os.path.exists(os.path.join(dataset_dir, renamed_csv)):
    os.rename(os.path.join(dataset_dir, csv_name), os.path.join(dataset_dir, renamed_csv))

# Columns
sentence1_key = "text"
sentence2_key = None
label_key = "label"

# Map function to translate labels in the csv file to numerical values when loading the dataset
def map_spam(example):
    example["label"] = int(example["label"] == "spam")
    return example

dataset = CustomCsvTextClassificationData(tokenizer, "smsspamcollection", dataset_dir, [renamed_csv], delimiter,
                                          label_names, sentence1_key, sentence2_key, label_key, train_size=1000,
                                          eval_size=1000, map_function=map_spam)

# Print a sample of the data
dataset.display_sample(Split.TRAIN, 10)

3. Prepare the Model for Fine Tuning and Evaluation

The notebook has two options to train the model.

In both cases, the model ends up being a transformers model and depending on the class constructor arguments, the appropriate API is selected.

Execute the following cell to declare the base class used for the Text Classification Model setup.

[ ]:
class TextClassificationModel():
    """
    Class used for model loading, training and evaluation.
    """
    def __init__(self,
                 model_name: str,
                 num_labels: int,
                 training_args: TrainingArguments = None,
                 ipex_optimize: bool = True,
                 device: str = "cpu"):
        """
        Initialize the TextClassificationModel class for a text classification model with
        PyTorch. The class uses the model_name to load the pre-trained PyTorch model from
        Hugging Face. If the training_args are given then the Trainer API is selected for
        training and evaluation of the model otherwise native PyTorch API is selected for
        model training and evaluation

        :param model_name: Name of the pre-trained model to load from Hugging Face
        :param num_labels: Number of class labels
        :param training_args: A TrainingArguments object if using the Trainer API to train
                              the model. If None, native PyTorch API is used for training.
        :param ipex_optimize: If True, then the model is optimized to run on intel hardware.
        :param device: Device to run on the PyTorch model.
        """
        self.model_name = model_name
        self.num_labels = num_labels
        self.training_args = training_args
        self.device = device
        self.trainer = None

        self.train_ds = dataset.train_ds
        self.eval_ds = dataset.eval_ds

        # Load the model using the pretrained weights
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

        # Apply the ipex optimize function to the model
        if ipex_optimize:
            self.model = ipex.optimize(self.model)

    def train(self,
              dataset: TextClassificationData,
              optimizers: typing.Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR],
              num_train_epochs: int = 1,
              batch_size: int = 16,
              compute_metrics: typing.Callable = None,
              shuffle_samples: bool = True
             ):

        # If training_args are given, we use the `Trainer` API to train the model
        if self.training_args:
            self.model.train()
            self.trainer = Trainer(model=self.model,
                                   args=self.training_args,
                                   train_dataset=self.train_ds,
                                   eval_dataset=self.eval_ds,
                                   optimizers=optimizers,
                                   compute_metrics=compute_metrics)
            self.trainer.train()

        # If training_args are not given, we use native PyTorch API to train the model
        else:

            # Rename the `label` column to `labels` because the model expects the argument to be named `labels`
            self.train_ds = self.train_ds.rename_column("label", "labels")

            # Set the format of the dataset to return PyTorch tensors instead of lists
            self.train_ds.set_format("torch")

            train_dataloader = DataLoader(self.train_ds, shuffle=shuffle_samples, batch_size=batch_size)

            # Unpack the `optimizers` parameter to get optimizer and lr_scheduler
            optimizer, lr_scheduler = optimizers[0], optimizers[1]

            # Define number of training steps for the training progress bar
            num_training_steps = num_train_epochs * len(train_dataloader)
            progress_bar = tqdm(range(num_training_steps))

            # Training loop
            self.model.to(self.device)
            self.model.train()
            for epoch in range(num_train_epochs):
                for batch in train_dataloader:
                    batch = {k: v.to(self.device) for k, v in batch.items()}
                    outputs = self.model(**batch)
                    loss = outputs.loss
                    loss.backward()

                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    progress_bar.update(1)

    def evaluate(self, batch_size=16):

        if self.trainer:
            self.model.eval()
            metrics = self.trainer.evaluate()
            for key in metrics.keys():
                print("{}: {}".format(key, metrics[key]))
        else:
            # Rename the `label` column to `labels` because the model expects the argument to be named `labels`
            self.eval_ds = self.eval_ds.rename_column("label", "labels")

            # Set the format of the dataset to return PyTorch tensors instead of lists
            self.eval_ds.set_format("torch")

            eval_dataloader = DataLoader(self.eval_ds, batch_size=batch_size)
            progress_bar = tqdm(range(len(eval_dataloader)))

            metric = load_metric("accuracy")
            self.model.eval()
            for batch in eval_dataloader:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = self.model(**batch)

                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)
                metric.add_batch(predictions=predictions, references=batch["labels"])
                progress_bar.update(1)

            print(metric.compute())

    def predict(self, raw_input_text):
        if isinstance(raw_input_text, str):
            raw_input_text = [raw_input_text]

        # Encode the raw text using the tokenizer
        encoded_input = tokenizer(raw_input_text, padding=True, return_tensors='pt')

        # Input the encoded text(s) to the model and get the predicted results
        output = self.model(**encoded_input)
        _, predictions = torch.max(output.logits, dim=1)

        # Translate the predictions to class label strings
        prediction_labels = dataset.class_labels.int2str(predictions)

        # Create a dataframe to display the results
        result_list = [list(x) for x in zip(raw_text_input, prediction_labels)]
        result_df = pd.DataFrame(result_list, columns=["Input Text", "Predicted Label"])
        return result_df.style.hide_index()

    def parameters(self):
        return self.model.parameters()

    def save(self, output_dir):
        self.model.save_pretrained(output_dir)

    @classmethod
    def load(cls, output_dir):
        return cls(output_dir, num_labels=len(dataset.get_label_names()))

Now that the TextClassificationModel class is defined, either use Option A to use the `Trainer <https://huggingface.co/docs/transformers/v4.16.2/en/main_classes/trainer#transformers.Trainer>`__ API from Hugging Face or Option B to use the native PyTorch API.

Option A: Use the `Trainer <https://huggingface.co/docs/transformers/v4.16.2/en/main_classes/trainer#transformers.Trainer>`__ API from Hugging Face

This step gets the pretrained model from Hugging Face and sets up the TrainingArguments and the Trainer. For simplicity, this example is using default values for most of the training args, but we are specifying our output directory and the number of training epochs. If your output directory already has checkpoints from a previous run, training will resume from the last checkpoint. The overwrite_output_dir training argument can be set to True if you want to instead overwrite previously generated checkpoints.

Note that it is expected to see a warning at this step about some weights not being used. This is because the pretraining head from the original model is being replaced with a classification head.

[ ]:
num_train_epochs = 2
batch_size = 16
num_labels = len(dataset.get_label_names())

# Define a TrainingArguments object for the Trainer API to use.
training_args = TrainingArguments(output_dir=output_dir, num_train_epochs=num_train_epochs)

# Get the model from Hugging Face. Since we are specifying training_args, the model is trained and
# evaluated with the Trainer API.
model = TextClassificationModel(model_name=model_name, num_labels=num_labels, training_args=training_args)

# Define model training parameters
learning_rate      = 5e-5
optimizer          = AdamW(model.parameters(), lr=learning_rate)
num_training_steps = num_train_epochs * len(dataset.train_ds)
metric             = load_metric("accuracy")
lr_scheduler       = get_scheduler(
                        name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
                     )

# Helper function for the Trainer API to compute metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Train and evaluate the model with the Trainer API

[ ]:
model.train(
    dataset,
    optimizers=(optimizer, lr_scheduler),
    num_train_epochs=num_train_epochs,
    batch_size=batch_size,
    compute_metrics=compute_metrics
)
[ ]:
model.evaluate()

Option B: Use the native PyTorch API

This step gets the pretrained model from Hugging Face and uses native PyTorch API to train and evaluate the model.

Note that it is expected to see a warning at this step about some weights not being used. This is because the pretraining head from the original model is being replaced with a classification head.

[ ]:
num_train_epochs = 2
batch_size = 16
num_labels = len(dataset.get_label_names())

# Get the model from Hugging Face. Since we are not specifying training_args, the model is trained and
# evaluated with the native PyTorch API.
model = TextClassificationModel(model_name=model_name, num_labels=num_labels)

# Define model training parameters
learning_rate      = 5e-5
optimizer          = AdamW(model.parameters(), lr=learning_rate)
num_training_steps = num_train_epochs * len(dataset.train_ds)
lr_scheduler       = get_scheduler(
                        name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
                     )

Train and evaluate the model with the native PyTorch API

[ ]:
model.train(
    dataset,
    optimizers=(optimizer, lr_scheduler),
    num_train_epochs=num_train_epochs,
    batch_size=batch_size
)
[ ]:
model.evaluate()

4. Export the model

[ ]:
# Save the model to our output directory
model.save(output_dir)

5. Reload the model and make predictions

The output directory is used to reload the model. In the next cell, we evalute the reloaded model to verify that we are getting the same metrics that we saw after fine tuning.

[ ]:
reloaded_model = TextClassificationModel.load(output_dir)

reloaded_model.evaluate()

Next, we demonstrate how encode raw text input and get predictions from the reloaded model.

[ ]:
model = reloaded_model
[ ]:
# Setup some raw text input
raw_text_input = ["It was okay. I finished it, but wouldn't watch it again.",
                  "So bad",
                  "Definitely not my favorite",
                  "Highly recommended"]

model.predict(raw_text_input)

6. Get Explainations with Intel Explainable AI Tools

[ ]:
from intel_ai_safety.explainer import attributions
[ ]:
from scipy.special import softmax
# Define a prediction function
def f(x):
    encoded_input = tokenizer(x.tolist(), padding='max_length', max_length=512, truncation=True, return_tensors='pt')
    outputs = model.model(**encoded_input)
    return softmax(outputs.logits.detach().numpy(), axis=1)
[ ]:
from intel_ai_safety.explainer import attributions
# Get shap values
text_for_shap = dataset.dataset['test'][:10]['text']
partition_explainer = attributions.partition_text_explainer(f, dataset.class_labels.names, text_for_shap, r"\W+", )
[ ]:
partition_explainer.visualize()

Citations

@InProceedings{maas-EtAl:2011:ACL-HLT2011,
  author    = {Maas, Andrew L.  and  Daly, Raymond E.  and  Pham, Peter T.  and  Huang, Dan  and  Ng, Andrew Y.  and  Potts, Christopher},
  title     = {Learning Word Vectors for Sentiment Analysis},
  booktitle = {Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies},
  month     = {June},
  year      = {2011},
  address   = {Portland, Oregon, USA},
  publisher = {Association for Computational Linguistics},
  pages     = {142--150},
  url       = {http://www.aclweb.org/anthology/P11-1015}
}

@misc{misc_sms_spam_collection_228,
  author       = {Almeida, Tiago},
  title        = {{SMS Spam Collection}},
  year         = {2012},
  howpublished = {UCI Machine Learning Repository}
}