Source code for dffml.cli.ml

import inspect

from ..model.model import Model
from ..source.source import Sources, SubsetSources
from ..util.cli.cmd import CMD, CMDOutputOverride
from ..high_level.ml import train, predict, score
from ..util.config.fields import FIELD_SOURCES
from ..util.cli.cmds import (
    SourcesCMD,
    ModelCMD,
    KeysCMD,
    ModelCMDConfig,
    SourcesCMDConfig,
    KeysCMDConfig,
)
from ..base import config, field
from ..accuracy import AccuracyScorer
from ..feature import Features


[docs]@config class MLCMDConfig(SourcesCMDConfig, ModelCMDConfig): pass
[docs]@config class AccuracyCMDConfig: model: Model = field("Model used for ML", required=True) scorer: AccuracyScorer = field( "Method to use to score accuracy", required=True ) features: Features = field("Predict Feature(s)", default=Features()) sources: Sources = FIELD_SOURCES
[docs]class MLCMD(SourcesCMD, ModelCMD): """ Commands which use models share many similar arguments. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) ModelCMD.__init__(self, *args, **kwargs)
[docs]class Train(MLCMD): """Train a model on data from given sources""" """ changes : model(features) -> model() """ CONFIG = MLCMDConfig async def run(self): return await train(self.model, self.sources)
[docs]class Accuracy(MLCMD): """Assess model accuracy on data from given sources""" CONFIG = AccuracyCMDConfig async def run(self): # Instantiate the accuracy scorer class if for some reason it is a class # at this point rather than an instance. if inspect.isclass(self.scorer): self.scorer = self.scorer.withconfig(self.extra_config) return await score( self.model, self.scorer, self.features, self.sources )
[docs]@config class PredictAllConfig(MLCMDConfig): update: bool = field( "Update record with sources", default=False, ) pretty: bool = field( "Outputs data in tabular form", default=False, )
[docs]class PredictAll(MLCMD): """Predicts for all sources""" CONFIG = PredictAllConfig async def run(self): async for record in predict( self.model, self.sources, update=self.update, keep_record=True ): if self.pretty: print(record) else: yield record if self.pretty: yield CMDOutputOverride
[docs]@config class PredictRecordConfig(PredictAllConfig, KeysCMDConfig): pass
[docs]class PredictRecord(PredictAll, KeysCMD): """Predictions for individual records""" CONFIG = PredictRecordConfig def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.sources = SubsetSources(*self.sources, keys=self.keys)
[docs]class Predict(CMD): """Evaluate features against records and produce a prediction""" record = PredictRecord _all = PredictAll