Accuracy Scorer

In this tutorial we will learn how to implement an accuracy scorer. Our accuracy scorer will be a simple mean squared error accuracy score, which is a common accuracy metric for regeression models.

We will be working in a new file, mse.py

Imports

First we need to import few modules from the standard library

  • config, this will help us creating config for the accuracy scorer that we will be implementing.

  • AccuracyScorer and AccuracyContext provides the base scorer that we need to inherit and override the score method.

Few other modules will also be required which needs to be imported.

from ..base import config
from ..feature import Feature
from ..model import ModelContext
from ..util.entrypoint import entrypoint
from ..source.source import SourcesContext
from .accuracy import (
    AccuracyScorer,
    AccuracyContext,
)

Config

Here we will be implementing MeanSquaredErrorAccuracyConfig, a config class for the scorer that we will be implementing.

@config
class MeanSquaredErrorAccuracyConfig:
    pass

Context

Now we will be implementing MeanSquaredErrorAccuracyContext, which inherits from AccuracyContext.

Here we will be implementing a score method which would take the ModelContext and Sources, using this we would have access the model’s config and the sources records.

class MeanSquaredErrorAccuracyContext(AccuracyContext):
    """
    Mean Squared Error
    """

    async def score(
        self, mctx: ModelContext, sources: SourcesContext, feature: Feature,
    ):
        y = []
        y_predict = []
        async for record in mctx.predict(sources):
            y.append(record.feature(feature.name))
            y_predict.append(record.prediction(feature.name).value)
        accuracy = sum(
            list(map(lambda x, y: abs(x - y) ** 2, y, y_predict))
        ) / len(y)
        return accuracy

Scorer

Now we will be implementing our MeanSquaredErrorAccuracy this would inherit from the AccuracyScorer. Here we will also create an entrypoint for it, so we can also use this scorer in the cli.

@entrypoint("mse")
class MeanSquaredErrorAccuracy(AccuracyScorer):
    CONFIG = MeanSquaredErrorAccuracyConfig
    CONTEXT = MeanSquaredErrorAccuracyContext