COVID Forecasting

Stay home and stay safe folks. This is a sad application of machine learning.

This example covers the following using the Python API.

  • Writing a new model. We’ll be wrapping fbprophet.

    • Overwriting a SimpleModel __init__() method.

    • Using joblib to save and load the fbprophet model class.

  • Creating Pandas DataFrames from DFFML Record objects.

  • Downloading dataset files with hash verification.

  • Loading data from files using the high level API.

  • Using multiple models. Model predictions will be used as feature data for another model. This problem can be approached many ways we have implemented it the way it is to show how one models prediction can be used as the input to another model as a feature.

Plan

The dataset we’ll be working with is the state of Oregon’s COVID case and death numbers by county.

Our goal is to forecast the number of cases and deaths in each county on any given day.

The dataset has already been divided into two files, training and test. We’ll download them using the cached_download() function.

fbprophet is great at forecasting int or float values given a date. Therefore, we’re going to wrap it using a model class (for more explanation first read the Writing a Model tutorial).

Setup

We need to install dependencies we’ll be importing.

See the https://facebook.github.io/prophet/docs/installation.html#python for more details

$ python -m pip install joblib pandas "pystan==2.19.1.1"
$ python -m pip install fbprophet

Code

There are lots of comments in the following code so we’re not going to explain everything in this documentation page. Please open an issue or ask on Gitter if there is anything that is unclear or you think could be changed / improved.

or_covid_data_by_county.py

import asyncio
import pathlib
import datetime
from typing import AsyncIterator, Type

from dffml import *

import joblib
import pandas
from fbprophet import Prophet


@config
class FBProphetModelConfig:
    date: Feature = field("Name of feature containing date value")
    predict: Feature = field("Label or the value to be predicted")
    location: pathlib.Path = field("Location where state should be saved")


@entrypoint("fbprophet")
class FBProphetModel(SimpleModel):
    # The configuration class needs to be set as the CONFIG property
    CONFIG: Type = FBProphetModelConfig

    def __init__(self, config) -> None:
        super().__init__(config)
        # The saved model
        self.saved = None
        self.saved_filepath = pathlib.Path(
            self.config.location, "model.joblib"
        )
        # Load saved model if it exists
        if self.saved_filepath.is_file():
            self.saved = joblib.load(str(self.saved_filepath))
            self.is_trained = True

    async def train(self, sources: SourcesContext) -> None:
        # Create a pandas DataFrame from the records with the features we care
        # about. Prophet wants each row in the DataFrame to have two features,
        # ds for the date, and y for the label (the value to predict).
        df = pandas.DataFrame.from_records(
            [
                {
                    "ds": record.feature(self.config.date.name),
                    "y": record.feature(self.config.predict.name),
                }
                async for record in sources.with_features(
                    [self.config.date.name, self.config.predict.name]
                )
            ]
        )
        # Use self.logger to report how many records are being used for training
        self.logger.debug("Number of training records: %d", len(df))
        # Create an instance of the Prophet model class
        self.saved = Prophet()
        # Train the model
        self.saved.fit(df)
        # Save the model
        joblib.dump(self.saved, str(self.saved_filepath))
        # Set the is_trained flag to True
        self.is_trained = True

    async def accuracy(self, sources: SourcesContext) -> Accuracy:
        # We're not going to calculate accuracy in the example
        raise NotImplementedError()

    async def predict(self, sources: SourcesContext) -> AsyncIterator[Record]:
        # Ensure there is a saved model
        if not self.is_trained:
            raise ModelNotTrained("Train the model before making predictions")
        # Load all the records into memory for the sake of speed
        records = [
            record
            async for record in sources.with_features([self.config.date.name])
        ]
        # Create a pandas DataFrame from the records that have a date feature
        future = pandas.DataFrame.from_records(
            [
                {"ds": record.feature(self.config.date.name)}
                for record in records
            ]
        )
        # Ask the model for predictions
        for record, value in zip(
            records, self.saved.predict(future).itertuples()
        ):
            # Set the predicted value
            # TODO do not use nan for confidence
            record.predicted(
                self.config.predict.name, value.yhat, float("nan")
            )
            # Yield the record
            yield record


async def main():
    # DFFML has a function to download files and validate their contents using
    # SHA 384 hashes. If you need to download files from an http:// site, you need
    # to add the following to the call to cached_download()
    #   protocol_allowlist=["https://", "http://"]
    training_file = await cached_download(
        "https://github.com/intel/dffml/files/5773999/COVID.Oregon.Counties.Train.Clean.to.2020-10-24.csv.gz",
        "training.csv.gz",
        "af9536ab41580e04dd72b1285f6b2b703977aee5b95b80422bbe7cc11262297da265e6c0e333bfc1faa7b4f263f5496e",
    )
    test_file = await cached_download(
        "https://github.com/intel/dffml/files/5773998/COVID.Oregon.Counties.Test.Clean.2020-10-25.to.2020-10-31.csv.gz",
        "test.csv.gz",
        "10ee8bcf06a511019f98c3e0e40f315585b2ed84d4a736f743567861d72438afcb7914f117e16640800959324f0f518d",
    )

    # Load the training data
    training_data = [record async for record in load(training_file)]

    # Load the test data
    test_data = [record async for record in load(test_file)]

    # Deaths and cases should have a relatively linear relationship. The cases
    # to deaths model will be trained to predict the number of deaths given the
    # number of cases. Try swapping the SLRModel for another model on the
    # plugins page.
    cases_to_deaths_model = SLRModel(
        features=Features(Feature("cases", int)),
        predict=Feature("deaths", int),
        location="cases_to_deaths.model",
    )

    # Train the model to learn the relationship between cases and deaths
    await train(cases_to_deaths_model, *training_data)

    # Find the set of counties by looking through all the training data and
    # recording each county seen
    counties = set([record.feature("county") for record in training_data])

    # We want to forecast the number of cases by county. We'll train a
    # forecasting model on past data to predict the number of cases given the
    # date for the given county.
    date_to_cases_models = {
        county: FBProphetModel(
            date=Feature("date", str),
            predict=Feature("cases", int),
            location=f"date_to_cases.{county}.model",
        )
        for county in counties
    }

    # Group training data by county
    training_data_by_county = {county: [] for county in counties}
    for record in training_data:
        training_data_by_county[record.feature("county")].append(record)

    # Group test data by county
    test_data_by_county = {county: [] for county in counties}
    for record in test_data:
        test_data_by_county[record.feature("county")].append(record)

    # Get today's date
    todays_date = datetime.datetime.now()

    for county, model in date_to_cases_models.items():
        # Train a model for each county
        await train(model, *training_data_by_county[county])
        # We want predictions for the test data for this county
        want_predictions = test_data_by_county[county]
        # We also want to ask for today through four days from now
        prediction_dates = [
            (todays_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d")
            for i in range(0, 5)
        ]
        # We append those records to the set of records we want predictions for
        want_predictions += [
            Record(
                key=date, data={"features": {"county": county, "date": date}}
            )
            for date in prediction_dates
        ]
        # Predict the number of cases for the county
        async for record in predict(
            model, *want_predictions, keep_record=True
        ):
            # Get predicted value for cases
            predicted_cases = record.prediction("cases")["value"]
            # Report actual value for cases if we have it
            actual_cases = "Actual Cases Unknown"
            features = record.features()
            if "cases" in features:
                actual_cases = features["cases"]
            else:
                # If we don't have an actual value for cases set the cases
                # feature to the predicted value so that we can feed it into the
                # next model for a number of deaths prediction
                record.evaluated({"cases": predicted_cases})
            # Use the cases to deaths model to predict the deaths. It's a loop
            # but we're only feeding it the one record from the loop we're in.
            async for record in predict(
                cases_to_deaths_model, record, keep_record=True
            ):
                # Get predicted value for deaths
                predicted_deaths = record.prediction("deaths")["value"]
                # Report actual value for deaths if we have it
                actual_deaths = "Actual Deaths Unknown"
                if "deaths" in features:
                    actual_deaths = features["deaths"]
                # Print out predictions
                print("---------------------------------------")
                print(f'county:           : {features["county"]}')
                print(f'date:             : {features["date"]}')
                print(f"predicted_cases   : {predicted_cases}")
                print(f"actual_cases      : {actual_cases}")
                print(f"predicted_deaths  : {predicted_deaths}")
                print(f"actual_deaths     : {actual_deaths}")


if __name__ == "__main__":
    asyncio.run(main())

Run the file to see the predictions for the test data along with today plus the next four days.

$ python or_covid_data_by_county.py
INFO:fbprophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:fbprophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
Initial log joint probability = -4.48963
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
      99       470.387     0.0198861       992.252       0.934       0.934      117
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     195       493.794   0.000557792       208.725   1.765e-06       0.001      272  LS failed, Hessian reset
     199       494.869     0.0142908       273.213           1           1      276
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     274       506.305    0.00034089       135.527   1.069e-06       0.001      410  LS failed, Hessian reset
     299       512.079     0.0479183       140.335           1           1      443
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     399       523.789   0.000554393       243.109   3.188e-06       0.001      599  LS failed, Hessian reset
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     499       530.926     0.0021352       316.262      0.4123      0.4123      721
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     597       537.419   0.000436284       246.382   7.017e-07       0.001      876  LS failed, Hessian reset
     599       537.634    0.00161266       87.1624           1           1      878
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     699       541.955   0.000221894       75.1759      0.8981      0.8981     1009
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     799       542.227     0.0031715       187.747           1           1     1137
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     899       542.449   0.000102382       70.2315      0.6954      0.6954     1257
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
     999       544.304    0.00911989       138.497           1           1     1379
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
    1099       546.541   2.90923e-06       78.8871      0.2538      0.2538     1509
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
    1199       546.616    0.00152174       144.315           1           1     1627
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
    1274       547.735   0.000108332       104.365   7.121e-07       0.001     1759  LS failed, Hessian reset
    1299       547.759   8.21456e-06       72.6784      0.6158      0.6158     1794
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
    1317       547.761   5.74449e-06       68.9362   6.756e-08       0.001     1853  LS failed, Hessian reset
    1399       547.886    0.00178702       80.0121           1           1     1949
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes
    1403       547.901   9.98032e-05       107.405   1.012e-06       0.001     1996  LS failed, Hessian reset
    1471       548.011    8.6135e-09        74.944      0.3285      0.3285     2085
Optimization terminated normally:
  Convergence detected: absolute parameter change was below tolerance
---------------------------------------
county:           : Lincoln
date:             : 2020-10-25
predicted_cases   : 510.5316275788172
actual_cases      : 517
predicted_deaths  : 9
actual_deaths     : 13
---------------------------------------
county:           : Lincoln
date:             : 2020-10-26
predicted_cases   : 512.1564675524057
actual_cases      : 517
predicted_deaths  : 9
actual_deaths     : 13
---------------------------------------
county:           : Lincoln
date:             : 2020-10-27
predicted_cases   : 511.7605249084097
actual_cases      : 517
predicted_deaths  : 9
actual_deaths     : 13
---------------------------------------
county:           : Lincoln
date:             : 2020-10-28
predicted_cases   : 512.410408061449
actual_cases      : 518
predicted_deaths  : 9
actual_deaths     : 13
---------------------------------------
county:           : Lincoln
date:             : 2020-10-29
predicted_cases   : 512.9605657720964
actual_cases      : 518
predicted_deaths  : 9
actual_deaths     : 13
---------------------------------------
county:           : Lincoln
date:             : 2020-10-30
predicted_cases   : 514.5428400672448
actual_cases      : 519
predicted_deaths  : 9
actual_deaths     : 13
---------------------------------------
county:           : Lincoln
date:             : 2020-10-31
predicted_cases   : 514.999483186029
actual_cases      : 519
predicted_deaths  : 9
actual_deaths     : 13
INFO:dffml.record:Evaluated 2021-01-05 {'county': 'Lincoln', 'date': '2021-01-05', 'cases': 576.0638433267727}
---------------------------------------
county:           : Lincoln
date:             : 2021-01-05
predicted_cases   : 576.0638433267727
actual_cases      : Actual Cases Unknown
predicted_deaths  : 10
actual_deaths     : Actual Deaths Unknown
INFO:dffml.record:Evaluated 2021-01-06 {'county': 'Lincoln', 'date': '2021-01-06', 'cases': 576.7137264798116}
---------------------------------------
county:           : Lincoln
date:             : 2021-01-06
predicted_cases   : 576.7137264798116
actual_cases      : Actual Cases Unknown
predicted_deaths  : 10
actual_deaths     : Actual Deaths Unknown
INFO:dffml.record:Evaluated 2021-01-07 {'county': 'Lincoln', 'date': '2021-01-07', 'cases': 577.2638841904583}
---------------------------------------
county:           : Lincoln
date:             : 2021-01-07
predicted_cases   : 577.2638841904583
actual_cases      : Actual Cases Unknown
predicted_deaths  : 10
actual_deaths     : Actual Deaths Unknown
INFO:dffml.record:Evaluated 2021-01-08 {'county': 'Lincoln', 'date': '2021-01-08', 'cases': 578.8461584856071}
---------------------------------------
county:           : Lincoln
date:             : 2021-01-08
predicted_cases   : 578.8461584856071
actual_cases      : Actual Cases Unknown
predicted_deaths  : 10
actual_deaths     : Actual Deaths Unknown
INFO:dffml.record:Evaluated 2021-01-09 {'county': 'Lincoln', 'date': '2021-01-09', 'cases': 579.3028016043917}
---------------------------------------
county:           : Lincoln
date:             : 2021-01-09
predicted_cases   : 579.3028016043917
actual_cases      : Actual Cases Unknown
predicted_deaths  : 10
actual_deaths     : Actual Deaths Unknown