lenet.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains a variant of the LeNet model definition."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. slim = tf.contrib.slim
  21. def lenet(images, num_classes=10, is_training=False,
  22. dropout_keep_prob=0.5,
  23. prediction_fn=slim.softmax,
  24. scope='LeNet'):
  25. """Creates a variant of the LeNet model.
  26. Note that since the output is a set of 'logits', the values fall in the
  27. interval of (-infinity, infinity). Consequently, to convert the outputs to a
  28. probability distribution over the characters, one will need to convert them
  29. using the softmax function:
  30. logits = lenet.lenet(images, is_training=False)
  31. probabilities = tf.nn.softmax(logits)
  32. predictions = tf.argmax(logits, 1)
  33. Args:
  34. images: A batch of `Tensors` of size [batch_size, height, width, channels].
  35. num_classes: the number of classes in the dataset. If 0 or None, the logits
  36. layer is omitted and the input features to the logits layer are returned
  37. instead.
  38. is_training: specifies whether or not we're currently training the model.
  39. This variable will determine the behaviour of the dropout layer.
  40. dropout_keep_prob: the percentage of activation values that are retained.
  41. prediction_fn: a function to get predictions out of logits.
  42. scope: Optional variable_scope.
  43. Returns:
  44. net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
  45. is a non-zero integer, or the inon-dropped-out nput to the logits layer
  46. if num_classes is 0 or None.
  47. end_points: a dictionary from components of the network to the corresponding
  48. activation.
  49. """
  50. end_points = {}
  51. with tf.variable_scope(scope, 'LeNet', [images]):
  52. net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
  53. net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
  54. net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2')
  55. net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
  56. net = slim.flatten(net)
  57. end_points['Flatten'] = net
  58. net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3')
  59. if not num_classes:
  60. return net, end_points
  61. net = end_points['dropout3'] = slim.dropout(
  62. net, dropout_keep_prob, is_training=is_training, scope='dropout3')
  63. logits = end_points['Logits'] = slim.fully_connected(
  64. net, num_classes, activation_fn=None, scope='fc4')
  65. end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
  66. return logits, end_points
  67. lenet.default_image_size = 28
  68. def lenet_arg_scope(weight_decay=0.0):
  69. """Defines the default lenet argument scope.
  70. Args:
  71. weight_decay: The weight decay to use for regularizing the model.
  72. Returns:
  73. An `arg_scope` to use for the inception v3 model.
  74. """
  75. with slim.arg_scope(
  76. [slim.conv2d, slim.fully_connected],
  77. weights_regularizer=slim.l2_regularizer(weight_decay),
  78. weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
  79. activation_fn=tf.nn.relu) as sc:
  80. return sc