Gibt es eine Möglichkeit, einen TF-Tensor innerhalb einer benutzerdefinierten Keras-Verlustfunktion umzuformen? Ich definiere diese benutzerdefinierte Verlustfunktion für ein konvolutionelles neuronales Netzwerk?TensorFlow Tensor in Keras verlustfunktion umformen?
def custom_loss(x, x_hat):
"""
Custom loss function for training background extraction networks (autoencoders)
"""
#flatten x, x_hat before computing mean, median
shape = x_hat.get_shape().as_list()
batch_size = shape[0]
image_size = np.prod(shape[1:])
x = tf.reshape(x, [batch_size, image_size])
x_hat = tf.reshape(x_hat, [batch_size, image_size])
B0 = reduce_median(tf.transpose(x_hat))
# I divide by sigma in the next step. So I add a small float32 to F0
# so as to prevent sigma from becoming 0 or Nan.
F0 = tf.abs(x_hat - B0) + 1e-10
sigma = tf.reduce_mean(tf.sqrt(F0/0.5), axis=0)
background_term = tf.reduce_mean(F0/sigma, axis=-1)
bce = binary_crossentropy(x, x_hat)
loss = bce + background_term
return loss
Zusätzlich zu dem Standard binary_crossentropy
in denen ein zusätzlichen Verlust background_term
wird die Berechnung hinzugefügt. Dieser Begriff regt das Netzwerk an, Bilder vorherzusagen, um den Median einer Charge zu schließen. Da die Ausgänge des CNN sind 2d und reduce_median
funktioniert besser mit 1D-Arrays muss ich die Bilder in 1D-Arrays umformen. Wenn ich versuche, dieses Netz zu trainieren erhalte ich die Fehler
Traceback (most recent call last):
File "stackoverflow.py", line 162, in <module>
autoencoder = build_conv_autoencoder(lambda_W, input_shape, num_filters, optimizer, custom_loss)
File "stackoverflow.py", line 136, in build_conv_autoencoder
autoencoder.compile(optimizer, loss, metrics=[mean_squared_error])
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 594, in compile
**kwargs)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 667, in compile
sample_weight, mask)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 318, in weighted
score_array = fn(y_true, y_pred)
File "stackoverflow.py", line 26, in custom_loss
x = tf.reshape(x, [batch_size, image_size])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 2448, in reshape
name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 494, in apply_op
raise err
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 491, in apply_op
preferred_dtype=default_dtype)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 710, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant_op.py", line 176, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant_op.py", line 165, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py", line 441, in make_tensor_proto
tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py", line 441, in <listcomp>
tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/compat.py", line 65, in as_bytes
(bytes_or_text,))
TypeError: Expected binary or unicode string, got None
Es ist wie Keras scheint custom_loss
ruft, bevor der TensorFlow Graph instanziiert wird. Dies macht batch_size
keine statt des tatsächlichen Wertes. Gibt es einen richtigen Weg, um Tensoren in Verlustfunktionen umzugestalten, um diesen Fehler zu vermeiden? Sie können den vollständigen Code here ansehen.
Haben Sie versucht, eine 'batch_input_shape' definieren, statt' input_shape' in entweder der ersten Schicht oder 'Input' Schicht? –
Könnten Sie den Wert von 'shape' überprüfen, nachdem Sie' get_shape() verwendet haben? As_list() '? Ich nehme an, 'x' und' x_hat' sind korrekte Tensoren, aber wenn Sie überprüfen könnten, dass sie korrekt sind, wird das Problem erheblich lösen – DarkCygnus