2017-12-10 5 views
0

Wenn I random_crop zu crop Bilddatensatz verwenden: tensorflow: Typeerror in random_crop

tf.random_crop(X, [batch_size, 24, 24, 3]) 

es wirft einen TypeError:

TypeError: Expected int32, got None of type '_Message' instead. 

Codes (run 3 Codeblocks unten in Python Terminal kann das Problem reproduzieren):

Ich will zufällig den IM schneiden Alter, bevor es in das Netz einspeisen, so dass ich schreiben random_crop_and_resize:

def random_crop_and_resize(): 
    # batch_size = tf.shape(X)[0] 
    batch_size, _, _, _ = X.get_shape().as_list() 
    return tf.image.resize_images \ 
     (tf.random_crop(X, [batch_size, 24, 24, 3]), [32, 32]) 

und Modellfunktion definiert als:

def my_model(X, y, is_training): 
    # augmentation: shape of X: [None, 32, 32, 3] 
    distorted_img = tf.cond(is_training, 
          random_crop_and_resize, lambda: X) 
    # ... feed distorted_img into network 

dann den Graphen definiert werden:

tf.reset_default_graph() 

X = tf.placeholder(tf.float32, [None, 32, 32, 3]) 
y = tf.placeholder(tf.int64, [None]) 
is_training = tf.placeholder(tf.bool) 

y_out, regularizer = my_model(X, y, is_training) 

aber es wirft a TypeError: Expected int32, got None of type '_Message' instead. Wo geht schief?


Weitere Informationen:

Umwelt:

  • TensorFlow 1.2 GPU
  • 3.6.3 Python von Anaconda
  • Ubuntu 14.04.3 LTS

Vollzurückverfolgungs :

-------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-63-67c2d74574b6> in <module>() 
    77 is_training = tf.placeholder(tf.bool) 
    78 
---> 79 y_out, regularizer = my_model(X, y, is_training) 
    80 
    81 # regularization 

<ipython-input-63-67c2d74574b6> in my_model(X, y, is_training) 
    11  # augmentation: shape of X: [None, 32, 32, 3] 
    12  distorted_img = tf.cond(is_training, 
---> 13     random_crop_and_resize, lambda: X) 
    14 
    15  regularizer = tf.contrib.layers.l2_regularizer(scale=0.03) 

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs) 
    287    'in a future version' if date is None else ('after %s' % date), 
    288    instructions) 
--> 289  return func(*args, **kwargs) 
    290  return tf_decorator.make_decorator(func, new_func, 'deprecated', 
    291          _add_deprecated_arg_notice_to_docstring(

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2) 
    1812  context_t = CondContext(pred, pivot_1, branch=1) 
    1813  context_t.Enter() 
-> 1814  orig_res_t, res_t = context_t.BuildCondBranch(true_fn) 
    1815  if orig_res_t is None: 
    1816  raise ValueError("true_fn must have a return value.") 

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in BuildCondBranch(self, fn) 
    1687 def BuildCondBranch(self, fn): 
    1688  """Add the subgraph defined by fn() to the graph.""" 
-> 1689  original_result = fn() 
    1690  if original_result is None: 
    1691  return None, None 

<ipython-input-63-67c2d74574b6> in random_crop_and_resize() 
     5  # batch_size = tf.shape(X)[0] 
     6  batch_size, _, _, _ = X.get_shape().as_list() 
----> 7  return tf.image.resize_images   (tf.random_crop(X, [batch_size, 24, 24, 3]), [32, 32]) 
     8 
     9 

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/ops/random_ops.py in random_crop(value, size, seed, name) 
    297 with ops.name_scope(name, "random_crop", [value, size]) as name: 
    298  value = ops.convert_to_tensor(value, name="value") 
--> 299  size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size") 
    300  shape = array_ops.shape(value) 
    301  check = control_flow_ops.Assert(

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, preferred_dtype) 
    674  name=name, 
    675  preferred_dtype=preferred_dtype, 
--> 676  as_ref=False) 
    677 
    678 

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype) 
    739 
    740   if ret is None: 
--> 741   ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 
    742 
    743   if ret is NotImplemented: 

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref) 
    111           as_ref=False): 
    112 _ = as_ref 
--> 113 return constant(v, dtype=dtype, name=name) 
    114 
    115 

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name, verify_shape) 
    100 tensor_value = attr_value_pb2.AttrValue() 
    101 tensor_value.tensor.CopyFrom(
--> 102  tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) 
    103 dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) 
    104 const_tensor = g.create_op(

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape) 
    372  nparray = np.empty(shape, dtype=np_dt) 
    373  else: 
--> 374  _AssertCompatible(values, dtype) 
    375  nparray = np.array(values, dtype=np_dt) 
    376  # check to them. 

/home/hyh/anaconda3/envs/cs231n/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in _AssertCompatible(values, dtype) 
    300  else: 
    301  raise TypeError("Expected %s, got %s of type '%s' instead." % 
--> 302      (dtype.name, repr(mismatch), type(mismatch).__name__)) 
    303 
    304 

TypeError: Expected int32, got None of type '_Message' instead. 

Antwort

0

Das Problem kommt von batch_size. In batch_size, _, _, _ = X.get_shape().as_list(), batch_size ist kein Integer-Typ.

Verwenden map_fn() statt Rechen batch_size in Bild ähnlichen Betrieb zu vermeiden:

tf.map_fn(lambda img: tf.random_crop(img, [24, 24, 3]), X) 

Referenz:

  1. TensorFlow image operations for batches
Verwandte Themen