sum(x.*y',2)
ist eine saubere kurze Lösung.
Es hat auch gute Geschwindigkeit und Speichereigenschaften. Der Trick besteht darin, die Matrix-Vektor-Multiplikation als eine lineare Kombination von Matrixspalten zu betrachten, die durch die Vektorelemente skaliert werden. Anstatt jede lineare Kombination für die Matrix x [:,:, i] zu verwenden, verwenden wir dieselbe Skalierung y [i] für x [:, i ,:]. In Code:
const x = rand(6,6,2^10);
const y = rand(6,1);
function tst(x,y)
z = zeros(6,1,2^10)
for i in 1:2^10
z[:,:,i] = x[:,:,i]*y
end
return z
end
tst2(x,y) = mapslices(i->i*y,x,(1,2))
tst3(x,y) = sum(x.*y',2)
Benchmarking gibt:
julia> using BenchmarkTools
julia> z = tst(x,y); z2 = tst2(x,y); z3 = tst3(x,y);
julia> @benchmark tst(x,y)
BenchmarkTools.Trial:
memory estimate: 688.11 KiB
allocs estimate: 8196
--------------
median time: 759.545 μs (0.00% GC)
samples: 6068
julia> @benchmark tst2(x,y)
BenchmarkTools.Trial:
memory estimate: 426.81 KiB
allocs estimate: 10798
--------------
median time: 1.634 ms (0.00% GC)
samples: 2869
julia> @benchmark tst3(x,y)
BenchmarkTools.Trial:
memory estimate: 336.41 KiB
allocs estimate: 12
--------------
median time: 114.060 μs (0.00% GC)
samples: 10000
So tst3
mit sum
eine bessere Leistung (~ 7x über tst
und ~ 15x über tst2
) hat.
Die Verwendung von StaticArrays
wie von @DNF vorgeschlagen ist auch eine Option, und es wäre schön, es mit den Lösungen hier zu vergleichen.
Wie funktioniert '_' im Verständnis? –
Es ist nur eine Dummy-Variable. Ich hätte zum Beispiel "i" verwenden können, aber es ist üblich, dass "_" eine Wegwerfvariable bezeichnet, die nicht weiter verwendet wird, und welcher Name unwichtig ist. – DNF