Source code for niftynet.io.image_sets_partitioner

# -*- coding: utf-8 -*-
"""
This module manages a table of subject ids and
their associated image file names.
A subset of the table can be retrieved by partitioning the set of images into
subsets of ``Train``, ``Validation``, ``Inference``.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os
import random
import shutil

import pandas
import tensorflow as tf  # to use the system level logging

from niftynet.engine.signal import TRAIN, VALID, INFER, ALL
from niftynet.utilities.decorators import singleton
from niftynet.utilities.filename_matching import KeywordsMatching
from niftynet.utilities.niftynet_global_config import NiftyNetGlobalConfig
from niftynet.utilities.util_common import look_up_operations
from niftynet.utilities.util_csv import match_and_write_filenames_to_csv
from niftynet.utilities.util_csv import write_csv

COLUMN_UNIQ_ID = 'subject_id'
COLUMN_PHASE = 'phase'
SUPPORTED_PHASES = {TRAIN, VALID, INFER, ALL}


[docs]@singleton class ImageSetsPartitioner(object): """ This class maintains a pandas.dataframe of filenames for all input sections. The list of filenames are obtained by searching the specified folders or loading from an existing csv file. Users can query a subset of the dataframe by train/valid/infer partition label and input section names. """ # dataframe (table) of file names in a shape of subject x modality _file_list = None # dataframes of subject_id:phase_id _partition_ids = None data_param = None ratios = None new_partition = False # for saving the splitting index data_split_file = "" # default parent folder location for searching the image files default_image_file_location = \ NiftyNetGlobalConfig().get_niftynet_home_folder()
[docs] def initialise(self, data_param, new_partition=False, data_split_file=None, ratios=None): """ Set the data partitioner parameters :param data_param: corresponding to all config sections :param new_partition: bool value indicating whether to generate new partition ids and overwrite csv file (this class will write partition file iff new_partition) :param data_split_file: location of the partition id file :param ratios: a tuple/list with two elements: ``(fraction of the validation set, fraction of the inference set)`` initialise to None will disable data partitioning and get_file_list always returns all subjects. """ self.data_param = data_param if data_split_file is None: self.data_split_file = os.path.join('.', 'dataset_split.csv') else: self.data_split_file = data_split_file self.ratios = ratios self._file_list = None self._partition_ids = None self.load_data_sections_by_subject() self.new_partition = new_partition self.randomly_split_dataset(overwrite=new_partition) tf.logging.info(self) return self
[docs] def number_of_subjects(self, phase=ALL): """ query number of images according to phase. :param phase: :return: """ if self._file_list is None: return 0 try: phase = look_up_operations(phase.lower(), SUPPORTED_PHASES) except (ValueError, AttributeError): tf.logging.fatal('Unknown phase argument.') raise if phase == ALL: return self._file_list[COLUMN_UNIQ_ID].count() if self._partition_ids is None: return 0 selector = self._partition_ids[COLUMN_PHASE] == phase selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]] subset = pandas.merge( self._file_list, selected, on=COLUMN_UNIQ_ID, sort=True) return subset.count()[COLUMN_UNIQ_ID]
[docs] def get_file_list(self, phase=ALL, *section_names): """ get file names as a dataframe, by partitioning phase and section names set phase to ALL to load all subsets. :param phase: the label of the subset generated by self._partition_ids should be one of the SUPPORTED_PHASES :param section_names: one or multiple input section names :return: a pandas.dataframe of file names """ if self._file_list is None: tf.logging.warning('Empty file list, please initialise' 'ImageSetsPartitioner first.') return [] try: phase = look_up_operations(phase.lower(), SUPPORTED_PHASES) except (ValueError, AttributeError): tf.logging.fatal('Unknown phase argument.') raise for name in section_names: try: look_up_operations(name, set(self._file_list)) except ValueError: tf.logging.fatal( 'Requesting files under input section [%s],\n' 'however the section does not exist in the config.', name) raise if phase == ALL: self._file_list = self._file_list.sort_values(COLUMN_UNIQ_ID) if section_names: section_names = [COLUMN_UNIQ_ID] + list(section_names) return self._file_list[section_names] return self._file_list if self._partition_ids is None or self._partition_ids.empty: tf.logging.fatal('No partition ids available.') if self.new_partition: tf.logging.fatal('Unable to create new partitions,' 'splitting ratios: %s, writing file %s', self.ratios, self.data_split_file) elif os.path.isfile(self.data_split_file): tf.logging.fatal( 'Unable to load %s, initialise the' 'ImageSetsPartitioner with `new_partition=True`' 'to overwrite the file.', self.data_split_file) raise ValueError selector = self._partition_ids[COLUMN_PHASE] == phase selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]] if selected.empty: tf.logging.warning( 'Empty subset for phase [%s], returning None as file list. ' 'Please adjust splitting fractions.', phase) return None subset = pandas.merge( self._file_list, selected, on=COLUMN_UNIQ_ID, sort=True) if subset.empty: tf.logging.warning( 'No subject id matched in between file names and ' 'partition files.\nPlease check the partition files %s,\nor ' 'removing it to generate a new file automatically.', self.data_split_file) if section_names: section_names = [COLUMN_UNIQ_ID] + list(section_names) return subset[section_names] return subset
[docs] def load_data_sections_by_subject(self): """ Go through all input data sections, converting each section to a list of file names. These lists are merged on ``COLUMN_UNIQ_ID``. This function sets ``self._file_list``. """ if not self.data_param: tf.logging.fatal( 'Nothing to load, please check input sections in the config.') raise ValueError self._file_list = None for section_name in self.data_param: if isinstance(self.data_param[section_name], dict): mod_spec = self.data_param[section_name] else: mod_spec = vars(self.data_param[section_name]) if mod_spec.get('csv_data_file', None): # has csv_data_file # skip file search continue modality_file_list = self.grep_files_by_data_section(section_name) if self._file_list is None: # adding all rows of the first modality self._file_list = modality_file_list continue n_rows = self._file_list[COLUMN_UNIQ_ID].count() self._file_list = pandas.merge(self._file_list, modality_file_list, how='outer', on=COLUMN_UNIQ_ID) if self._file_list[COLUMN_UNIQ_ID].count() < n_rows: tf.logging.warning('rows not matched in section [%s]', section_name) if self._file_list is None or self._file_list.size == 0: tf.logging.fatal( "Empty filename lists, please check the csv " "files (removing csv_file keyword if it is in the config file " "to automatically search folders and generate new csv " "files again).\n\n" "Please note in the matched file names, each subject id are " "created by removing all keywords listed `filename_contains` " "in the config.\n" "E.g., `filename_contains=foo, bar` will match file " "foo_subject42_bar.nii.gz, and the subject id is " "_subject42_.\n\n") raise IOError
[docs] def grep_files_by_data_section(self, modality_name): """ list all files by a given input data section:: if the ``csv_file`` property of ``data_param[modality_name]`` corresponds to a file, read the list from the file; otherwise write the list to ``csv_file``. :return: a table with two columns, the column names are ``(COLUMN_UNIQ_ID, modality_name)``. """ if modality_name not in self.data_param: tf.logging.fatal('unknown section name [%s], ' 'current input section names: %s.', modality_name, list(self.data_param)) raise ValueError # input data section must have a ``csv_file`` section for loading # or writing filename lists if isinstance(self.data_param[modality_name], dict): mod_spec = self.data_param[modality_name] else: mod_spec = vars(self.data_param[modality_name]) ######################### # guess the csv_file path ######################### temp_csv_file = None try: csv_file = os.path.expanduser(mod_spec.get('csv_file', None)) if not os.path.isfile(csv_file): # writing to the same folder as data_split_file default_csv_file = os.path.join( os.path.dirname(self.data_split_file), '{}.csv'.format(modality_name)) tf.logging.info('`csv_file = %s` not found, ' 'writing to "%s" instead.', csv_file, default_csv_file) csv_file = default_csv_file if os.path.isfile(csv_file): tf.logging.info('Overwriting existing: "%s".', csv_file) csv_file = os.path.abspath(csv_file) except (AttributeError, KeyError, TypeError): tf.logging.debug('`csv_file` not specified, writing the list of ' 'filenames to a temporary file.') import tempfile temp_csv_file = os.path.join( tempfile.mkdtemp(), '{}.csv'.format(modality_name)) csv_file = temp_csv_file ############################################# # writing csv file if path_to_search specified ############################################## if mod_spec.get('path_to_search', None): if not temp_csv_file: tf.logging.info( '[%s] search file folders, writing csv file %s', modality_name, csv_file) # grep files by section properties and write csv try: matcher = KeywordsMatching.from_dict( input_dict=mod_spec, default_folder=self.default_image_file_location) match_and_write_filenames_to_csv([matcher], csv_file) except (IOError, ValueError) as reading_error: tf.logging.warning('Ignoring input section: [%s], ' 'due to the following error:', modality_name) tf.logging.warning(repr(reading_error)) return pandas.DataFrame( columns=[COLUMN_UNIQ_ID, modality_name]) else: tf.logging.info( '[%s] using existing csv file %s, skipped filenames search', modality_name, csv_file) if not os.path.isfile(csv_file): tf.logging.fatal( '[%s] csv file %s not found.', modality_name, csv_file) raise IOError ############################### # loading the file as dataframe ############################### try: csv_list = pandas.read_csv( csv_file, header=None, dtype=(str, str), names=[COLUMN_UNIQ_ID, modality_name], skipinitialspace=True) except Exception as csv_error: tf.logging.fatal(repr(csv_error)) raise if temp_csv_file: shutil.rmtree(os.path.dirname(temp_csv_file), ignore_errors=True) return csv_list
# pylint: disable=broad-except
[docs] def randomly_split_dataset(self, overwrite=False): """ Label each subject as one of the ``TRAIN``, ``VALID``, ``INFER``, use ``self.ratios`` to compute the size of each set. The results will be written to ``self.data_split_file`` if overwrite otherwise it tries to read partition labels from it. This function sets ``self._partition_ids``. """ if overwrite: try: valid_fraction, infer_fraction = self.ratios valid_fraction = max(min(1.0, float(valid_fraction)), 0.0) infer_fraction = max(min(1.0, float(infer_fraction)), 0.0) except (TypeError, ValueError): tf.logging.fatal( 'Unknown format of faction values %s', self.ratios) raise if (valid_fraction + infer_fraction) <= 0: tf.logging.warning( 'To split dataset into training/validation, ' 'please make sure ' '"exclude_fraction_for_validation" parameter is set to ' 'a float in between 0 and 1. Current value: %s.', valid_fraction) # raise ValueError n_total = self.number_of_subjects() n_valid = int(math.ceil(n_total * valid_fraction)) n_infer = int(math.ceil(n_total * infer_fraction)) n_train = int(n_total - n_infer - n_valid) phases = [TRAIN] * n_train + [VALID] * n_valid + [INFER] * n_infer if len(phases) > n_total: phases = phases[:n_total] random.shuffle(phases) write_csv(self.data_split_file, zip(self._file_list[COLUMN_UNIQ_ID], phases)) elif os.path.isfile(self.data_split_file): tf.logging.warning( 'Loading from existing partitioning file %s, ' 'ignoring partitioning ratios.', self.data_split_file) if os.path.isfile(self.data_split_file): try: self._partition_ids = pandas.read_csv( self.data_split_file, header=None, dtype=(str, str), names=[COLUMN_UNIQ_ID, COLUMN_PHASE], skipinitialspace=True) assert not self._partition_ids.empty, \ "partition file is empty." except Exception as csv_error: tf.logging.warning( "Unable to load the existing partition file %s, %s", self.data_split_file, repr(csv_error)) self._partition_ids = None try: phase_strings = self._partition_ids[COLUMN_PHASE] phase_strings = phase_strings.astype(str).str.lower() is_valid_phase = phase_strings.isin(SUPPORTED_PHASES) assert is_valid_phase.all(), \ "Partition file contains unknown phase id." self._partition_ids[COLUMN_PHASE] = phase_strings except (TypeError, AssertionError): tf.logging.warning( 'Please make sure the values of the second column ' 'of data splitting file %s, in the set of phases: %s.\n' 'Remove %s to generate random data partition file.', self.data_split_file, SUPPORTED_PHASES, self.data_split_file) raise ValueError
def __str__(self): return self.to_string()
[docs] def to_string(self): """ Print summary of the partitioner. """ n_subjects = self.number_of_subjects() summary_str = '\n\nNumber of subjects {}, '.format(n_subjects) if self._file_list is not None: summary_str += 'input section names: {}\n'.format( list(self._file_list)) if self._partition_ids is not None and n_subjects > 0: n_train = self.number_of_subjects(TRAIN) n_valid = self.number_of_subjects(VALID) n_infer = self.number_of_subjects(INFER) summary_str += \ 'Dataset partitioning:\n' \ '-- {} {} cases ({:.2f}%),\n' \ '-- {} {} cases ({:.2f}%),\n' \ '-- {} {} cases ({:.2f}%).\n'.format( TRAIN, n_train, float(n_train) / float(n_subjects) * 100.0, VALID, n_valid, float(n_valid) / float(n_subjects) * 100.0, INFER, n_infer, float(n_infer) / float(n_subjects) * 100.0) else: summary_str += '-- using all subjects ' \ '(without data partitioning).\n' return summary_str
[docs] def has_phase(self, phase): """ :return: True if the `phase` subset of images is not empty. """ if self._partition_ids is None or self._partition_ids.empty: return False selector = self._partition_ids[COLUMN_PHASE] == phase if not selector.any(): return False selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]] subset = pandas.merge( left=self._file_list, right=selected, on=COLUMN_UNIQ_ID, sort=False) return not subset.empty
@property def has_training(self): """ :return: True if the TRAIN subset of images is not empty. """ return self.has_phase(TRAIN) @property def has_inference(self): """ :return: True if the INFER subset of images is not empty. """ return self.has_phase(INFER) @property def has_validation(self): """ :return: True if the VALID subset of images is not empty. """ return self.has_phase(VALID) @property def validation_files(self): """ :return: the list of validation filenames. """ if self.has_validation: return self.get_file_list(VALID) return self.all_files @property def train_files(self): """ :return: the list of training filenames. """ if self.has_training: return self.get_file_list(TRAIN) return self.all_files @property def inference_files(self): """ :return: the list of inference filenames (defaulting to list of all filenames if no partition definition) """ if self.has_inference: return self.get_file_list(INFER) return self.all_files @property def all_files(self): """ :return: list of all filenames """ return self.get_file_list()
[docs] def get_file_lists_by(self, phase=None, action='train'): """ Get file lists by action and phase. This function returns file lists for training/validation/inference based on the phase or action specified by the user. ``phase`` has a higher priority: If `phase` specified, the function returns the corresponding file list (as a list). otherwise, the function checks ``action``: it returns train and validation file lists if it's training action, otherwise returns inference file list. :param action: an action :param phase: an element from ``{TRAIN, VALID, INFER, ALL}`` :return: """ if phase: try: return [self.get_file_list(phase=phase)] except (ValueError, AttributeError): tf.logging.warning('phase `parameter` %s ignored', phase) if action and TRAIN.startswith(action): file_lists = [self.train_files] if self.has_validation: file_lists.append(self.validation_files) return file_lists return [self.inference_files]
[docs] def reset(self): """ reset all fields of this singleton class. """ self._file_list = None self._partition_ids = None self.data_param = None self.ratios = None self.new_partition = False self.data_split_file = "" self.default_image_file_location = \ NiftyNetGlobalConfig().get_niftynet_home_folder()