Deeplab v3(2): 소스 분석
17455 단어 컴퓨터 시각 코드이미지 의미 분할 코드
train.py 주요 함수 및 주석은 다음과 같다
main()
# GPU
conifg = slim.deployment.model_deploy.DeploymentConfig(xxx) # Create a DeploymentConfig for multi-gpu
# slim
dataset = deeplab.datasets.segmentation_dataset.get_dataset(xxx) # Gets an instance of slim dataset
#
samples = input_generator.get(dataset, xxx)
# Creates a queue to prefetch tensors from `tensors`
inputs_queue = prefetch_queue.prefetch_queue(samples, capacity=128 * config.num_clones)
#
clones = Clone(_build_deeplab(inputs_queue, xxx), scope, device)
learning_rate = train_utils.get_model_learning_rate(xxx)
slim.learning.train(xxx)
deeplab.datasets.segmentation_dataset.get_dataset(dataset_name, split_name, dataset_dir):
# example 。 tf
keys_to_features
# 。 slim
items_to_handlers
# ,
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
return dataset.Dataset(xxx)
deeplab.utils.input_generator.get(dataset, xxx)
# provider dataset
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset, xxx)
# , , ,
image, height, width = data_provider.get([common.IMAGE, common.HEIGHT, common.WIDTH])
original_image, image, label = input_preprocess.preprocess_image_and_label(xxx)
return tf.train.batch(xxx)
_build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label)
#
samples = inputs_queue.dequeue()
model_options = common.ModelOptions(xxx)
#
outputs_to_scales_to_logits = model.multi_scale_logits(xxx)
eval.py 분석
1. 기본 구조는 다음과 같다.
def main(_):
# contrib.slim
dataset = segmentation_dataset.get_dataset(FLAGS.dataset, FLAGS.eval_split, dataset_dir=FLAGS.dataset_dir)
# todo: tf.Graph(), tf.Graph()
with tf.Graph().as_default():
#
samples = input_generator.get(
dataset,
FLAGS.eval_crop_size,
FLAGS.eval_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.eval_split,
is_training=False,
model_variant=FLAGS.model_variant)
#
model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
crop_size=FLAGS.eval_crop_size,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
if tuple(FLAGS.eval_scales) == (1.0,):
tf.logging.info('Performing single-scale test.')
predictions = model.predict_labels(samples[common.IMAGE], model_options, image_pyramid=FLAGS.image_pyramid)
else:
tf.logging.info('Performing multi-scale test.')
predictions = model.predict_labels_multi_scale(samples[common.IMAGE], model_options=model_options, eval_scales=FLAGS.eval_scales, add_flipped_images=FLAGS.add_flipped_images)
# flatten
predictions = tf.reshape(predictions, shape=[-1])
labels = tf.reshape(samples[common.LABEL], shape=[-1])
# miou
weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))
# Set ignore_label regions to label 0, because metrics.mean_iou requires
# range of labels = [0, dataset.num_classes). Note the ignore_label regions
# are not evaluated since the corresponding regions contain weights = 0.
labels = tf.where(
tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels)
metric_map['miou_1.0'] = tf.metrics.mean_iou(predictions, labels, dataset.num_classes, weights=weights)
metrics_to_values, metrics_to_updates = (tf.contrib.metrics.aggregate_metric_map(metric_map))
num_batches = int(math.ceil(dataset.num_samples / float(FLAGS.eval_batch_size)))
tf.contrib.slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_logdir,
num_evals=num_batches,
eval_op=list(metrics_to_updates.values()),
max_number_of_evaluations=num_eval_iters,
eval_interval_secs=FLAGS.eval_interval_secs)
# contrib.slim
def deeplab.datasets.segmentation_dataset.get_dataset(dataset_name, split_name, dataset_dir):
splits_to_sizes = {'train': 2975, 'val': 500}
num_classes = 19
ignore_labe = 255
file_pattern = '/home/sjming/Documents/deeplearning/semantic-segmentation/cityscapes/tfrecord/val-*'
# tf.FixedLenFeature(x,x,x): example , feature , tf
keys_to_features = {
'image/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/filename': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/format': tf.FixedLenFeature(
(), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/width': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/segmentation/class/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/segmentation/class/format': tf.FixedLenFeature(
(), tf.string, default_value='png'),
}
# 。 contrib.slim
items_to_handlers = {
'image': tfexample_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3),
'image_name': tfexample_decoder.Tensor('image/filename'),
'height': tfexample_decoder.Tensor('image/height'),
'width': tfexample_decoder.Tensor('image/width'),
'labels_class': tfexample_decoder.Image(
image_key='image/segmentation/class/encoded',
format_key='image/segmentation/class/format',
channels=1),
}
#
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
# contrib.slim
return dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=splits_to_sizes[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
ignore_label=ignore_label,
num_classes=num_classes,
name=dataset_name,
multi_label=True)
# This functions gets the dataset split for semantic segmentation. In
# particular, it is a wrapper of (1) dataset_data_provider which returns the raw
# dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
# Tensorflow operation of batching the preprocessed data. Then, the output could
# be directly used by training, evaluation or visualization.
def deeplab.utils.input_generator.get(dataset,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None):
# DatasetDataProvider
data_provider = tf.contrib.slim.dataset_data_provider.DatasetDataProvider(dataset,num_readers=num_readers,num_epochs=None if is_training else 1,shuffle=is_training)
image, height, width = data_provider.get([common.IMAGE, common.HEIGHT, common.WIDTH])
label, = data_provider.get([common.LABELS_CLASS])
#
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant)
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
common.LABEL: label
}
return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
#
def deeplab.model.predict_labels(images, model_options, image_pyramid=None)
# (?, 129, 129, 19)
outputs_to_scales_to_logits = multi_scale_logits(
images,
model_options=model_options,
image_pyramid=image_pyramid,
is_training=False,
fine_tune_batch_norm=False)
# argmax
return predictions
def model.predict_labels_multi_scale(images,model_options, eval_scales=(1.0,), add_flipped_images=False):
for i, image_scale in enumerate(eval_scales):
with tf.variable_scope(tf.get_variable_scope(), reuse=True if i else None):
outputs_to_scales_to_logits = multi_scale_logits(
images,
model_options=model_options,
image_pyramid=[image_scale],
is_training=False,
fine_tune_batch_norm=False)
# bilinear
for output in sorted(outputs_to_predictions):
predictions = outputs_to_predictions[output]
# Compute average prediction across different scales and flipped images.
predictions = tf.reduce_mean(tf.concat(predictions, 4), axis=4)
outputs_to_predictions[output] = tf.argmax(predictions, 3)
return outputs_to_predictions
2. 본 부분의 코드를 간단하게 설명하고 tf를 참고할 수 있습니다.contrib.slim 관련 내용, deeplab 공식적으로 제공한 코드는 기본적으로 slim 라이브러리 공식 문서에 따라 작성되며, 이 글을 참고하십시오https://blog.csdn.net/u014451076/article/details/80706318, 템플릿으로 사용할 수 있습니다.3. 코드 수정[후속 보충]