0

Ich entwickle gerade ein neuronales Netzwerk, und ich habe alle Daten und ich habe den Code zu dem Punkt bekommen, dass ein Bild zum CNN für das Training gespeist wird. In dem Trainingsprozess erscheint jedoch für das erste Bild ein Fehler mit dem folgenden Code.Ungültiger Argumentfehler Erwartet begin [0] = 0

def convolutional_neural_network(x): 
    weights = {'W_conv1':tf.Variable(tf.random_normal([5,5,1,32])), 
       'W_conv2':tf.Variable(tf.random_normal([5,5,32,64])), 
       'W_fc':tf.Variable(tf.random_normal([7*7*64,1024])), 
       'out':tf.Variable(tf.random_normal([1024, n_classes]))} 

    biases = {'b_conv1':tf.Variable(tf.random_normal([32])), 
       'b_conv2':tf.Variable(tf.random_normal([64])), 
       'b_fc':tf.Variable(tf.random_normal([1024])), 
       'out':tf.Variable(tf.random_normal([n_classes]))} 

    x = tf.reshape(x, shape=[-1, 28, 28, 1]) 

    conv1 = tf.nn.relu(conv2d(x, weights['W_conv1']) + biases['b_conv1']) 
    conv1 = maxpool2d(conv1) 

    conv2 = tf.nn.relu(conv2d(conv1, weights['W_conv2']) + biases['b_conv2']) 
    conv2 = maxpool2d(conv2) 

    fc = tf.reshape(conv2,[-1, 7*7*64]) 
    fc = tf.nn.relu(tf.matmul(fc, weights['W_fc'])+biases['b_fc']) 
    fc = tf.nn.dropout(fc, keep_rate) 

    output = tf.matmul(fc, weights['out'])+biases['out'] 
    print("hi") 
    return output 


def shuffle_unison(images, labels): 
    shuffleLabel = [] 
    shuffleImage = [] 
    shuffleVector = [] 
    for i in range(0, len(images)-1): 
     shuffleVector.append(i) 
    random.shuffle(shuffleLabel) 
    for i in range(0, len(shuffleVector)-1): 
     shuffleImage.append(images[shuffleVector[i]]) 
     shuffleLabel.append(labels[shuffleVector[i]]) 
    return shuffleImage, shuffleLabel 





def train_neural_network(x): 
    prediction = convolutional_neural_network(x) 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(prediction,y)) 
    optimizer = tf.train.AdamOptimizer().minimize(cost) 

    hm_epochs = 10 
    # step 4: Batching 

    with tf.Session() as sess: 
     init = tf.initialize_all_variables() 
     sess.run(init) 
     tf.train.start_queue_runners() 
     #array of strings and corresponding values 
     image_list, label_list = readImageLables() 
     for epoch in range(hm_epochs): 
      epoch_loss = 0 
      #shuffle every epoch 
      shuffle_image_list, shuffle_label_list = shuffle_unison(image_list, label_list) 
      sampleList = ['/home/sciencefair/Desktop/OrchardData/MachineLearningTesting/RottenOranges/result1.jpg'] 
      for i in range(0,7683): 
       #filename_queue = tf.train.string_input_producer(sampleList) 
       file_contents = tf.read_file(shuffle_image_list[i]) 
       image = tf.image.decode_jpeg(file_contents, channels=1) 
       resized_image = tf.image.resize_images(image, [28,28]) 
       #image_batch, label_batch = tf.train.batch([resized_image, shuffle_label_list[i]], batch_size=batch_size) # does train.batch take individual images or final tensors 
       #if(i>batch_size): 
        #print(label_batch.eval()) 
       a = tf.reshape(resized_image,[1, 784]) 
       print(a.eval()) 
       _, c = sess.run([optimizer, cost], feed_dict={x: tf.reshape(resized_image,[1, 784]).eval(), y: shuffle_label_list[i]}) 
       epoch_loss += c 
       print("ok") 

      print('Epoch', epoch, 'completed out of',hm_epochs,'loss:',epoch_loss) 
     sess.close() 

Die Stack-Trace sah aus wie diese

Caused by op 'Slice_1', defined at: 
    File "revisednet.py", line 128, in <module> 
    train_neural_network(x) 
    File "revisednet.py", line 87, in train_neural_network 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(prediction,y)) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/nn_ops.py", line 670, in softmax_cross_entropy_with_logits 
    labels = _flatten_outer_dims(labels) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/nn_ops.py", line 472, in _flatten_outer_dims 
    array_ops.shape(logits), [math_ops.sub(rank, 1)], [1]) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 431, in slice 
    return gen_array_ops._slice(input_, begin, size, name=name) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 2234, in _slice 
    name=name) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 749, in apply_op 
    op_def=op_def) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2380, in create_op 
    original_op=self._default_original_op, op_def=op_def) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1298, in __init__ 
    self._traceback = _extract_stack() 

InvalidArgumentError (see above for traceback): Expected begin[0] == 0 (got -1) and size[0] == 0 (got 1) when input.dim_size(0) == 0 
    [[Node: Slice_1 = Slice[Index=DT_INT32, T=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](Shape_2, Slice_1/begin, Slice_1/size)]] 

Dieser Fehler aus den Daten zu stammen scheint eine gewisse confliction mit der softmax Funktion verursacht. Ich habe jedoch keine Ahnung, was dieses Problem verursacht.

+0

Gibt es eine bestimmte Zeile, die den Fehler verursacht? –

+0

Ja, entsprechend dem Stack, den ich bereitgestellt habe, Zeile 87. TF.reduce_mean ... –

Antwort

1

Ich folgte diesem Tutorial: Sentdex, First pass through Data w/ 3D ConvNet , um ein 3D CNN zu bauen und bekam den gleichen Fehler wie hier.

Dieser Fehler tritt auf, weil die Dimension des Beschriftungsvektors meiner Eingabedaten (z. B. die Position des ersten Beschriftungsvektors in Sentex-Zugdaten train_data[0][1]) dieselbe Nummer sein sollte wie n_classes, die im Lernprogramm 2 ist.

In meinem falschen Versuch, verwende ich nur einen binären Wert 0 oder 1, um es darzustellen, dessen Dimension 1 ist, wo sollte 2. Also die tf.nn.softmax_cross_entropy_with_logits() Funktion wurde durch die falsche Größe des Labelvektors verwirrt.

Versuchen Sie, die Größe Ihrer Beschriftungsvektoren so zu erweitern, dass sie gleich n_classes ist.

Verwandte Themen