2017-08-24 2 views
1

Entschuldigung, wenn dies der falsche Ort ist, um mein Problem anzusprechen (bitte helfen Sie mir mit, wo es am besten ist, wenn es der Fall ist). Ich bin ein Neuling mit Keras und Python, also hoffen Antworten darauf.Wie wird ein CNN mit Keras fit_generator programmiert?

Ich versuche ein CNN-Steuermodell zu trainieren, das Bilder als Eingabe nimmt. Da es sich um ein ziemlich großes Dataset handelt, habe ich einen Datengenerator erstellt, um mit fit_generator() zu arbeiten. Es ist mir nicht klar, wie man diese Methode auf Chargen trainieren lässt, also nahm ich an, dass der Generator Chargen an fit_generator() zurückgeben muss. Der Generator sieht wie folgt aus:

def gen(file_name, batchsz = 64): 
    csvfile = open(file_name) 
    reader = csv.reader(csvfile) 
    batchCount = 0 
    while True: 
     for line in reader: 
      inputs = [] 
      targets = [] 
      temp_image = cv2.imread(line[1]) # line[1] is path to image 
      measurement = line[3] # steering angle 
      inputs.append(temp_image) 
      targets.append(measurement) 
      batchCount += 1 
      if batchCount >= batchsz: 
       batchCount = 0 
       X = np.array(inputs) 
       y = np.array(targets) 
       yield X, y 
     csvfile.seek(0) 

Es liest eine CSV-Datei mit Telemetriedaten (Lenkwinkel usw.) und Pfade zu Bildabtastwerte und gibt Arrays der Größe: batchsz Der Aufruf von fit_generator() wie folgt aussieht:

tgen = gen('h:/Datasets/dataset14-no.zero.speed.trn.csv', batchsz = 128) # Train data generator 
vgen = gen('h:/Datasets/dataset14-no.zero.speed.val.csv', batchsz = 128) # Validation data generator 
try: 
    model.fit_generator(
     tgen, 
     samples_per_epoch=113526, 
     nb_epoch=6, 
     validation_data=vgen, 
     nb_val_samples=20001 
    ) 

Der Datensatz enthält 113.526 Abtastpunkte noch das Training Fortschreibungsausgangs Modell liest dies wie (zum Beispiel):

1020/113526 [..............................] - ETA: 27737s - loss: 0.0080 
    1021/113526 [..............................] - ETA: 27723s - loss: 0.0080 
    1022/113526 [..............................] - ETA: 27709s - loss: 0.0080 
    1023/113526 [..............................] - ETA: 27696s - loss: 0.0080 

Welche scheint Probe für Probe zu trainieren (stochastisch?). Das resultierende Modell ist nutzlos. Ich trainierte zuvor mit .fit() mit dem gesamten Datensatz, der in den Speicher geladen wurde, auf einem viel kleineren Datensatz, und das erzeugte ein Modell, das zumindest dann funktioniert, wenn es schlecht ist. Offensichtlich stimmt etwas mit meinem Ansatz von fit_generator() nicht. Ich werde Ihnen sehr dankbar sein. Diese

+1

'samples_per_epoch' sollte wie in [keras documentation] (https://keras.io/models/sequential/) vorgeschlagen' total_samples/batch_size 'entsprechen. 'samples_per_epoch' gibt an, wie oft der Generator aufgerufen wird, bevor eine Epoche als abgeschlossen betrachtet wird. Sie weiß nicht, welche" Stapelgröße "Sie verwenden – gionni

+0

Danke @gionni. Von Keras 1.0.2 auf den neuesten Stand gebracht. Die fit-generator() Parameter machen mit dieser Version mehr Sinn. – tinyMind

Antwort

2

:

for line in reader: 
    inputs = [] 
    targets = [] 

... ist Batch für jede Zeile in der CSV-Dateien zurückzusetzen. Sie sind nicht mit der ganzen Daten Ausbildung, aber mit nur einer einzigen Probe in 128

Vorschlag:

for line in reader: 

    if batchCount == 0: 
     inputs = [] 
     targets = [] 
    .... 
    .... 

Als jemand kommentiert, der in fit Generator, samples_per_epoch sollte total_samples/batchsz gleich sein

Obwohl ich denke, Ihr Verlust sollte sowieso sinken. Wenn dies nicht der Fall ist, gibt es möglicherweise noch ein anderes Problem im Code, möglicherweise in der Art, wie Sie die Daten laden, oder in der Initialisierung oder Struktur des Modells.

Versuchen Sie Ihre Bilder und drucken Sie die Daten in den Generator plotten:

for X,y in tgen: #careful, this is an infinite loop, make it stop 

    print(X.shape[0]) # is this really the number of batches you expect? 

    for image in X: 
     ...some method to plot X so you can see it, or just print  

    print(y) 

Prüfen, ob die ergaben Werte ok mit dem, was Sie erwarten, dass sie zu sein.

+0

"... setzt Ihren Stapel für jede Zeile in den CSV-Dateien zurück." Dah! hätte das gesehen. Sonderbar, weil ich einen Testcode habe, um die yellowed Arrays auszudrucken, und sie sind Chargen der richtigen Größe und Reihenfolge. – tinyMind

+0

Über den Verlust hatte ich kürzlich ein Problem mit einem "eingefrorenen" Verlust. Ich entschied mich, für viele Epochen immer wieder nur eine Probe zu trainieren und plötzlich fand der Verlust seinen Weg.Dann führte ich schrittweise andere Beispiele ein und es begann richtig zu trainieren. Ich denke, das Modell war zu komplex oder ich hatte meine Gewichte nicht richtig initialisiert, so dass es länger dauerte, einige interessante Ergebnisse zu zeigen. –

+0

Danke Daniel. Scheint jetzt ok zu sein. Die GPU-Last ist jedoch ziemlich niedrig, als würde die GPU auf das Skript warten. – tinyMind