3

Ich versuche, eine Multiclass Multilabel-Klassifizierung mit einem CNN in Keras durchzuführen. Ich habe versucht, eine einzelne Etikett Genauigkeit Funktion zu erstellen, basierend auf this function aus einer ähnlichen FrageKeras Multilabel Multiclass Individual Tag Genauigkeit

Der entsprechende Code, den ich versucht habe, ist:

labels = ["dog", "mammal", "cat", "fish", "rock"] #I have more 
interesting_id = [0]*len(labels) 
interesting_id[labels.index("rock")] = 1 #we only care about rock's accuracy 
interesting_label = K.variable(np.array(interesting_label), dtype='float32') 

def single_class_accuracy(interesting_class_id): 
    def single(y_true, y_pred): 
     class_id_true = K.argmax(y_true, axis=-1) 
     class_id_preds = K.argmax(y_pred, axis=-1) 
     # Replace class_id_preds with class_id_true for recall here 
     accuracy_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'float32') 
     class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'float32') * accuracy_mask 
     class_acc = K.sum(class_acc_tensor)/K.maximum(K.sum(accuracy_mask), 1) 
     return class_acc 
    return single 

und dann später wird sie als eine Metrik genannt:

model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), 
       loss='binary_crossentropy', metrics=[metrics.binary_accuracy, 
       single_class_accuracy(interesting_id)]) 

Aber der Fehler ich erhalte ist:

> Traceback (most recent call last): 
    File "/share/pkg/tensorflow/r1.3/install/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 490, in apply_op 
    preferred_dtype=default_dtype) 
    File "/share/pkg/tensorflow/r1.3/install/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 676, in internal_convert_to_tensor 
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 
    File "/share/pkg/tensorflow/r1.3/install/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 677, in _TensorConversionFunction 
    "of type '%s'" % (dtype.name, v.dtype.name)) 
ValueError: Incompatible type conversion requested to type 'int64' for variable of type 'float32_ref' 

During handling of the above exception, another exception occurred: 

> Traceback (most recent call last): 
    File "bottleneck_model.py", line 190, in <module> 
    main() 
    File "bottleneck_model.py", line 171, in main 
    loss='binary_crossentropy', metrics=[metrics.binary_accuracy, binary_accuracy_with_threshold, single_class_accuracy(interesting_label)]) 
    File "/share/pkg/keras/2.0.6/install/lib/python3.6/site-packages/keras/engine/training.py", line 898, in compile 
    metric_result = masked_metric_fn(y_true, y_pred, mask=masks[i]) 
    File "/share/pkg/keras/2.0.6/install/lib/python3.6/site-packages/keras/engine/training.py", line 494, in masked 
    score_array = fn(y_true, y_pred) 
    File "bottleneck_model.py", line 81, in single 
    accuracy_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'float32') 
    File "/share/pkg/keras/2.0.6/install/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 1516, in equal 
    return tf.equal(x, y) 
    File "/share/pkg/tensorflow/r1.3/install/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 753, in equal 
    result = _op_def_lib.apply_op("Equal", x=x, y=y, name=name) 
    File "/share/pkg/tensorflow/r1.3/install/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 526, in apply_op 
    inferred_from[input_arg.type_attr])) 
TypeError: Input 'y' of 'Equal' Op has type float32 that does not match type int64 of argument 'x'. 

I h Ich habe versucht, die Typen vergeblich zu ändern.

Antwort

2

Die Eingänge zu K.equal haben unterschiedliche Datentypen. Ich nehme an, dass Sie class_id_preds zu float32 oder interesting_class_id zu int64 werfen sollten. Ist letzteres (Guss sonst die anderen Tensoren) eine ganze Zahl, sollte dies den Fehler beheben:

interesting_class_id = K.cast(interesting_class_id, 'int64')

+0

, die mir ein ungültiges Argument Fehler gibt. Shape 64 stimmt nicht überein 12 –

+0

Ich vermute, das ist ein späterer Fehler, der nach der Lösung dieses Problems aufgetreten ist. Leider kann ich Ihnen mit dem von Ihnen bereitgestellten Code-Fragment keine weitere Hilfe anbieten – rvinas

Verwandte Themen