from ..base import config
from ..feature import Feature
from ..model import ModelContext
from ..util.entrypoint import entrypoint
from ..source.source import SourcesContext
from .accuracy import (
AccuracyScorer,
AccuracyContext,
)
[docs]@config
class MeanSquaredErrorAccuracyConfig:
pass
[docs]class MeanSquaredErrorAccuracyContext(AccuracyContext):
"""
Mean Squared Error
"""
[docs] async def score(
self, mctx: ModelContext, sources: SourcesContext, feature: Feature,
):
y = []
y_predict = []
async for record in mctx.predict(sources):
y.append(record.feature(feature.name))
y_predict.append(record.prediction(feature.name).value)
accuracy = sum(
list(map(lambda x, y: abs(x - y) ** 2, y, y_predict))
) / len(y)
return accuracy
[docs]@entrypoint("mse")
class MeanSquaredErrorAccuracy(AccuracyScorer):
CONFIG = MeanSquaredErrorAccuracyConfig
CONTEXT = MeanSquaredErrorAccuracyContext