Source code for niftynet.utilities.util_common

# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import datetime
import os
import re
from functools import partial

import numpy as np
import tensorflow as tf
from scipy import ndimage
from six import string_types

[docs]def traverse_nested(input_lists, types=(list, tuple)): """ Flatten a nested list or tuple """ if isinstance(input_lists, types): for input_list in input_lists: for sub_list in traverse_nested(input_list, types=types): yield sub_list else: yield input_lists
[docs]def list_depth_count(input_list): """ This function count the maximum depth of a nested list (recursively) This is used to check compatibility of users' input and system API only to be used for list or tuple """ if not isinstance(input_list, (list, tuple)): return 0 if len(input_list) == 0: return 1 return 1 + max(map(list_depth_count, input_list))
[docs]def average_gradients(multi_device_gradients): # the input gradients are grouped by device, # this function average the gradients of multiple devices if multi_device_gradients is None or not multi_device_gradients: # nothing to average return multi_device_gradients if len(multi_device_gradients) == 1: # only one device, so we get rid of the first level list # that loops over devices return multi_device_gradients[0] nested_grads_depth = list_depth_count(multi_device_gradients) if nested_grads_depth == 4: gradients = zip(*multi_device_gradients) averaged_grads = [__average_grads(g) for g in gradients] elif nested_grads_depth == 3: averaged_grads = __average_grads(multi_device_gradients) else: tf.logging.fatal( "The list of gradients are nested in an unusual way." "application's gradient is not compatible with app driver." "Please check the return value of gradients_collector " "in _connect_data_and_network() of the application") raise RuntimeError return averaged_grads
def __average_grads(tower_grads): """ Performs and return the average of the gradients :param tower_grads: in form of [[tower_1_grad], [tower_2_grad], ...] :return ave_grads: in form of [ave_grad] """ # average gradients computed from multiple GPUs ave_grads = [] for grad_and_vars in zip(*tower_grads): grads = [tf.expand_dims(g, 0) for g, _ in grad_and_vars if g is not None] if not grads: continue grad = tf.concat(grads, 0) grad = tf.reduce_mean(grad, 0) v = grad_and_vars[0][1] grad_and_var = (grad, v) ave_grads.append(grad_and_var) return ave_grads
[docs]def has_bad_inputs(input_args): """ Check if all input params have been properly set in the configuration file. :param input_args: :return: """ is_bad = False for section in input_args: section_args = input_args[section] for input_arg in vars(section_args): user_value = getattr(section_args, input_arg) if user_value is None: print('{} not set in section [{}] the config file'.format( input_arg, section)) is_bad = True return is_bad
def __print_argparse_section(args, section): output_string = [] header_str = '[{}]'.format(section.upper()) print(header_str) output_string.append(header_str) section_args = args[section] for arg in vars(section_args): out_str = "-- {}: {}".format(arg, getattr(section_args, arg)) print(out_str) output_string.append(out_str) return output_string
[docs]class MorphologyOps(object): """ Class that performs the morphological operations needed to get notably connected component. To be used in the evaluation """ def __init__(self, binary_img, neigh): assert len(binary_img.shape) == 3, 'currently supports 3d inputs only' self.binary_map = np.asarray(binary_img, dtype=np.int8) self.neigh = neigh
[docs] def border_map(self): """ Creates the border for a 3D image :return: """ west = ndimage.shift(self.binary_map, [-1, 0, 0], order=0) east = ndimage.shift(self.binary_map, [1, 0, 0], order=0) north = ndimage.shift(self.binary_map, [0, 1, 0], order=0) south = ndimage.shift(self.binary_map, [0, -1, 0], order=0) top = ndimage.shift(self.binary_map, [0, 0, 1], order=0) bottom = ndimage.shift(self.binary_map, [0, 0, -1], order=0) cumulative = west + east + north + south + top + bottom border = ((cumulative < 6) * self.binary_map) == 1 return border
[docs] def foreground_component(self): return ndimage.label(self.binary_map)
[docs]def CachedFunction(func): def decorated(*args, **kwargs): key = (func, args, frozenset(kwargs.items())) if key not in cache: cache[key] = func(*args,**kwargs) return cache[key] return decorated
[docs]def CachedFunctionByID(func): def decorated(*args, **kwargs): id_args = tuple(id(a) for a in args) id_kwargs = ((k,id(kwargs[k])) for k in sorted(kwargs.keys())) key = (func, id_args, id_kwargs) if key not in cache: cache[key] = func(*args,**kwargs) return cache[key] return decorated
[docs]class CacheFunctionOutput(object): """ this provides a decorator to cache function outputs to avoid repeating some heavy function computations """ def __init__(self, func): self.func = func def __get__(self, obj, _=None): if obj is None: return self return partial(self, obj) # to remember func as self.func def __call__(self, *args, **kw): obj = args[0] try: cache = obj.__cache except AttributeError: cache = obj.__cache = {} key = (self.func, args[1:], frozenset(kw.items())) try: value = cache[key] except KeyError: value = cache[key] = self.func(*args, **kw) return value
[docs]def look_up_operations(type_str, supported): """ This function validates the ``type_str`` against the supported set. if ``supported`` is a ``set``, returns ``type_str`` if ``supported`` is a ``dict``, return ``supported[type_str]`` else: raise an error possibly with a guess of the closest match. :param type_str: :param supported: :return: """ assert isinstance(type_str, string_types), 'unrecognised type string' if type_str in supported and isinstance(supported, dict): return supported[type_str] if type_str in supported and isinstance(supported, set): return type_str if isinstance(supported, set): set_to_check = supported elif isinstance(supported, dict): set_to_check = set(supported) else: set_to_check = set() edit_distances = {} for supported_key in set_to_check: edit_distance = damerau_levenshtein_distance(supported_key, type_str) if edit_distance <= 3: edit_distances[supported_key] = edit_distance if edit_distances: guess_at_correct_spelling = min(edit_distances, key=edit_distances.get) raise ValueError('By "{0}", did you mean "{1}"?\n' '"{0}" is not a valid option.\n' 'Available options are {2}\n'.format( type_str, guess_at_correct_spelling, supported)) else: raise ValueError("No supported option \"{}\" " "is not found.\nAvailable options are {}\n".format( type_str, supported))
[docs]def damerau_levenshtein_distance(s1, s2): """ Calculates an edit distance, for typo detection. Code based on :–Levenshtein_distance """ d = {} string_1_length = len(s1) string_2_length = len(s2) for i in range(-1, string_1_length + 1): d[(i, -1)] = i + 1 for j in range(-1, string_2_length + 1): d[(-1, j)] = j + 1 for i in range(string_1_length): for j in range(string_2_length): if s1[i] == s2[j]: cost = 0 else: cost = 1 d[(i, j)] = min( d[(i - 1, j)] + 1, # deletion d[(i, j - 1)] + 1, # insertion d[(i - 1, j - 1)] + cost, # substitution ) if i and j and s1[i] == s2[j - 1] and s1[i - 1] == s2[j]: d[(i, j)] = min(d[(i, j)], d[i - 2, j - 2] + cost) # transposition return d[string_1_length - 1, string_2_length - 1]
[docs]def otsu_threshold(img, nbins=256): """ Implementation of otsu thresholding :param img: :param nbins: :return: """ hist, bin_edges = np.histogram(img.ravel(), bins=nbins) hist = hist.astype(float) half_bin_size = (bin_edges[1] - bin_edges[0]) * 0.5 bin_centers = bin_edges[:-1] + half_bin_size weight_1 = np.copy(hist) mean_1 = np.copy(hist) weight_2 = np.copy(hist) mean_2 = np.copy(hist) for i in range(1, hist.shape[0]): weight_1[i] = weight_1[i - 1] + hist[i] mean_1[i] = mean_1[i - 1] + hist[i] * bin_centers[i] weight_2[-i - 1] = weight_2[-i] + hist[-i - 1] mean_2[-i - 1] = mean_2[-i] + hist[-i - 1] * bin_centers[-i - 1] target_max = 0 threshold = bin_centers[0] for i in range(0, hist.shape[0] - 1): ratio_1 = mean_1[i] / weight_1[i] ratio_2 = mean_2[i + 1] / weight_2[i + 1] target = weight_1[i] * weight_2[i + 1] * (ratio_1 - ratio_2) ** 2 if target > target_max: target_max, threshold = target, bin_centers[i] return threshold
# def otsu_threshold(img, nbins=256): # """ Implementation of otsu thresholding """ # hist, bin_edges = np.histogram(img.ravel(), bins=nbins, density=True) # hist = hist.astype(float) * (bin_edges[1] - bin_edges[0]) # centre_bins = 0.5 * (bin_edges[:-1] + bin_edges[1:]) # # hist_mul_val = hist * centre_bins # sum_tot = np.sum(hist_mul_val) # # threshold, target_max = centre_bins[0], 0 # sum_im, mean_im = 0, 0 # for i in range(0, hist.shape[0]-1): # mean_im = mean_im + hist_mul_val[i] # mean_ip = sum_tot - mean_im # # sum_im = sum_im + hist[i] # sum_ip = 1 - sum_im # # target = sum_ip * sum_im * np.square(mean_ip/sum_ip - mean_im/sum_im) # if target > target_max: # threshold, target_max = centre_bins[i], target # return threshold # Print iterations progress
[docs]def set_cuda_device(cuda_devices): if re.findall("\\d", cuda_devices): os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices "set CUDA_VISIBLE_DEVICES to {}".format(cuda_devices)) else: # using Tensorflow default choice pass
[docs]class ParserNamespace(object): """ Parser namespace for representing parsed parameters from config file e.g.:: system_params = ParserNamespace(action='train') action_str = system_params.action """ def __init__(self, **kwargs): self.__dict__.update(kwargs)
[docs] def update(self, **kwargs): self.__dict__.update(kwargs)