2017-12-12 2 views
0

Simple multi-task network can be done here. Aber ich möchte so etwas enter image description here. Nun konstruieren ich das Modell wie folgt:Wie macht man Multi-Task-Lernen in Fackel7?

model = nn.Sequential() 
model:add(nn.Linear(3,5)) 
prl1 = nn.ConcatTable() 
prl1:add(nn.Linear(5,1)) 
prl2 = nn.ConcatTable() 
prl2:add(nn.Linear(5,1)) 
prl2:add(nn.Linear(5,1)) 
prl1:add(prl2) 
model:add(prl1) 

Und meine Ausgabe lautet:

input = torch.rand(5,3) 
output = model:forward(input) 
output 
{ 
    1 : DoubleTensor - size: 5x1 
    2 : 
    { 
     1 : DoubleTensor - size: 5x1 
     2 : DoubleTensor - size: 5x1 
    } 
} 

Wie soll ich mein Kriterium konstruieren?

Antwort

0

Ich scheine es durch zwei Schritte, um herauszufinden:

1.Verwenden nn.Concat statt nn.ConcatTable in dem obigen Netzwerk, das die Ausgabe macht ein einfaches NxM Tensor zu sein, z.B. Ein 5x3-Tensor wird in das obige Netzwerk gelangen, während nn.Concat anstelle von nn.ConcatTable verwendet wird.

2.Nach dem Erhalt eines NxM-Tensors verwende ich die Kombination von nn.ConcatTable, nn.Concat und nn.Select, um die Ausgabe als einfache Tabelle zu machen, die jedes Ergebnis Tensor enthält.

Hier ist ein einfaches Beispiel für Schritt 2:

model = nn.Sequential() 
model:add(nn.Linear(3,5)) 

prl = nn.ConcatTable() 

spl1 = nn.Concat(2) 

seq1 = nn.Sequential() 
seq1:add(nn.Select(2, 1)) 
seq1:add(nn.Reshape(1)) 

seq2 = nn.Sequential() 
seq2:add(nn.Select(2, 2)) 
seq2:add(nn.Reshape(1)) 

seq3 = nn.Sequential() 
seq3:add(nn.Select(2, 3)) 
seq3:add(nn.Reshape(1)) 

spl1:add(seq1) 
spl1:add(seq2) 
spl1:add(seq3) 
prl:add(spl1) 

spl2 = nn.Concat(2) 

seq4 = nn.Sequential() 
seq4:add(nn.Select(2, 4)) 
seq4:add(nn.Reshape(1)) 

seq5 = nn.Sequential() 
seq5:add(nn.Select(2, 5)) 
seq5:add(nn.Reshape(1)) 

spl2:add(seq4) 
spl2:add(seq5) 
prl:add(spl2) 

model:add(prl) 

input = torch.rand(5,3) 
output = model:forward(input) 

Die Ausgabe aussehen wird:

th> output 
{ 
    1 : DoubleTensor - size: 5x3 
    2 : DoubleTensor - size: 5x2 
}