2017-08-23 3 views
1

Ich möchte diese Verlustfunktion entwerfen:Tensorflow: tf.argmax und Schneiden

sum((y[argmax(y_)] - y_[argmax(y_)])²) 

finde ich nicht einen Weg y[argmax(y_)] zu tun. Ich versuchte y[k], y[:,k] und y[None,k] keine dieser Arbeit. Dies ist mein Code:

Na = 3 
    x = tf.placeholder(tf.float32, [None, 2]) 
    W = tf.Variable(tf.zeros([2, Na])) 
    b = tf.Variable(tf.zeros([Na])) 
    y = tf.nn.relu(tf.matmul(x, W) + b) 
    y_ = tf.placeholder(tf.float32, [None, 3]) 
    k = tf.argmax(y_, 1) 
    diff = y[k] - y_[k] 
    loss = tf.reduce_sum(tf.square(diff)) 

Und der Fehler:

File "/home/ncarrara/phd/code/cython/robotnavigation/ftq/cftq19.py", line 156, in <module> 
    diff = y[k] - y_[k] 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 499, in _SliceHelper 
    name=name) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 663, in strided_slice 
    shrink_axis_mask=shrink_axis_mask) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3515, in strided_slice 
    shrink_axis_mask=shrink_axis_mask, name=name) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op 
    op_def=op_def) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2508, in create_op 
    set_shapes_for_outputs(ret) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1873, in set_shapes_for_outputs 
    shapes = shape_func(op) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1823, in call_with_requiring 
    return call_cpp_shape_fn(op, require_shape_fn=True) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn 
    debug_python_shape_fn, require_shape_fn) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl 
    raise ValueError(err.message) 
ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [?,3], [1,?], [1,?], [1]. 

Antwort

0

Die tf.gather_nd getan werden kann, mit:

import tensorflow as tf 

Na = 3 
x = tf.placeholder(tf.float32, [None, 2]) 
W = tf.Variable(tf.zeros([2, Na])) 
b = tf.Variable(tf.zeros([Na])) 
y = tf.nn.relu(tf.matmul(x, W) + b) 
y_ = tf.placeholder(tf.float32, [None, 3]) 
k = tf.argmax(y_, 1) 
# Make index tensor with row and column indices 
num_examples = tf.cast(tf.shape(x)[0], dtype=k.dtype) 
idx = tf.stack([tf.range(num_examples), k], axis=-1) 
diff = tf.gather_nd(y, idx) - tf.gather_nd(y_, idx) 
loss = tf.reduce_sum(tf.square(diff)) 

Erläuterung:

In diesem Fall ist die Idee von tf.gather_nd ist eine Matrix (ein zweidimensionaler Tensor), in der jede Zeile den Index von enthält die Zeile und Spalte in der Ausgabe haben. wenn ich zum Beispiel eine Matrix a enthält:

| 1 2 3 | 
| 4 5 6 | 
| 7 8 9 | 

und eine Matrix i enthält:

| 1 2 | 
| 0 1 | 
| 2 2 | 
| 1 0 | 

Dann wird das Ergebnis der tf.gather_nd(a, i) wäre der Vektor (eindimensionale Tensor) enthält:

| 6 | 
| 2 | 
| 9 | 
| 4 | 

In diesem Fall sind die Spaltenindizes gegeben durch tf.argmax in k; es sagt Ihnen für jede Zeile, welches die Spalte mit dem höchsten Wert ist. Jetzt müssen Sie nur den Zeilenindex mit jedem von diesen setzen. Das erste Element in k ist der Index der Spalte mit dem maximalen Wert in der Zeile 0, das nächste Element das in der Zeile 1 und so weiter. num_examples ist nur die Anzahl der Zeilen in x und tf.range(num_examples) gibt Ihnen dann einen Vektor von 0 bis die Anzahl der Zeilen in x minus 1 (das heißt, alle Reihenfolge der Zeilenindizes). Jetzt müssen Sie nur das zusammen mit k setzen, was tf.stack tut, und das Ergebnis idx ist das Argument für tf.gather_nd.

+0

Sieht gut aus, aber im Moment bin ich nicht sicher genug, um Ihre Antwort zu validieren, danke trotzdem! –

+0

@nicolascarrara Ich habe einige Erklärungen hinzugefügt. – jdehesa

+0

Super danke, es ist jetzt klar! –