2017-10-15 2 views
0

Ich erstelle ein CNN mit Keras 2.0.8, mit Tensorflow-Backend. Ich versuche, die Gewichtsmatrix der ersten Faltungsschicht zu erhalten, wie unten dargestellt:keras Conv2d Gewicht Matrixgröße ist umgekehrt

model = Sequential() 
model.add(Conv2D(filters=16, kernel_size=(3,3), 
        input_shape= 
(9,9,1),activation='relu',kernel_regularizer =l2(regularization_coef))) 

model.add(Conv2D(filters=64, kernel_size= 
(3,3),activation='relu',kernel_regularizer = l2(regularization_coef))) 

model.add(MaxPooling2D(pool_size=(2,2))) 
model.add(Dropout(0.5)) 

model.add(Flatten()) 
model.add(Dense(128,activation='relu',kernel_regularizer = 
l2(regularization_coef))) 

model.add(Dropout(0.5)) 
model.add(Dense(2,activation='softmax',kernel_regularizer = 
l2(regularization_coef))) 

model.compile(loss='categorical_crossentropy', 
optimizer='adadelta',metrics=['accuracy']) 
model.summary() 

model.fit(X_train, Y_train, batch_size=batch_size, epochs=nb_epoch, 
verbose=0, validation_split=0.1) 

score = model.evaluate(X_test, Y_test, verbose=0) 
print('Test score:', score[0]) 
print('Test accuracy:', score[1]) 

filters= model.layers[0].get_weights()[0] 
print(filters.shape) 

Die erste Schicht, wie man sehen kann eine 2D-Faltungsschicht mit 16 Filtern ist, die Kerngröße (3,3) und 1 Eingangskanal. Die letzte Linie sollte mir also eine Form von (16,1,3,3) geben, aber stattdessen bekomme ich eine Form von (3,3,1,16). Ich möchte die Gewichte als 16 3x3-Matrizen visualisieren, aber ich kann das wegen dieses Formproblems nicht machen. Kann mir bitte jemand helfen? Vielen Dank im Voraus!

Antwort

1

Sie können das Array transponieren, um die 16 an den Anfang zu verschieben und dann zu (16, 3, 3) umzuformen.

filters= model.layers[0].get_weights()[0] 
print(filters.shape) 
# (3,3,1,16) 
filters = filters.transpose(3,0,1,2) 
print(filters.shape) 
# (16, 3, 3, 1) 
filters = filters.reshape((16,3,3)) 
print(filters.shape) 
# (16, 3, 3) 
+0

Vielen Dank. Also kann ich das sehr gut machen? 'filters = filters.transpose (3,2,0,1) drucken (filters.shape) # (16,1,3,3)' – abhih1

Verwandte Themen