2016-04-03 7 views
5

Angenommen, ich habe einen Tensor der Größe BxWxHxD. Ich möchte den Tensor so bearbeiten, dass ich einen neuen BxWxHxD-Tensor habe, bei dem nur das maximale Element in jedem WxH-Schnitt beibehalten wird und alle anderen Werte Null sind. Mit anderen Worten, ich denke, der beste Weg, dies zu erreichen, besteht darin, einen 2D-Argmax über die WxH-Schichten zu nehmen, was BxD-Indextensoren für die Zeilen und Spalten ergibt, die dann in einen einharten BxWxHxD-Tensor umgewandelt werden können als Maske verwendet werden. Wie mache ich das?Tensorflow multi-dimensional argmax

Antwort

1

Sie können die folgende Funktion als Ausgangspunkt verwenden. Er berechnet die Indizes des maximalen Elements für jede Charge und für jeden Kanal. Das resultierende Array hat das Format (Stapelgröße, 2, Anzahl der Kanäle).

def argmax_2d(tensor): 

    # input format: BxHxWxD 
    assert rank(tensor) == 4 

    # flatten the Tensor along the height and width axes 
    flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3])) 

    # argmax of the flat tensor 
    argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32) 

    # convert indexes into 2D coordinates 
    argmax_x = argmax // tf.shape(tensor)[2] 
    argmax_y = argmax % tf.shape(tensor)[2] 

    # stack and return 2D coordinates 
    return tf.stack((argmax_x, argmax_y), axis=1) 

def rank(tensor): 

    # return the rank of a Tensor 
    return len(tensor.get_shape())