Fast-SCNN的解释以及使用Tensorflow 2.0的实现( 三 )

现在 , 让我们将这两个输入添加到特征融合模块中 。
ff_final = tf.keras.layers.add([ff_layer1, ff_layer2])ff_final = tf.keras.layers.BatchNormalization()(ff_final)ff_final = tf.keras.activations.relu(ff_final)4. 分类器在分类器部分 , 引入了2个深度可分离的卷积层和1个Point-wise的卷积层 。 在每个层之后 , 还进行了BatchNorm层和ReLU激活 。
这里需要注意的是 , 在原论文中 , 没有提到在point-wise卷积层之后添加上采样和Dropout层 , 但在本文的后面部分描述了这些层是在 point-wise卷积层之后添加的 。 因此 , 在实现过程中 , 我也按照论文的要求引入了这两层 。
在根据最终输出的需要进行上采样之后 , SoftMax将作为最后一层的激活 。
classifier = tf.keras.layers.SeparableConv2D(128, (3, 3), padding='same', strides = (1, 1), name = 'DSConv1_classifier')(ff_final)classifier = tf.keras.layers.BatchNormalization()(classifier)classifier = tf.keras.activations.relu(classifier)classifier = tf.keras.layers.SeparableConv2D(128, (3, 3), padding='same', strides = (1, 1), name = 'DSConv2_classifier')(classifier)classifier = tf.keras.layers.BatchNormalization()(classifier)classifier = tf.keras.activations.relu(classifier)classifier = conv_block(classifier, 'conv', 19, (1, 1), strides=(1, 1), padding='same', relu=True)classifier = tf.keras.layers.Dropout(0.3)(classifier)classifier = tf.keras.layers.UpSampling2D((8, 8))(classifier)classifier = tf.keras.activations.softmax(classifier)编译模型现在我们已经添加了所有的层 , 让我们创建最终的模型并编译它 。 为了创建模型 , 如上所述 , 我们使用了来自TF.Keras的函数api 。 这里 , 模型的输入是学习下采样模块中描述的初始输入层 , 输出是最终分类器的输出 。
fast_scnn = tf.keras.Model(inputs = input_layer , outputs = classifier, name = 'Fast_SCNN')现在 , 让我们用优化器和损失函数来编译它 。 在原论文中 , 作者在训练过程中使用了动量值为0.9 , 批大小为12的SGD优化器 。 他们还在学习率策略中使用了多项式学习率 , base值为0.045 , power为0.9 。 为了简单起见 , 我在这里没有使用任何学习率策略 , 但如果需要 , 你可以自己添加 。 此外 , 在编译模型时从ADAM optimizer开始总是一个好主意 , 但是在这个CityScapes dataset的特殊情况下 , 作者只使用了SGD 。 但在一般情况下 , 最好从ADAM optimizer开始 , 然后根据需要转向其他不同的优化器 。 对于损失函数 , 作者使用了交叉熵损失 , 在实现过程中也使用了交叉熵损失 。
optimizer = tf.keras.optimizers.SGD(momentum=0.9, lr=0.045)fast_scnn.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])在本文中 , 作者使用CityScapes数据集中的19个类别进行训练和评价 。 通过这个实现 , 你可以根据特定项目所需的任意数量的输出进行调整 。
下面是一些Fast-SCNN的验证结果 , 与输入图像和ground truth进行了比较 。
Fast-SCNN的解释以及使用Tensorflow 2.0的实现文章插图
来自原始论文中的图
【Fast-SCNN的解释以及使用Tensorflow 2.0的实现】英文原文: