2015-03-31 16 views
12

Angenommen, ich habe einen DataFrame (den ich von einem csv auf HDFS eingelesen habe) und ich möchte einige Algorithmen über MLlib darin trainieren. Wie konvertiere ich die Zeilen in LabeledPoints oder verwende MLlib in diesem Datensatz? Verwenden von DataFrame mit MLlib

+1

Sie nicht den Datentyp der Spalten erwähnt haben, aber wenn sie numerisch sind (integer, double, etc) Sie können [VectorAssembler] (http: //spark.apache .org/docs/latest/ml-features.html # vectorasembler) um die Feature-Spalten in eine Spalte von [Vector] zu konvertieren (http://spark.apache.org/docs/latest/mllib-data-types.html) . – Ben

Antwort

5

Sie Angenommen Scala mit:

Lassen Sie uns sagen, dass Ihre die DataFrame erhalten wie folgt:

val results : DataFrame = sqlContext.sql(...) 

Schritt 1: call results.printSchema() - dies zeigt Ihnen nicht nur die Spalten in der DataFrame und (das ist wichtig) ihre Reihenfolge, aber auch, was Spark SQL für ihre Typen hält. Sobald Sie diese Ausgabe sehen, werden die Dinge viel weniger geheimnisvoll.

Schritt 2: ein RDD[Row] aus den DataFrame Get:

val rows: RDD[Row] = results.rdd 

Schritt 3: Jetzt ist es nur eine Frage des Ziehens, was Interesse Sie aus den einzelnen Zeilen Felder. Dafür müssen Sie die 0-basierte Position jedes Felds und dessen Typ kennen, und zum Glück haben Sie alles in Schritt 1 oben erhalten. Zum Beispiel lasst uns sagen, dass Sie eine SELECT x, y, z, w FROM ... tat und Drucken des Schemas ergab

root 
|-- x double (nullable = ...) 
|-- y string (nullable = ...) 
|-- z integer (nullable = ...) 
|-- w binary (nullable = ...) 

Und lassen Sie uns alle sagen, Sie x und z verwenden wollte. Sie können sie in eine RDD[(Double, Integer)] herausziehen wie folgt:

rows.map(row => { 
    // x has position 0 and type double 
    // z has position 2 and type integer 
    (row.getDouble(0), row.getInt(2)) 
}) 

Von hier aus nur Sie Core-Spark-verwenden, um die relevanten MLlib Objekte zu erstellen. Die Dinge könnten etwas komplizierter werden, wenn Ihr SQL Spalten des Array-Typs zurückgibt. In diesem Fall müssen Sie getList(...) für diese Spalte aufrufen.

2

Angenommen, Sie sind mit JAVA (Spark-Version 1.6.2): ​​

Hier ist ein einfaches Beispiel für JAVA-Code für maschinelles Lernen mit Datenrahmen.

  • Es lädt ein JSON mit der folgenden Struktur,

    [{ "label": 1 "att2": 5,037089672359123 "att1": 2,4100883023159456}, ...]

  • teilt die Daten in Trainings- und Test,

  • Zug des Modell der Bahndaten,
  • das Modell zu den Testdaten und
  • stor gilt Es sind die Ergebnisse.

Außerdem nach den official documentation die "Datenrahmen-basierten API ist die primäre API" für MLlib seit der aktuellen Version 2.0.0. So können Sie mit DataFrame mehrere Beispiele finden.

Der Code:

SparkConf conf = new SparkConf().setAppName("MyApp").setMaster("local[2]"); 
SparkContext sc = new SparkContext(conf); 
String path = "F:\\SparkApp\\test.json"; 
String outputPath = "F:\\SparkApp\\justTest"; 

System.setProperty("hadoop.home.dir", "C:\\winutils\\"); 

SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); 

DataFrame df = sqlContext.read().json(path); 
df.registerTempTable("tmp"); 
DataFrame newDF = df.sqlContext().sql("SELECT att1, att2, label FROM tmp"); 
DataFrame dataFixed = newDF.withColumn("label", newDF.col("label").cast("Double")); 

VectorAssembler assembler = new VectorAssembler().setInputCols(new String[]{"att1", "att2"}).setOutputCol("features"); 
StringIndexer indexer = new StringIndexer().setInputCol("label").setOutputCol("labelIndexed"); 

// Split the data into training and test 
DataFrame[] splits = dataFixed.randomSplit(new double[] {0.7, 0.3}); 
DataFrame trainingData = splits[0]; 
DataFrame testData = splits[1]; 

DecisionTreeClassifier dt = new DecisionTreeClassifier().setLabelCol("labelIndexed").setFeaturesCol("features"); 
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {assembler, indexer, dt}); 
// Train model 
PipelineModel model = pipeline.fit(trainingData); 

// Make predictions 
DataFrame predictions = model.transform(testData); 
predictions.rdd().coalesce(1,true,null).saveAsTextFile("justPlay.txt" +System.currentTimeMillis()); 
Verwandte Themen