2017-12-13 5 views
3

Ich versuche, TextGAN mit Pytorch "replizieren" und ich bin neu in Pytorch. Mein aktuelles Anliegen ist es, die L_G (eq 7 Seite 3.), Und hier ist mein aktueller Code zu replizieren:Wie implementiert man eine Obergrenze JSD Verlust in Pytorch?

class JSDLoss(nn.Module): 

    def __init__(self): 
     super(JSDLoss,self).__init__() 

    def forward(self, batch_size, f_real, f_synt): 
     assert f_real.size()[1] == f_synt.size()[1] 

     f_num_features = f_real.size()[1] 
     identity = autograd.Variable(torch.eye(f_num_features)*0.1, requires_grad=False) 

     if use_cuda: 
      identity = identity.cuda(gpu) 

     f_real_mean = torch.mean(f_real, 0, keepdim=True) 
     f_synt_mean = torch.mean(f_synt, 0, keepdim=True) 

     dev_f_real = f_real - f_real_mean.expand(batch_size,f_num_features) 
     dev_f_synt = f_synt - f_synt_mean.expand(batch_size,f_num_features) 

     f_real_xx = torch.mm(torch.t(dev_f_real), dev_f_real) 
     f_synt_xx = torch.mm(torch.t(dev_f_synt), dev_f_synt) 

     cov_mat_f_real = (f_real_xx/batch_size) - torch.mm(f_real_mean, torch.t(f_real_mean)) + identity 
     cov_mat_f_synt = (f_synt_xx/batch_size) - torch.mm(f_synt_mean, torch.t(f_synt_mean)) + identity 

     cov_mat_f_real_inv = torch.inverse(cov_mat_f_real) 
     cov_mat_f_synt_inv = torch.inverse(cov_mat_f_synt) 

     temp1 = torch.trace(torch.add(torch.mm(cov_mat_f_synt_inv, cov_mat_f_real), torch.mm(cov_mat_f_real_inv, cov_mat_f_synt))) 
     temp1 = temp1.view(1,1) 
     temp2 = torch.mm(torch.mm((f_synt_mean - f_real_mean), (cov_mat_f_synt_inv + cov_mat_f_real_inv)), torch.t(f_synt_mean - f_real_mean)) 
     loss_g = torch.add(temp1, temp2).mean() 

     return loss_g 

Es funktioniert. Ich habe jedoch den Verdacht, dass es nicht die Möglichkeit ist, einen benutzerdefinierten Verlust zu erstellen. Jede Art von Hilfe wird sehr geschätzt! Vielen Dank im Voraus :)

Antwort

0

Wie ein benutzerdefinierten Verlust in Pytorch

Dies zu schaffen, ist, wie Sie einen benutzerdefinierten Verlust in Pytorch erstellen. Sie müssen die folgenden Anforderungen erfüllen:

  • Der Wert schließlich durch eine Verlustfunktion ein Skalar-Wert muss zurückgegeben werden. Kein Vektor/Tensor.
  • Der zurückgegebene Wert muss eine Variable sein. Dies dient dazu, die Parameter in Ihrem Modell zu aktualisieren. Der beste Weg, dies zu tun ist ist nur um sicherzustellen, dass sowohl x und y übergeben werden Variablen sind. Auf diese Weise wird jede Funktion der beiden auch eine Variable sein.
  • definieren die __init__ und forward Methoden

Sie mehrere Verlustmodule finden können, die Sie als Beispiele in der Pytorch Quellcode verwenden: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py

Wenn Sie einen Mini-Batch-Tensor sind vorbei zu Ihre Verlustfunktion, dann besteht keine Notwendigkeit, die Minibatchgröße an die forward-Funktion zu übergeben, da die Größe in der forward-Funktion berechnet werden kann.

Wie Sie einen benutzerdefinierten Verlust verwenden

Sobald Sie Ihre Verlustfunktion implementiert haben, können Sie es wie folgt, zum Beispiel:

loss = YourLoss() 
input = autograd.Variable(torch.randn(3, 5), requires_grad=True) 
target = autograd.Variable(torch.randn(3, 5)) 
output = loss(input, target) 
output.backward() 

loss.backward() berechnet dloss/dx für jeden Parameter x in Ihrem Netzwerk, das requires_grad=True hat. Diese werden in x.grad für jeden Parameter x akkumuliert. In Pseudo-Code:

x.grad += dloss/dx 

optimizer.step aktualisiert den Wert von x den Gradienten x.grad. Zum Beispiel führt die SGD-Optimierer:

x += -lr * x.grad 

optimizer.zero_grad() löscht x.grad für jeden Parameter x im Optimierer. Es ist wichtig, dies vorher loss.backward() zu nennen, sonst sammeln Sie die Gradienten aus mehreren Durchgängen.

+0

vielen Dank für den Kommentar @ JMA. Ich wusste schon ziemlich genau, wie man Verlustfunktionen macht und alles, meine Frage konzentrierte sich eher darauf, ob ich die richtige Gleichung oder die Gleichung des Papiers benutze oder nicht. Es würde die ganze Welt bedeuten, wenn Sie mir dabei helfen können :) –

+0

Brauchen Sie noch Hilfe, die das bestätigt? – JMA

+0

Ja, ich bin auch in Ordnung, wenn Sie die Diskussion auf E-Mail verschieben möchten –

Verwandte Themen