2017-06-01 3 views
3

Ich verwende das Keras Sequential-Modell, um mehrere Klassenklassifizierer zu trainieren.Anfügen von Klassenetiketten an ein Keras-Modell

Bei der Auswertung gibt Keras einen Vektor der Vertraulichkeiten aus und ich kann die korrekte Klassen-ID aus der mit Argmax ableiten. Ich kann dann eine Nachschlagetabelle verwenden, um die tatsächliche Klassenbezeichnung (z. B. eine Zeichenfolge) zu erhalten.

Bis jetzt ist die Lösung, das trainierte Modell zu laden und dann eine Lookup-Tabelle separat zu laden. Da ich eine ganze Reihe von Klassifikatoren habe, würde ich es vorziehen, beide Strukturen in einer Datei zu behalten.

Also was ich suche ist eine Möglichkeit, den tatsächlichen Label-Lookup-Vektor in das Keras-Modell zu integrieren. Das würde mir erlauben, eine einzige Klassifizierungsdatei zu haben, die in der Lage ist, einige Eingabedaten zu nehmen und die korrekte Klassenbezeichnung für diese Daten zurückzugeben.

Eine Möglichkeit, dies zu lösen, wäre, sowohl das Modell als auch die Nachschlagetabelle in einem Tupel zu speichern und dieses Tupel in eine Beize zu schreiben, aber das scheint nicht sehr elegant zu sein.

Antwort

5

Also habe ich mich selbst an einer Lösung versucht und das scheint zu funktionieren. Ich habe auf etwas einfacheres gehofft.

Das Öffnen der Modelldatei ein zweites Mal ist nicht wirklich optimal, denke ich. Wenn es irgendjemandem besser geht, tue es auf jeden Fall.

import h5py 

from keras.models import load_model 
from keras.models import save_model 


def load_model_ext(filepath, custom_objects=None): 
    model = load_model(filepath, custom_objects=None) 
    f = h5py.File(filepath, mode='r') 
    meta_data = None 
    if 'my_meta_data' in f.attrs: 
     meta_data = f.attrs.get('my_meta_data') 
    f.close() 
    return model, meta_data 


def save_model_ext(model, filepath, overwrite=True, meta_data=None): 
    save_model(model, filepath, overwrite) 
    if meta_data is not None: 
     f = h5py.File(filepath, mode='a') 
     f.attrs['my_meta_data'] = meta_data 
     f.close() 
+0

Akzeptieren meine eigene Antwort für einen Mangel an Alternativen. Wenn jemand eine bessere Lösung findet, nehme ich seine an. – Cerno

+1

Das gleiche Problem versuche ich zu lösen. aber deine lösung funktioniert nicht für mich: '' 'save_model_ext (mod1, filepath = 'test_model.h5', meta_data = {0: 'c1', 1: 'c2'})' ' ergibt ein Fehler: '' ' TypeError: Objekt dtype dtype ('O') hat kein natives HDF5-Äquivalent ' ' Welchen Typ erwartet Ihre Funktion' meta_data'? – slymore

+0

Hallo. Sie müssen Daten verwenden, die in HDF5 konvertiert werden können. dtype = "O" bedeutet, dass Ihre Daten ein Python-Objekt enthalten, das offensichtlich nicht gültig ist. Wenn ich mich erinnere, habe ich Python-Wörterbücher ohne Probleme benutzt. Ist das wirklich der Code, den du ausprobiert hast oder ist die Wahrheit komplexer? – Cerno