Kürzlich habe ich versucht, ein CNN in TF mit float16 zu trainieren. Zu meiner Überraschung ist es auf verschiedene Arten gebrochen, obwohl TF behauptet, es für eine Weile zu unterstützen. Zum Beispiel führt die float16-Optimierung bereits im zweiten Schritt zu einem NaN-Verlust, unabhängig vom Netzwerk.TensorFlow float16 Unterstützung ist gebrochen
import tensorflow as tf
import numpy as np
slim = tf.contrib.slim
dtype = tf.float16
shape = (4, 16, 16, 3)
inpt = tf.placeholder(dtype, shape, name='input')
net = slim.conv2d(inpt, 16, [3, 3], scope='conv',
weights_initializer=tf.zeros_initializer(),
# normalizer_fn=slim.batch_norm
)
loss = tf.reduce_mean(net)
opt = tf.train.AdamOptimizer(1e-3)
train_op = slim.learning.create_train_op(loss, opt)
val = np.zeros(shape)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(2):
print(sess.run(train_op, feed_dict={inpt: val}))
Zu meinem Verständnis ist es eindeutig ein Fehler: I Null Faltungen auf Null-Eingang gelten, sollte ich Null-Gradienten erhalten, die Null-Verlust nicht ändern. Es kann einfach nicht abweichen. Wenn dtype float32 ist, funktioniert es. NaN-Verluste treten sowohl bei CPU- als auch bei GPU-Versionen auf.
Allerdings habe ich in GH Probleme entlassen wurde, schloss eine zufällige Geck dieses Problem zu sagen, dass es das Verhalten bestimmt ist: https://github.com/tensorflow/tensorflow/issues/7226
Wenn Sie die Zeile mit BN Kommentar-, wird es bereits auf Millimeter Bauzeit brechen, weil BN übernimmt Moving Averages (und Beta, Gamma) sind immer float32 und werfen sie nicht richtig. Dieses Problem wurde auch geschlossen und scheinbar ignoriert: https://github.com/tensorflow/tensorflow/issues/7164
Ich fühle mich wie ich mit einer IT-Unterstützung der ersten Linie eines ISP sprechen.
Kann jemand erklären, wie ich mit float16 trainieren sollte, wenn solch ein einfaches "Netzwerk" fürchterlich versagt? Und was ist die empfohlene Methode, um Fehler jetzt zu melden?