inception_utils.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains common code shared by all inception models.
  16. Usage of arg scope:
  17. with slim.arg_scope(inception_arg_scope()):
  18. logits, end_points = inception.inception_v3(images, num_classes,
  19. is_training=is_training)
  20. """
  21. from __future__ import absolute_import
  22. from __future__ import division
  23. from __future__ import print_function
  24. import tensorflow as tf
  25. slim = tf.contrib.slim
  26. def inception_arg_scope(weight_decay=0.00004,
  27. use_batch_norm=True,
  28. batch_norm_decay=0.9997,
  29. batch_norm_epsilon=0.001,
  30. activation_fn=tf.nn.relu):
  31. """Defines the default arg scope for inception models.
  32. Args:
  33. weight_decay: The weight decay to use for regularizing the model.
  34. use_batch_norm: "If `True`, batch_norm is applied after each convolution.
  35. batch_norm_decay: Decay for batch norm moving average.
  36. batch_norm_epsilon: Small float added to variance to avoid dividing by zero
  37. in batch norm.
  38. activation_fn: Activation function for conv2d.
  39. Returns:
  40. An `arg_scope` to use for the inception models.
  41. """
  42. batch_norm_params = {
  43. # Decay for the moving averages.
  44. 'decay': batch_norm_decay,
  45. # epsilon to prevent 0s in variance.
  46. 'epsilon': batch_norm_epsilon,
  47. # collection containing update_ops.
  48. 'updates_collections': tf.GraphKeys.UPDATE_OPS,
  49. # use fused batch norm if possible.
  50. 'fused': None,
  51. }
  52. if use_batch_norm:
  53. normalizer_fn = slim.batch_norm
  54. normalizer_params = batch_norm_params
  55. else:
  56. normalizer_fn = None
  57. normalizer_params = {}
  58. # Set weight_decay for weights in Conv and FC layers.
  59. with slim.arg_scope([slim.conv2d, slim.fully_connected],
  60. weights_regularizer=slim.l2_regularizer(weight_decay)):
  61. with slim.arg_scope(
  62. [slim.conv2d],
  63. weights_initializer=slim.variance_scaling_initializer(),
  64. activation_fn=activation_fn,
  65. normalizer_fn=normalizer_fn,
  66. normalizer_params=normalizer_params) as sc:
  67. return sc