inception_v1.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains the definition for inception v1 classification network."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets import inception_utils
  21. slim = tf.contrib.slim
  22. trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
  23. def inception_v1_base(inputs,
  24. final_endpoint='Mixed_5c',
  25. scope='InceptionV1'):
  26. """Defines the Inception V1 base architecture.
  27. This architecture is defined in:
  28. Going deeper with convolutions
  29. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
  30. Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
  31. http://arxiv.org/pdf/1409.4842v1.pdf.
  32. Args:
  33. inputs: a tensor of size [batch_size, height, width, channels].
  34. final_endpoint: specifies the endpoint to construct the network up to. It
  35. can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
  36. 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
  37. 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
  38. 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
  39. scope: Optional variable_scope.
  40. Returns:
  41. A dictionary from components of the network to the corresponding activation.
  42. Raises:
  43. ValueError: if final_endpoint is not set to one of the predefined values.
  44. """
  45. end_points = {}
  46. with tf.variable_scope(scope, 'InceptionV1', [inputs]):
  47. with slim.arg_scope(
  48. [slim.conv2d, slim.fully_connected],
  49. weights_initializer=trunc_normal(0.01)):
  50. with slim.arg_scope([slim.conv2d, slim.max_pool2d],
  51. stride=1, padding='SAME'):
  52. end_point = 'Conv2d_1a_7x7'
  53. net = slim.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point)
  54. end_points[end_point] = net
  55. if final_endpoint == end_point: return net, end_points
  56. end_point = 'MaxPool_2a_3x3'
  57. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  58. end_points[end_point] = net
  59. if final_endpoint == end_point: return net, end_points
  60. end_point = 'Conv2d_2b_1x1'
  61. net = slim.conv2d(net, 64, [1, 1], scope=end_point)
  62. end_points[end_point] = net
  63. if final_endpoint == end_point: return net, end_points
  64. end_point = 'Conv2d_2c_3x3'
  65. net = slim.conv2d(net, 192, [3, 3], scope=end_point)
  66. end_points[end_point] = net
  67. if final_endpoint == end_point: return net, end_points
  68. end_point = 'MaxPool_3a_3x3'
  69. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  70. end_points[end_point] = net
  71. if final_endpoint == end_point: return net, end_points
  72. end_point = 'Mixed_3b'
  73. with tf.variable_scope(end_point):
  74. with tf.variable_scope('Branch_0'):
  75. branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
  76. with tf.variable_scope('Branch_1'):
  77. branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
  78. branch_1 = slim.conv2d(branch_1, 128, [3, 3], scope='Conv2d_0b_3x3')
  79. with tf.variable_scope('Branch_2'):
  80. branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
  81. branch_2 = slim.conv2d(branch_2, 32, [3, 3], scope='Conv2d_0b_3x3')
  82. with tf.variable_scope('Branch_3'):
  83. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  84. branch_3 = slim.conv2d(branch_3, 32, [1, 1], scope='Conv2d_0b_1x1')
  85. net = tf.concat(
  86. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  87. end_points[end_point] = net
  88. if final_endpoint == end_point: return net, end_points
  89. end_point = 'Mixed_3c'
  90. with tf.variable_scope(end_point):
  91. with tf.variable_scope('Branch_0'):
  92. branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  93. with tf.variable_scope('Branch_1'):
  94. branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  95. branch_1 = slim.conv2d(branch_1, 192, [3, 3], scope='Conv2d_0b_3x3')
  96. with tf.variable_scope('Branch_2'):
  97. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  98. branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3')
  99. with tf.variable_scope('Branch_3'):
  100. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  101. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  102. net = tf.concat(
  103. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  104. end_points[end_point] = net
  105. if final_endpoint == end_point: return net, end_points
  106. end_point = 'MaxPool_4a_3x3'
  107. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  108. end_points[end_point] = net
  109. if final_endpoint == end_point: return net, end_points
  110. end_point = 'Mixed_4b'
  111. with tf.variable_scope(end_point):
  112. with tf.variable_scope('Branch_0'):
  113. branch_0 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
  114. with tf.variable_scope('Branch_1'):
  115. branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
  116. branch_1 = slim.conv2d(branch_1, 208, [3, 3], scope='Conv2d_0b_3x3')
  117. with tf.variable_scope('Branch_2'):
  118. branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
  119. branch_2 = slim.conv2d(branch_2, 48, [3, 3], scope='Conv2d_0b_3x3')
  120. with tf.variable_scope('Branch_3'):
  121. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  122. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  123. net = tf.concat(
  124. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  125. end_points[end_point] = net
  126. if final_endpoint == end_point: return net, end_points
  127. end_point = 'Mixed_4c'
  128. with tf.variable_scope(end_point):
  129. with tf.variable_scope('Branch_0'):
  130. branch_0 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  131. with tf.variable_scope('Branch_1'):
  132. branch_1 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
  133. branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3')
  134. with tf.variable_scope('Branch_2'):
  135. branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
  136. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  137. with tf.variable_scope('Branch_3'):
  138. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  139. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  140. net = tf.concat(
  141. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  142. end_points[end_point] = net
  143. if final_endpoint == end_point: return net, end_points
  144. end_point = 'Mixed_4d'
  145. with tf.variable_scope(end_point):
  146. with tf.variable_scope('Branch_0'):
  147. branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  148. with tf.variable_scope('Branch_1'):
  149. branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  150. branch_1 = slim.conv2d(branch_1, 256, [3, 3], scope='Conv2d_0b_3x3')
  151. with tf.variable_scope('Branch_2'):
  152. branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
  153. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  154. with tf.variable_scope('Branch_3'):
  155. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  156. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  157. net = tf.concat(
  158. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  159. end_points[end_point] = net
  160. if final_endpoint == end_point: return net, end_points
  161. end_point = 'Mixed_4e'
  162. with tf.variable_scope(end_point):
  163. with tf.variable_scope('Branch_0'):
  164. branch_0 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
  165. with tf.variable_scope('Branch_1'):
  166. branch_1 = slim.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1')
  167. branch_1 = slim.conv2d(branch_1, 288, [3, 3], scope='Conv2d_0b_3x3')
  168. with tf.variable_scope('Branch_2'):
  169. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  170. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  171. with tf.variable_scope('Branch_3'):
  172. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  173. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  174. net = tf.concat(
  175. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  176. end_points[end_point] = net
  177. if final_endpoint == end_point: return net, end_points
  178. end_point = 'Mixed_4f'
  179. with tf.variable_scope(end_point):
  180. with tf.variable_scope('Branch_0'):
  181. branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
  182. with tf.variable_scope('Branch_1'):
  183. branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  184. branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
  185. with tf.variable_scope('Branch_2'):
  186. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  187. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
  188. with tf.variable_scope('Branch_3'):
  189. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  190. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  191. net = tf.concat(
  192. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  193. end_points[end_point] = net
  194. if final_endpoint == end_point: return net, end_points
  195. end_point = 'MaxPool_5a_2x2'
  196. net = slim.max_pool2d(net, [2, 2], stride=2, scope=end_point)
  197. end_points[end_point] = net
  198. if final_endpoint == end_point: return net, end_points
  199. end_point = 'Mixed_5b'
  200. with tf.variable_scope(end_point):
  201. with tf.variable_scope('Branch_0'):
  202. branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
  203. with tf.variable_scope('Branch_1'):
  204. branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  205. branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
  206. with tf.variable_scope('Branch_2'):
  207. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  208. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0a_3x3')
  209. with tf.variable_scope('Branch_3'):
  210. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  211. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  212. net = tf.concat(
  213. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  214. end_points[end_point] = net
  215. if final_endpoint == end_point: return net, end_points
  216. end_point = 'Mixed_5c'
  217. with tf.variable_scope(end_point):
  218. with tf.variable_scope('Branch_0'):
  219. branch_0 = slim.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1')
  220. with tf.variable_scope('Branch_1'):
  221. branch_1 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
  222. branch_1 = slim.conv2d(branch_1, 384, [3, 3], scope='Conv2d_0b_3x3')
  223. with tf.variable_scope('Branch_2'):
  224. branch_2 = slim.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1')
  225. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
  226. with tf.variable_scope('Branch_3'):
  227. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  228. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  229. net = tf.concat(
  230. axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  231. end_points[end_point] = net
  232. if final_endpoint == end_point: return net, end_points
  233. raise ValueError('Unknown final endpoint %s' % final_endpoint)
  234. def inception_v1(inputs,
  235. num_classes=1000,
  236. is_training=True,
  237. dropout_keep_prob=0.8,
  238. prediction_fn=slim.softmax,
  239. spatial_squeeze=True,
  240. reuse=None,
  241. scope='InceptionV1',
  242. global_pool=False):
  243. """Defines the Inception V1 architecture.
  244. This architecture is defined in:
  245. Going deeper with convolutions
  246. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
  247. Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
  248. http://arxiv.org/pdf/1409.4842v1.pdf.
  249. The default image size used to train this network is 224x224.
  250. Args:
  251. inputs: a tensor of size [batch_size, height, width, channels].
  252. num_classes: number of predicted classes. If 0 or None, the logits layer
  253. is omitted and the input features to the logits layer (before dropout)
  254. are returned instead.
  255. is_training: whether is training or not.
  256. dropout_keep_prob: the percentage of activation values that are retained.
  257. prediction_fn: a function to get predictions out of logits.
  258. spatial_squeeze: if True, logits is of shape [B, C], if false logits is of
  259. shape [B, 1, 1, C], where B is batch_size and C is number of classes.
  260. reuse: whether or not the network and its variables should be reused. To be
  261. able to reuse 'scope' must be given.
  262. scope: Optional variable_scope.
  263. global_pool: Optional boolean flag to control the avgpooling before the
  264. logits layer. If false or unset, pooling is done with a fixed window
  265. that reduces default-sized inputs to 1x1, while larger inputs lead to
  266. larger outputs. If true, any input size is pooled down to 1x1.
  267. Returns:
  268. net: a Tensor with the logits (pre-softmax activations) if num_classes
  269. is a non-zero integer, or the non-dropped-out input to the logits layer
  270. if num_classes is 0 or None.
  271. end_points: a dictionary from components of the network to the corresponding
  272. activation.
  273. """
  274. # Final pooling and prediction
  275. with tf.variable_scope(scope, 'InceptionV1', [inputs], reuse=reuse) as scope:
  276. with slim.arg_scope([slim.batch_norm, slim.dropout],
  277. is_training=is_training):
  278. net, end_points = inception_v1_base(inputs, scope=scope)
  279. with tf.variable_scope('Logits'):
  280. if global_pool:
  281. # Global average pooling.
  282. net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
  283. end_points['global_pool'] = net
  284. else:
  285. # Pooling with a fixed kernel size.
  286. net = slim.avg_pool2d(net, [7, 7], stride=1, scope='AvgPool_0a_7x7')
  287. end_points['AvgPool_0a_7x7'] = net
  288. if not num_classes:
  289. return net, end_points
  290. net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b')
  291. logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
  292. normalizer_fn=None, scope='Conv2d_0c_1x1')
  293. if spatial_squeeze:
  294. logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
  295. end_points['Logits'] = logits
  296. end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
  297. return logits, end_points
  298. inception_v1.default_image_size = 224
  299. inception_v1_arg_scope = inception_utils.inception_arg_scope