Source code for niftynet.network.base_net

from __future__ import absolute_import, print_function

import tensorflow as tf
from niftynet.layer.base_layer import TrainableLayer


[docs]class BaseNet(TrainableLayer): """ Template for networks """ def __init__(self, num_classes=0, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, acti_func='prelu', name="net_template"): super(BaseNet, self).__init__(name=name) self.num_classes = num_classes self.acti_func = acti_func self.initializers = {'w': w_initializer, 'b': b_initializer} self.regularizers = {'w': w_regularizer, 'b': b_regularizer} tf.logging.info('using {}'.format(name))