2017-01-19 4 views
2

Ich habe eine RelationalGroupedDataset durch den Aufruf instances.groupBy(instances.col("property_name")) erstellt:Wie findet man den Mittelwert von gruppierten Vector-Spalten in Spark SQL?

val x = instances.groupBy(instances.col("property_name")) 

Wie komponiere ich user-defined aggregate function ein Statistics.colStats().mean für jede Gruppe durchführen?

Danke!

+0

versuchen Sie nur, den Mittelwert einer Spalte zu erhalten? Könnten Sie erklären, was für Input und Output Sie erwarten? Was fehlt auch an den von Ihnen bereitgestellten Links? –

+0

Jede Zeile hat eine Bezeichnung und einen Merkmalsvektor. Ich gruppiere die Zeilen nach Label und möchte einen Vektormittelwert der Merkmalsvektoren verwenden. Die Lösung fehlt in den von mir bereitgestellten Links. –

+0

was ist falsch mit instances.groupBy (instances.col ("property_name")). Agg (avg ("col1"), avg ("col2") ...) –

Antwort

8

Sie können nicht UserDefinedAggregateFunction verwenden, aber Sie können einen Aggregator mit dem gleichen MultivariateOnlineSummarizer erstellen:

import org.apache.spark.sql.Row 
import org.apache.spark.sql.expressions.Aggregator 
import org.apache.spark.mllib.linalg.{Vector, Vectors} 
import org.apache.spark.sql.{Encoder, Encoders} 
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder 
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer 

type Summarizer = MultivariateOnlineSummarizer 

case class VectorSumarizer(f: String) extends Aggregator[Row, Summarizer, Vector] 
    with Serializable { 
    def zero = new Summarizer 
    def reduce(acc: Summarizer, x: Row) = acc.add(x.getAs[Vector](f)) 
    def merge(acc1: Summarizer, acc2: Summarizer) = acc1.merge(acc2) 

    // This can be easily generalized to support additional statistics 
    def finish(acc: Summarizer) = acc.mean 

    def bufferEncoder: Encoder[Summarizer] = Encoders.kryo[Summarizer] 
    def outputEncoder: Encoder[Vector] = ExpressionEncoder() 
} 

Beispiel Nutzung:

import org.apache.spark.mllib.random.RandomRDDs.logNormalVectorRDD 

val df = spark.sparkContext.union((1 to 10).map(i => 
    logNormalVectorRDD(spark.sparkContext, i, 10, 10000, 3, 1).map((i, _)) 
)).toDF("group", "features") 

df 
.groupBy($"group") 
.agg(VectorSumarizer("features").toColumn.alias("means")) 
.show(10, false) 

Das Ergebnis:

+-----+---------------------------------------------------------------------+ 
|group|means                | 
+-----+---------------------------------------------------------------------+ 
|1 |[1.0495089547176625E15,3.057434217141363E13,8.180842267228103E13] | 
|6 |[8.578684690153061E15,1.865830977115807E14,1.0690831496167929E15] | 
|3 |[1.0347016972600206E14,4.952536828257269E15,8.498944924018858E13] | 
|5 |[2.2135916061736424E16,1.5137112888230388E14,8.154750681129871E14] | 
|9 |[6.496030194110956E15,6.2697260327708368E16,3.7282521260607136E16] | 
|4 |[2.4518629692233766E14,1.959083619621557E13,5.278689364420169E13] | 
|8 |[1.806052212008392E16,2.0410654639336184E16,6.409495244104527E15] | 
|7 |[1.32896092658714784E17,1.2074042288752348E15,1.10951746294648096E17]| 
|10 |[1.6131199347666342E19,1.24546214832341616E17,8.5265750194040304E16] | 
|2 |[4.330324858747168E12,6.19671483053885E12,2.2416578004282832E13]  | 
+-----+---------------------------------------------------------------------+ 

Hinweis:

  • Bitte beachten Sie, dass MultivariateOnlineSummarizer erfordert "alten Stil" mllib.linalg.Vector. Es funktioniert nicht mit ml.linalg.Vector. Um diese zu unterstützen, müssen Sie convert between new and old types.
  • Leistung weise werden Sie wahrscheinlich better off with RDDs sein.
Verwandte Themen