Ich versuche den F1-Score als benutzerdefinierte Metrik in TensorFlow für eine DNNClassifier
zu definieren. Um das zu tun, schrieb ich eine FunktionBenutzerdefinierte Metrik basierend auf Streaming-Metriken von Tensorflow gibt NaN zurück
def metric_fn(predictions=[], labels=[], weights=[]):
P, _ = tf.contrib.metrics.streaming_precision(predictions, labels)
R, _ = tf.contrib.metrics.streaming_recall(predictions, labels)
if P + R == 0:
return 0
return 2*(P*R)/(P+R)
die streaming_precision
und streaming_recall
von TensorFlow verwendet die F1-Score zu calulate. Danach machte ich einen neuen Eintrag in die validation_metrics:
validation_metrics = {
"accuracy":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_accuracy,
prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
"precision":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_precision,
prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
"recall":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_recall,
prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
"f1score":
tf.contrib.learn.MetricSpec(
metric_fn=metric_fn,
prediction_key=tf.contrib.learn.PredictionKey.CLASSES)
}
Doch obwohl ich richtig Precision und Recall-Werte zu erhalten, f1score
ist immer nan
:
INFO:tensorflow:Saving dict for global step 151: accuracy = 0.982456, accuracy/baseline_label_mean = 0.397661, accuracy/threshold_0.500000_mean = 0.982456, auc = 0.982867, f1score = nan, global_step = 151, labels/actual_label_mean = 0.397661, labels/prediction_mean = 0.406118, loss = 0.310612, precision = 0.971014, precision/positive_threshold_0.500000_mean = 0.971014, recall = 0.985294, recall/positive_threshold_0.500000_mean = 0.985294
Irgend etwas stimmt nicht mit meinem metric_fn
, aber ich kann es nicht herausfinden. Die Werte P
und R
, die durch metric_fn
erhalten werden, haben die Form Tensor("precision/value:0", shape=(), dtype=float32)
. Ich finde das ein bisschen komisch. Ich habe einen skalaren Tensor erwartet.
Jede Hilfe wird geschätzt.
das funktioniert können, danke. Es gibt jedoch einen kleinen Syntaxfehler (fehlt '') '' im ersten Teil der return-Anweisung. – TheWaveLad
@TheWaveLad erledigt, thx – user1735003