Ich versuche Tensorflow zu lernen und versuche derzeit, eine einfache logistische Regressionsmodell zu tun. Hier ist mein Code, den ich aus verschiedenen Beispielen zusammengesetzt habe, die ich finden konnte.logistische Regression Debug Tensorflow
with tf.Session() as sess:
# Training data
input = tf.constant(tra)
target = tf.constant(np.transpose(data[:,1]).astype(np.float64))
# Set model weights
W = tf.Variable(np.random.randn(10, 1).astype(np.float64))
# Construct model
mat=tf.matmul(input,W)
pred = tf.sigmoid(mat)
# Compute the error
yerror = tf.sub(pred, target)
# We are going to minimize the L2 loss. The L2 loss is the sum of the
# squared error for all our estimates of y. This penalizes large errors
# a lot, but small errors only a little.
loss = tf.nn.l2_loss(yerror)
# Gradient Descent
update_weights = tf.train.GradientDescentOptimizer(0.05).minimize(loss)
# Initializing the variables
tf.initialize_all_variables().run()
for _ in range(50):
# Repeatedly run the operations, updating the TensorFlow variable.
sess.run(update_weights)
print(loss.eval())
so dass der Code läuft, aber die Fehler Dosis verbessert nicht nach jedem ‚sess.run (update_weights)‘ itteration und ich habe versucht, mit diffrent Schrittgrößen.
Ich frage mich, ob das Setup korrekt ist?
Ich bin ein bisschen unsicher, wie man es debuggen kann, da die Berechnung von allem beim Laufbefehl erfolgt. Die Trainingsdaten sind in Ordnung. Wenn einige von euch sehen könnten, was ich falsch mache, in dieser ganzen Session aufbauen oder Vorschläge machen, wie ich das debuggen kann.
Hilfe sehr geschätzt.
Dank für einen Kommentar einiger Fragen: Kann ich irgendwelche Zufallszahl setzen oder sollte es zufällig Gauß- mit Mittelwert 0 und std 1 sein? und sind die N in log (N) die Anzahl der Features, die ich habe oder die Anzahl der Trainingsbeispiele? –
Wenn die Gewichte alle Null sind, ist es nicht wirklich wichtig, was die Gewichte sind, aber ansonsten ja, Sie können alles füttern, das nicht mit der Ausgabe korreliert ist. Das "N" bezieht sich auf die Anzahl der Klassen. – drpng