Source code for niftynet.contrib.csv_reader.toynet_features
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.network.base_net import BaseNet
[docs]class ToyNetFeat(BaseNet):
def __init__(self,
num_classes,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
acti_func='prelu',
name='ToyNet'):
super(ToyNetFeat, self).__init__(
num_classes=num_classes,
w_initializer=w_initializer,
w_regularizer=w_regularizer,
b_initializer=b_initializer,
b_regularizer=b_regularizer,
acti_func=acti_func,
name=name)
self.hidden_features = 10
[docs] def layer_op(self, images, is_training=True, **unused_kwargs):
conv_1 = ConvolutionalLayer(self.hidden_features,
kernel_size=3,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
b_initializer=self.initializers['b'],
b_regularizer=self.regularizers['b'],
acti_func='relu',
name='conv_input')
# conv_2 = ConvolutionalLayer(self.num_classes,
# kernel_size=1,
# w_initializer=self.initializers['w'],
# w_regularizer=self.regularizers['w'],
# b_initializer=self.initializers['b'],
# b_regularizer=self.regularizers['b'],
# acti_func=None,
# name='conv_output')
flow = conv_1(images, is_training)
# flow = conv_2(flow, is_training)
return flow