2015-11-23 16 views
28

gibt es eine Möglichkeit, eine Aggregatfunktion auf alle (oder eine Liste von) Spalten eines Datenrahmens anzuwenden, wenn Sie eine Gruppe durch? Mit anderen Worten, gibt es eine Möglichkeit, dies für jede Spalte zu vermeiden:SparkSQL: anwenden Aggregatfunktionen auf eine Liste der Spalte

df.groupBy("col1") 
.agg(sum("col2").alias("col2"), sum("col3").alias("col3"), ...) 

vielen Dank!

Antwort

55

Es gibt mehrere Möglichkeiten, Aggregatfunktionen auf mehrere Spalten anzuwenden.

GroupedData Klasse bietet eine Anzahl von Methoden für die am häufigsten verwendeten Funktionen, einschließlich count, max, min, mean und sum, die direkt verwendet werden können, wie folgt:

  • Python:

    df = sqlContext.createDataFrame(
        [(1.0, 0.3, 1.0), (1.0, 0.5, 0.0), (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)], 
        ("col1", "col2", "col3")) 
    
    df.groupBy("col1").sum() 
    
    ## +----+---------+-----------------+---------+ 
    ## |col1|sum(col1)|  sum(col2)|sum(col3)| 
    ## +----+---------+-----------------+---------+ 
    ## | 1.0|  2.0|    0.8|  1.0| 
    ## |-1.0|  -2.0|6.199999999999999|  0.7| 
    ## +----+---------+-----------------+---------+ 
    
  • Scala

    val df = sc.parallelize(Seq(
        (1.0, 0.3, 1.0), (1.0, 0.5, 0.0), 
        (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)) 
    ).toDF("col1", "col2", "col3") 
    
    df.groupBy($"col1").min().show 
    
    // +----+---------+---------+---------+ 
    // |col1|min(col1)|min(col2)|min(col3)| 
    // +----+---------+---------+---------+ 
    // | 1.0|  1.0|  0.3|  0.0| 
    // |-1.0|  -1.0|  0.6|  0.2| 
    // +----+---------+---------+---------+ 
    

Optional können Sie eine Liste der Spalten übergeben, die

df.groupBy("col1").sum("col2", "col3") 

Sie auch Wörterbuch/Karte mit Spalten eine der Tasten und Funktionen wie die Werte passieren können aggregiert werden sollten:

  • Python

    exprs = {x: "sum" for x in df.columns} 
    df.groupBy("col1").agg(exprs).show() 
    
    ## +----+---------+ 
    ## |col1|avg(col3)| 
    ## +----+---------+ 
    ## | 1.0|  0.5| 
    ## |-1.0|  0.35| 
    ## +----+---------+ 
    
  • Scala
    val exprs = df.columns.map((_ -> "mean")).toMap 
    df.groupBy($"col1").agg(exprs).show() 
    
    // +----+---------+------------------+---------+ 
    // |col1|avg(col1)|   avg(col2)|avg(col3)| 
    // +----+---------+------------------+---------+ 
    // | 1.0|  1.0|    0.4|  0.5| 
    // |-1.0|  -1.0|3.0999999999999996|  0.35| 
    // +----+---------+------------------+---------+ 
    

Schließlich können Sie verwenden varargs:

  • Python

    from pyspark.sql.functions import min 
    
    exprs = [min(x) for x in df.columns] 
    df.groupBy("col1").agg(*exprs).show() 
    
  • Scala

    import org.apache.spark.sql.functions.sum 
    
    val exprs = df.columns.map(sum(_)) 
    df.groupBy($"col1").agg(exprs.head, exprs.tail: _*) 
    

Es gibt andere Möglichkeiten, einen ähnlichen Effekt zu erzielen, aber diese sollten mehr als genug die meiste Zeit sein.

+0

Es scheint 'aggregateBy' wäre hier anwendbar. Es ist schneller (viel schneller) als "groupBy".Oh, Moment mal - der 'DataFrame' zeigt 'aggregateBy' nicht an -' agg' zeigt auf 'groupBy'. Nun, das bedeutet, 'DataFrames' sind * langsam * .. – javadba

+0

@javadba Nein, es bedeutet nur, dass' Dataset.groupBy'/'Dataset.groupByKey' und' RDD.groupBy'/'RDD.groupByKey' im Allgemeinen unterschiedliche Semantik. Im Falle einfacher 'DataFrame'-Aggregationen [dies prüfen] (http://stackoverflow.com/a/32903568/1560062). Da ist mehr, aber hier ist es nicht wichtig. – zero323

+0

Schöne Info! Upvoted Ihre andere Antwort – javadba

6

Ein weiteres Beispiel für das gleiche Konzept - aber sagen - Sie zwei verschiedene Spalten haben - und Sie wollen jeden von ihnen, dh unterschiedliche agg Funktionen anwenden

f.groupBy("col1").agg(sum("col2").alias("col2"), avg("col3").alias("col3"), ...) 

Hier ist der Weg, es zu erreichen - obwohl ich noch nicht wissen, wie Sie den Alias ​​in diesem Fall hinzufügen

Beispiel unten sehen - Verwenden von Karten

val Claim1 = StructType(Seq(StructField("pid", StringType, true),StructField("diag1", StringType, true),StructField("diag2", StringType, true), StructField("allowed", IntegerType, true), StructField("allowed1", IntegerType, true))) 
val claimsData1 = Seq(("PID1", "diag1", "diag2", 100, 200), ("PID1", "diag2", "diag3", 300, 600), ("PID1", "diag1", "diag5", 340, 680), ("PID2", "diag3", "diag4", 245, 490), ("PID2", "diag2", "diag1", 124, 248)) 

val claimRDD1 = sc.parallelize(claimsData1) 
val claimRDDRow1 = claimRDD1.map(p => Row(p._1, p._2, p._3, p._4, p._5)) 
val claimRDD2DF1 = sqlContext.createDataFrame(claimRDDRow1, Claim1) 

val l = List("allowed", "allowed1") 
val exprs = l.map((_ -> "sum")).toMap 
claimRDD2DF1.groupBy("pid").agg(exprs) show false 
val exprs = Map("allowed" -> "sum", "allowed1" -> "avg") 

claimRDD2DF1.groupBy("pid").agg(exprs) show false 
Verwandte Themen