neural_compressor.contrib.strategy.tpe
¶
Fefine the tuning strategy that uses tpe search in tuning space.
Module Contents¶
Classes¶
The tuning strategy using tpe search in tuning space. |
- class neural_compressor.contrib.strategy.tpe.TpeTuneStrategy(model, conf, q_dataloader, q_func=None, eval_dataloader=None, eval_func=None, dicts=None, q_hooks=None)¶
Bases:
neural_compressor.strategy.strategy.TuneStrategy
The tuning strategy using tpe search in tuning space.
- Parameters:
model (object) – The FP32 model specified for low precision tuning.
conf (Conf) – The Conf class instance initialized from user yaml config file.
q_dataloader (generator) – Data loader for calibration, mandatory for post-training quantization. It is iterable and should yield a tuple (input, label) for calibration dataset containing label, or yield (input, _) for label-free calibration dataset. The input could be a object, list, tuple or dict, depending on user implementation, as well as it can be taken as model input.
q_func (function, optional) – Reserved for future use.
eval_dataloader (generator, optional) – Data loader for evaluation. It is iterable and should yield a tuple of (input, label). The input could be a object, list, tuple or dict, depending on user implementation, as well as it can be taken as model input. The label should be able to take as input of supported metrics. If this parameter is not None, user needs to specify pre-defined evaluation metrics through configuration file and should set “eval_func” parameter as None. Tuner will combine model, eval_dataloader and pre-defined metrics to run evaluation process.
eval_func (function, optional) –
The evaluation function provided by user. This function takes model as parameter, and evaluation dataset and metrics should be encapsulated in this function implementation and outputs a higher-is-better accuracy scalar value.
The pseudo code should be something like:
- def eval_func(model):
input, label = dataloader() output = model(input) accuracy = metric(output, label) return accuracy
dicts (dict, optional) – The dict containing resume information. Defaults to None.
- traverse()¶
Tpe traverse logic.
- add_loss_to_tuned_history_and_find_best(tuning_history_list)¶
Find the best tuned history.
- object_evaluation(tune_cfg, model)¶
Check if config was alredy evaluated.
- calculate_loss(acc_diff, lat_diff, config)¶
Calculate the accuracy loss.
- stop(timeout, trials_count)¶
Check if need to stop traversing the tuning space, either accuracy goal is met or timeout is reach.
- Returns:
True if need stop, otherwise False.
- Return type:
bool