2015-12-10 5 views
15

Ich habe einen Datenrahmen mit Schema als solche:Wie fasst man Werte nach groupBy in einer Sammlung zusammen?

[visitorId: string, trackingIds: array<string>, emailIds: array<string>] 

zu Gruppe nach einem Weg suchen (oder vielleicht Rollup?) Dieser Datenrahmen von visitorid wo die trackingIds und emailIds Spalte zusammenhänge würden. So zum Beispiel, wenn meine erste df wie folgt aussieht:

visitorId |trackingIds|emailIds 
+-----------+------------+-------- 
|a158|  [666b]  | [12] 
|7g21|  [c0b5]  | [45] 
|7g21|  [c0b4]  | [87] 
|a158|  [666b, 777c]| [] 

würde ich meine Ausgabe df wie diese groupBy und agg Operatoren zu verwenden Versuch

visitorId |trackingIds|emailIds 
+-----------+------------+-------- 
|a158|  [666b,666b,777c]|  [12,''] 
|7g21|  [c0b5,c0b4]  |  [45, 87] 

aussehen mögen, aber nicht viel Glück haben.

Antwort

17

Spark-2.x

Es ist möglich, aber recht teuer. Mit den Daten Sie zur Verfügung gestellt haben:

case class Record(
    visitorId: String, trackingIds: Array[String], emailIds: Array[String]) 

val df = Seq(
    Record("a158", Array("666b"), Array("12")), 
    Record("7g21", Array("c0b5"), Array("45")), 
    Record("7g21", Array("c0b4"), Array("87")), 
    Record("a158", Array("666b", "777c"), Array.empty[String])).toDF 

und eine Hilfsfunktion:

import org.apache.spark.sql.functions.udf 

val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten) 

können wir die Rohlinge mit Platzhalter füllen:

import org.apache.spark.sql.functions.{array, lit, when} 

val dfWithPlaceholders = df.withColumn(
    "emailIds", 
    when(size($"emailIds") === 0, array(lit(""))).otherwise($"emailIds")) 

collect_lists und flatten:

import org.apache.spark.sql.functions.{array, collect_listn} 

val emailIds = flatten(collect_list($"emailIds")).alias("emailIds") 
val trackingIds = flatten(collect_list($"trackingIds")).alias("trackingIds") 

df 
    .groupBy($"visitorId") 
    .agg(trackingIds, emailIds) 

// +---------+------------------+--------+ 
// |visitorId|  trackingIds|emailIds| 
// +---------+------------------+--------+ 
// |  a158|[666b, 666b, 777c]| [12, ]| 
// |  7g21|  [c0b5, c0b4]|[45, 87]| 
// +---------+------------------+--------+ 

Mit getippt statisch Dataset:

df.as[Record] 
    .groupByKey(_.visitorId) 
    .mapGroups { case (key, vs) => 
    vs.map(v => (v.trackingIds, v.emailIds)).toArray.unzip match { 
     case (trackingIds, emailIds) => 
     Record(key, trackingIds.flatten, emailIds.flatten) 
    }} 

// +---------+------------------+--------+ 
// |visitorId|  trackingIds|emailIds| 
// +---------+------------------+--------+ 
// |  a158|[666b, 666b, 777c]| [12, ]| 
// |  7g21|  [c0b5, c0b4]|[45, 87]| 
// +---------+------------------+--------+ 

Spark-1.x

Sie nach RDD und Gruppe umwandeln kann

import org.apache.spark.sql.Row 

dfWithPlaceholders.rdd 
    .map { 
    case Row(id: String, 
     trcks: Seq[String @ unchecked], 
     emails: Seq[String @ unchecked]) => (id, (trcks, emails)) 
    } 
    .groupByKey 
    .map {case (key, vs) => vs.toArray.unzip match { 
    case (trackingIds, emailIds) => 
     Record(key, trackingIds.flatten, emailIds.flatten) 
    }} 
    .toDF 

// +---------+------------------+--------+ 
// |visitorId|  trackingIds|emailIds| 
// +---------+------------------+--------+ 
// |  7g21|  [c0b5, c0b4]|[45, 87]| 
// |  a158|[666b, 666b, 777c]| [12, ]| 
// +---------+------------------+--------+ 
+0

, was diese Methode nicht abflachen genau tun? – xXxpRoGrAmmErxXx

+0

Was, wenn wir Dubletten in 'trackingIds' entfernen müssen? – puru

6

@ zero323 Antwort ist ziemlich viel vollständig, aber Funken gibt uns noch mehr Flexibilität. Wie wäre es mit der folgenden Lösung?

import org.apache.spark.sql.functions._ 
inventory 
    .select($"*", explode($"trackingIds") as "tracking_id") 
    .select($"*", explode($"emailIds") as "email_id") 
    .groupBy("visitorId") 
    .agg(
    collect_list("tracking_id") as "trackingIds", 
    collect_list("email_id") as "emailIds") 

Das aber lässt alle leeren Sammlungen aus (so gibt es noch Raum für Verbesserungen :))

+1

In dieser Lösung ist es möglich, eine orderBy() nach der groupBy und vor dem agg()? Oder in dieser Situation wird Ordnung nicht deterministisch sein? –

+0

Meiner Meinung nach antworten Sie, es ist nicht neu für die folgenden Gründe a) explodieren ist in funken.22.b) collect_list auf einem sehr großen Dataset kann den Treiberprozess mit OutOfMemoryError – xXxpRoGrAmmErxXx

+0

@xXxpRoGrAmmErxXx zum Absturz bringen Bitte lassen Sie sich nicht mit dem 'explode' Operator und der' explode' Funktion verwirren. Für b) möglicherweise. –

0

Sie benutzerdefinierte aggregierte Funktionen nutzen können.

1) Erstellen Sie eine benutzerdefinierte UDAF mithilfe der Scala-Klasse namens customAggregation.

package com.package.name 

import org.apache.spark.sql.Row 
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} 
import org.apache.spark.sql.types._ 
import scala.collection.JavaConverters._ 

class CustomAggregation() extends UserDefinedAggregateFunction { 

// Input Data Type Schema 
def inputSchema: StructType = StructType(Array(StructField("col5", ArrayType(StringType)))) 

// Intermediate Schema 
def bufferSchema = StructType(Array(
StructField("col5_collapsed", ArrayType(StringType)))) 

// Returned Data Type . 
def dataType: DataType = ArrayType(StringType) 

// Self-explaining 
def deterministic = true 

// This function is called whenever key changes 
def initialize(buffer: MutableAggregationBuffer) = { 
buffer(0) = Array.empty[String] // initialize array 
} 

// Iterate over each entry of a group 
def update(buffer: MutableAggregationBuffer, input: Row) = { 
buffer(0) = 
    if(!input.isNullAt(0)) 
    buffer.getList[String](0).toArray ++ input.getList[String](0).toArray 
    else 
    buffer.getList[String](0).toArray 
} 

    // Merge two partial aggregates 
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { 
buffer1(0) = buffer1.getList[String](0).toArray ++ buffer2.getList[String](0).toArray 
} 

// Called after all the entries are exhausted. 
def evaluate(buffer: Row) = { 
    buffer.getList[String](0).asScala.toList.distinct 
} 
} 

2) Dann die UDAF als

in Ihrem Code verwenden
//define UDAF 
val CustomAggregation = new CustomAggregation() 
DataFrame 
    .groupBy(col1,col2,col3) 
    .agg(CustomAggregation(DataFrame(col5))).show() 
Verwandte Themen