Source code for dffml.accuracy.mse

from ..base import config
from ..feature import Feature
from ..model import ModelContext
from ..util.entrypoint import entrypoint
from ..source.source import SourcesContext
from .accuracy import (
    AccuracyScorer,
    AccuracyContext,
)


[docs]@config class MeanSquaredErrorAccuracyConfig: pass
[docs]class MeanSquaredErrorAccuracyContext(AccuracyContext): """ Mean Squared Error """
[docs] async def score( self, mctx: ModelContext, sources: SourcesContext, feature: Feature, ): y = [] y_predict = [] async for record in mctx.predict(sources): y.append(record.feature(feature.name)) y_predict.append(record.prediction(feature.name).value) accuracy = sum( list(map(lambda x, y: abs(x - y) ** 2, y, y_predict)) ) / len(y) return accuracy
[docs]@entrypoint("mse") class MeanSquaredErrorAccuracy(AccuracyScorer): CONFIG = MeanSquaredErrorAccuracyConfig CONTEXT = MeanSquaredErrorAccuracyContext