2016-03-24 15 views
0
Actually I am working on pyspark code. My dataframe is 

+-------+--------+--------+--------+--------+ 
|element|collect1|collect2|collect3|collect4| 
+-------+--------+--------+--------+--------+ 
|A1  | 1.02 | 2.6 | 5.21 | 3.6 | 
|A2  | 1.61 | 2.42 | 4.88 | 6.08 | 
|B1  | 1.66 | 2.01 | 5.0 | 4.3 | 
|C2  | 2.01 | 1.85 | 3.42 | 4.44 | 
+-------+--------+--------+--------+--------+ 

Ich muss den Mittelwert und Stddev für jedes Element finden, indem Sie alle CollectX-Spalten aggregieren. Das Endergebnis sollte wie folgt sein.Spark-Datenframe-Aggregat für mehrere Spalten

+-------+--------+--------+ 
|element|mean |stddev | 
+-------+--------+--------+ 
|A1  | 3.11 | 1.76 | 
|A2  | 3.75 | 2.09 | 
|B1  | 3.24 | 1.66 | 
|C2  | 2.93 | 1.23 | 
+-------+--------+--------+ 

Der Code unten Bruchs alle der Mittelwert an den einzelnen Spalten df.groupBy ("Element"). Bedeuten(). Show(). Kann nicht für jede Spalte ein Rollup für alle Spalten durchgeführt werden?

+-------+-------------+-------------+-------------+-------------+ 
|element|avg(collect1)|avg(collect2)|avg(collect3)|avg(collect4)| 
+-------+-------------+-------------+-------------+-------------+ 
|A1  | 1.02  | 2.6  | 5.21  | 3.6  | 
|A2  | 1.61  | 2.42  | 4.88  | 6.08  | 
|B1  | 1.66  | 2.01  | 5.0  | 4.3  | 
|C2  | 2.01  | 1.85  | 3.42  | 4.44  | 
+-------+-------------+-------------+-------------+-------------+ 

Ich versuchte Nutzung der Funktion beschreiben zu machen, da sie die vollständigen Aggregationsfunktionen haben aber nach wie vor als einzelne Spalte df.groupBy ("Element"). Mittelwert(). Describe(). Show() gezeigt

dank

Antwort

0

Haben Sie die Spalten zusammen nur versuchen, das Hinzufügen und ggf. durch 4 dividiert?

SELECT avg((collect1 + collect2 + collect3 + collect4)/4), 
    stddev((collect1 + collect2 + collect3 + collect4)/4) 

Das wird nicht genau das tun, was Sie wollen, aber bekommen Sie die Idee.

Nicht sicher Ihre Sprache, aber man kann immer die Abfrage im Fluge bauen, wenn Sie nicht zufrieden sind mit hartkodierte:

val collectColumns = df.columns.filter(_.startsWith("collect")) 
val stmnt = "SELECT avg((" + collectColumns.mkString(" + ") + ")/" + collectColumns.length + "))" 

Sie bekommen die Idee.

+0

Tatsächlich kann der Datenrahmen mehr oder weniger collectX-Spalten haben. Hardcoded zu tun ist nicht die bevorzugte Wahl. – Chn

+0

Erstellen Sie einfach die Abfrage im laufenden Betrieb. Siehe Änderungen. –

0

Spark ermöglicht es Ihnen, alle Arten von Statistiken pro Spalte zu sammeln. Sie versuchen, Statistiken pro Zeile zu berechnen. In diesem Fall können Sie etwas mit udf hacken. Hier ein Beispiel: D

$ pyspark 
>>> from pyspark.sql.types import DoubleType 
>>> from pyspark.sql.functions import array, udf 
>>> 
>>> mean = udf(lambda v: sum(v)/len(v), DoubleType()) 
>>> df = sc.parallelize([['A1', 1.02, 2.6, 5.21, 3.6], ['A2', 1.61, 2.42, 4.88, 6.08]]).toDF(['element', 'collect1', 'collect2', 'collect3', 'collect4']) 
>>> df.show() 
+-------+--------+--------+--------+--------+ 
|element|collect1|collect2|collect3|collect4| 
+-------+--------+--------+--------+--------+ 
|  A1| 1.02|  2.6| 5.21|  3.6| 
|  A2| 1.61| 2.42| 4.88| 6.08| 
+-------+--------+--------+--------+--------+ 
>>> df.select('element', mean(array(df.columns[1:])).alias('mean')).show() 
+-------+------+ 
|element| mean| 
+-------+------+ 
|  A1|3.1075| 
|  A2|3.7475| 
+-------+------+