Source code for niftynet.evaluation.evaluation_application_driver

# -*- coding: utf-8 -*-
"""
This module defines a general procedure for running evaluations
Example usage:
    app_driver = EvaluationApplicationDriver()
    app_driver.initialise_application(system_param, input_data_param)
    app_driver.run_application()

system_param and input_data_param should be generated using:
niftynet.utilities.user_parameters_parser.run()
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import itertools

import pandas as pd
import tensorflow as tf

from niftynet.engine.application_factory import ApplicationFactory
from niftynet.io.misc_io import touch_folder
from niftynet.io.image_sets_partitioner import ImageSetsPartitioner

FILE_PREFIX = 'model.ckpt'


[docs]class EvaluationApplicationDriver(object): """ This class represents the application logic for evaluating a set of results inferred within NiftyNet (or externally generated) """ def __init__(self): self.app = None self.model_dir = None self.summary_dir = None self.session_prefix = None self.outputs_collector = None self.gradients_collector = None
[docs] def initialise_application(self, workflow_param, data_param): """ This function receives all parameters from user config file, create an instance of application. :param workflow_param: a dictionary of user parameters, keys correspond to sections in the config file :param data_param: a dictionary of input image parameters, keys correspond to data properties to be used by image_reader :return: """ try: system_param = workflow_param.get('SYSTEM', None) net_param = workflow_param.get('NETWORK', None) infer_param = workflow_param.get('INFERENCE', None) eval_param = workflow_param.get('EVALUATION', None) app_param = workflow_param.get('CUSTOM', None) except AttributeError: tf.logging.fatal('parameters should be dictionaries') raise self.num_threads = 1 # self.num_threads = max(system_param.num_threads, 1) # self.num_gpus = system_param.num_gpus # set_cuda_device(system_param.cuda_devices) # set output TF model folders self.model_dir = touch_folder( os.path.join(system_param.model_dir, 'models')) self.session_prefix = os.path.join(self.model_dir, FILE_PREFIX) assert infer_param, 'inference parameters not specified' # create an application instance assert app_param, 'application specific param. not specified' self.app_param = app_param app_module = ApplicationFactory.create(app_param.name) self.app = app_module(net_param, infer_param, system_param.action) self.eval_param = eval_param data_param, self.app_param = \ self.app.add_inferred_output(data_param, self.app_param) # initialise data input data_partitioner = ImageSetsPartitioner() # clear the cached file lists data_partitioner.reset() if data_param: data_partitioner.initialise( data_param=data_param, new_partition=False, ratios=None, data_split_file=system_param.dataset_split_file) # initialise data input self.app.initialise_dataset_loader(data_param, self.app_param, data_partitioner) self.app.initialise_evaluator(eval_param)
[docs] def run(self, application): """ This is the main application logic for evaluation. Computation of all metrics for all subjects is delegated to an Evaluator objects owned by the application object. The resulting metrics are aggregated as defined by the evaluation classes and output to one or more csv files (based on their 'group_by' headings). For example, per-subject metrics will be in one file, per-label-class metrics will be in another and per-subject-per-class will be in a third. :return: """ start_time = time.time() try: if not os.path.exists(self.eval_param.save_csv_dir): os.makedirs(self.eval_param.save_csv_dir) # iteratively run the graph all_results = application.evaluator.evaluate() for group_by, data_frame in all_results.items(): if group_by == (None,): csv_id = '' else: csv_id = '_'.join(group_by) with open(os.path.join(self.eval_param.save_csv_dir, 'eval_' + csv_id + '.csv'), 'w') as csv: csv.write(data_frame.reset_index().to_csv(index=False)) except KeyboardInterrupt: tf.logging.warning('User cancelled application') except RuntimeError: import sys import traceback exc_type, exc_value, exc_traceback = sys.exc_info() traceback.print_exception( exc_type, exc_value, exc_traceback, file=sys.stdout) finally: tf.logging.info('Cleaning up...') tf.logging.info( "%s stopped (time in second %.2f).", type(application).__name__, (time.time() - start_time))