Ich benutze Caffe für die Klassifizierung von Nicht-Bilddaten mit einer recht einfachen CNN-Struktur. Ich hatte keine Probleme, mein Netzwerk auf meinen HDF5-Daten mit den Abmessungen n x 1 x 156 x 12 zu trainieren. Ich habe jedoch Schwierigkeiten, neue Daten zu klassifizieren.Vorhersage in Caffe - Ausnahme: Eingabe-Blob-Argumente stimmen nicht mit den Netto-Eingaben überein
Wie mache ich einen einfachen Weiterleitungsdurchlauf ohne Vorverarbeitung? Meine Daten wurden normalisiert und haben korrekte Maße für Caffe (es wurde bereits verwendet, um das Netz zu trainieren). Unten ist mein Code und die CNN-Struktur.
EDIT: Ich habe das Problem auf die Funktion '_Net_forward' in pycaffe.py isoliert und festgestellt, dass das Problem auftritt, da das self.input dict leer ist. Kann mir jemand erklären, warum das so ist? Der Satz soll den Satz aus den neuen Testdaten kommt gleich sein:
if set(kwargs.keys()) != set(self.inputs):
raise Exception('Input blob arguments do not match net inputs.')
Mein Code ein wenig verändert hat, wie ich jetzt die IO-Methoden zur Umwandlung der Daten in Bezug verwenden (siehe unten). Auf diese Weise habe ich die Variable Kwargs mit den richtigen Daten gefüllt.
Selbst kleine Hinweise würden sehr geschätzt werden!
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
# Make sure that caffe is on the python path:
caffe_root = '' # this file is expected to be run from {caffe_root}
import sys
sys.path.insert(0, caffe_root + 'python')
import caffe
import os
import subprocess
import h5py
import shutil
import tempfile
import sklearn
import sklearn.datasets
import sklearn.linear_model
import skimage.io
def LoadFromHDF5(dataset='test_reduced.h5', path='Bjarke/hdf5_classification/data/'):
f = h5py.File(path + dataset, 'r')
dat = f['data'][:]
f.close()
return dat;
def runModelPython():
model_file = 'Bjarke/hdf5_classification/conv_v2_simple.prototxt'
pretrained = 'Bjarke/hdf5_classification/data/train_iter_10000.caffemodel'
test_data = LoadFromHDF5()
net = caffe.Net(model_file, pretrained)
caffe.set_mode_cpu()
caffe.set_phase_test()
user = test_data[0,:,:,:]
datum = caffe.io.array_to_datum(user.astype(np.uint8))
user_dat = caffe.io.datum_to_array(datum)
user_dat = user_dat.astype(np.uint8)
out = net.forward_all(data=np.asarray([user_dat]))
if __name__ == '__main__':
runModelPython()
CNN Proto
name: "CDR-CNN"
layers {
name: "data"
type: HDF5_DATA
top: "data"
top: "label"
hdf5_data_param {
source: "Bjarke/hdf5_classification/data/train.txt"
batch_size: 10
}
include: { phase: TRAIN }
}
layers {
name: "data"
type: HDF5_DATA
top: "data"
top: "label"
hdf5_data_param {
source: "Bjarke/hdf5_classification/data/test.txt"
batch_size: 10
}
include: { phase: TEST }
}
layers {
name: "feature_conv"
type: CONVOLUTION
bottom: "data"
top: "feature_conv"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 10
kernel_w: 12
kernel_h: 1
stride_w: 1
stride_h: 1
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
}
}
}
layers {
name: "conv1"
type: CONVOLUTION
bottom: "feature_conv"
top: "conv1"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 14
kernel_w: 1
kernel_h: 4
stride_w: 1
stride_h: 1
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
}
}
}
layers {
name: "pool1"
type: POOLING
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_w: 1
kernel_h: 3
stride_w: 1
stride_h: 3
}
}
layers {
name: "conv2"
type: CONVOLUTION
bottom: "pool1"
top: "conv2"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 120
kernel_w: 1
kernel_h: 5
stride_w: 1
stride_h: 1
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
}
}
}
layers {
name: "fc1"
type: INNER_PRODUCT
bottom: "conv2"
top: "fc1"
blobs_lr: 1
blobs_lr: 2
weight_decay: 1
weight_decay: 0
inner_product_param {
num_output: 84
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
}
layers {
name: "accuracy"
type: ACCURACY
bottom: "fc1"
bottom: "label"
top: "accuracy"
include: { phase: TEST }
}
layers {
name: "loss"
type: SOFTMAX_LOSS
bottom: "fc1"
bottom: "label"
top: "loss"
}
auch die Protokolldatei angezeigt würde uns helfen, das Problem einzugrenzen weiter –
Nur damit Sie wissen, ich habe es nicht ein Fehler auf der gesagt Tracker. Ich habe gefragt, wie man es auf der Mailingliste macht, aber bis jetzt keine Antwort erhalten https://groups.google.com/forum/?utm_medium=email&utm_source=footer#!msg/caffe-users/eEhSBlKcjpc/llQi9PTPAYSJ – Mark
Gleiches Problem: https : //groups.google.com/forum/#! topic/caffe-users/aojN_bmbg74 –