2017-12-23 13 views
2

Ich laufe auf dieses Problem bei der Implementierung des vektorisierten SVM-Gradienten für cs231n Zuweisung1. hier ein Beispiel:numpy vektorisierte Möglichkeit, mehrere Zeilen des Arrays zu ändern (Zeilen können wiederholt werden)

ary = np.array([[1,-9,0], 
       [1,2,3], 
       [0,0,0]]) 
ary[[0,1]] += np.ones((2,2),dtype='int') 

und es gibt:

array([[ 2, -8, 1], 
     [ 2, 3, 4], 
     [ 0, 0, 0]]) 

alles bis Reihen fein ist nicht eindeutig ist:

ary[[0,1,1]] += np.ones((3,3),dtype='int') 

obwohl es nicht werfen einen Fehler haben, die Ausgabe war wirklich seltsam:

array([[ 2, -8, 1], 
     [ 2, 3, 4], 
     [ 0, 0, 0]]) 

und ich erwarte, sollte die zweite Zeile [3,4,5] anstatt [2,3,4], die naive Art, wie ich dieses Problem zu lösen verwendet wird, unter Verwendung einer for-Schleife wie folgt aus:

ary = np.array([[ 2, -8, 1], 
       [ 2, 3, 4], 
       [ 0, 0, 0]]) 
# the rows I want to change 
rows = [0,1,2,1,0,1] 
# the change matrix 
change = np.random.randn((6,3)) 
for i,row in enumerate(rows): 
    ary[row] += change[i] 

Also ich weiß wirklich nicht, wie man diese for-Schleife vektorisiert, gibt es eine bessere Möglichkeit, dies in NumPy zu tun? und warum es falsch ist, etwas zu tun, wie dies ?:

ary[rows] += change 

Falls jemand ist neugierig, warum ich so tun wollen, hier meine Implementierung von svm_loss_vectorized Funktion ist, muss ich auf den Etiketten die Gradienten von Gewichten basierend berechnen y:

def svm_loss_vectorized(W, X, y, reg): 
    """ 
    Structured SVM loss function, vectorized implementation. 

    Inputs and outputs are the same as svm_loss_naive. 
    """ 
    loss = 0.0 
    dW = np.zeros(W.shape) # initialize the gradient as zero 

    # transpose X and W 
    # D means input dimensions, N means number of train example 
    # C means number of classes 
    # X.shape will be (D,N) 
    # W.shape will be (C,D) 
    X = X.T 
    W = W.T 
    dW = dW.T 
    num_train = X.shape[1] 
    # transpose W_y shape to (D,N) 
    W_y = W[y].T 
    S_y = np.sum(W_y*X ,axis=0) 
    margins = np.dot(W,X) + 1 - S_y 
    mask = np.array(margins>0) 

    # get the impact of num_train examples made on W's gradient 
    # that is,only when the mask is positive 
    # the train example has impact on W's gradient 
    dW_j = np.dot(mask, X.T) 
    dW += dW_j 
    mul_mask = np.sum(mask, axis=0, keepdims=True).T 

    # dW[y] -= mul_mask * X.T 
    dW_y = mul_mask * X.T 
    for i,label in enumerate(y): 
     dW[label] -= dW_y[i] 

    loss = np.sum(margins*mask) - num_train 
    loss /= num_train 
    dW /= num_train 
    # add regularization term 
    loss += reg * np.sum(W*W) 
    dW += reg * 2 * W 
    dW = dW.T 

    return loss, dW 

Antwort

3

mithilfe integrierter in np.add.at

Die eingebaut ist np.add.at für solche Aufgaben, i, e.

np.add.at(ary, rows, change) 

Aber, da wir mit einem 2D Array arbeiten, die vielleicht nicht die performant sein.

schnell matrix-multiplication

Nutzung Wie sich herausstellt, können wir die sehr effizient matrix-multplication für einen solchen Fall, wie gut und gegeben genug Anzahl von wiederholten Zeilen für Summierung nutzen, könnte wirklich gut sein. Hier ist, wie wir es verwenden können -

mask = rows == np.arange(len(ary))[:,None] 
ary += mask.dot(change) 

Benchmarking

Lassen Sie uns Zeit np.add.at Verfahren gegen matrix-multiplication Basis eines für größere Arrays -

In [681]: ary = np.random.rand(1000,1000) 

In [682]: rows = np.random.randint(0,len(ary),(10000)) 

In [683]: change = np.random.rand(10000,1000) 

In [684]: %timeit np.add.at(ary, rows, change) 
1 loop, best of 3: 604 ms per loop 

In [687]: def matmul_addat(ary, row, change): 
    ...:  mask = rows == np.arange(len(ary))[:,None] 
    ...:  ary += mask.dot(change) 

In [688]: %timeit matmul_addat(ary, rows, change) 
10 loops, best of 3: 158 ms per loop 
+0

ist add.at eine neue Funktion? Ich habe es heute zweimal gesehen – Dark

+1

@Dark Schon seit Ewigkeiten hier. Aber nur zufällig, um heute ziemlich nützlich zu sein :) – Divakar

+1

@Divakar, yeah, danke für die Erinnerung an 'ufunc.at()'! – MaxU

Verwandte Themen