2016-03-30 4 views

Antwort

6

Eigentlich hat es trees Attribut:

import org.apache.spark.ml.attribute.NominalAttribute 
import org.apache.spark.ml.classification.{ 
    RandomForestClassificationModel, RandomForestClassifier, 
    DecisionTreeClassificationModel 
} 

val meta = NominalAttribute 
    .defaultAttr 
    .withName("label") 
    .withValues("0.0", "1.0") 
    .toMetadata 

val data = sqlContext.read.format("libsvm") 
    .load("data/mllib/sample_libsvm_data.txt") 
    .withColumn("label", $"label".as("label", meta)) 

val rf: RandomForestClassifier = new RandomForestClassifier() 
    .setLabelCol("label") 
    .setFeaturesCol("features") 

val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect { 
    case t: DecisionTreeClassificationModel => t 
} 

Wie Sie das einzige Problem zu sehen ist Typen richtig zu machen, damit wir diese tatsächlich nutzen:

trees.head.transform(data).show(3) 
// +-----+--------------------+-------------+-----------+----------+ 
// |label|   features|rawPrediction|probability|prediction| 
// +-----+--------------------+-------------+-----------+----------+ 
// | 0.0|(692,[127,128,129...| [33.0,0.0]| [1.0,0.0]|  0.0| 
// | 1.0|(692,[158,159,160...| [0.0,59.0]| [0.0,1.0]|  1.0| 
// | 1.0|(692,[124,125,126...| [0.0,59.0]| [0.0,1.0]|  1.0| 
// +-----+--------------------+-------------+-----------+----------+ 
// only showing top 3 rows 

Hinweis:

Wenn Sie mit Pipelines arbeiten, können Sie auch einzelne Bäume extrahieren:

import org.apache.spark.ml.Pipeline 

val model = new Pipeline().setStages(Array(rf)).fit(data) 

// There is only one stage and know its type 
// but lets be thorough 
val rfModelOption = model.stages.headOption match { 
    case Some(m: RandomForestClassificationModel) => Some(m) 
    case _ => None 
} 

val trees = rfModelOption.map { 
    _.trees // ... as before 
}.getOrElse(Array()) 
+0

Hallo Zero323, danke für deine Hilfe. Ich habe eine Follow-up-Frage. Ich möchte Regeln aus Baumknoten mit hohen Vorhersagewahrscheinlichkeiten extrahieren (sagen wir über 0,3). In 'spark.ml' ist das Objekt" instantationStats "in einem internen Knoten des Baums privat und das sind auch die Methoden toOld und fromOld. Ich brauche diese Details (auf die ich nicht zugreifen kann, da sie privat sind), um etwas extrahieren zu können. In ähnlicher Weise liefert die Aufteilung des Knotens keine Informationen über seine Kategorien und Merkmalschwellenwerte. Gibt es eine Möglichkeit, Regeln aus Knoten mit hoher Wahrscheinlichkeit in 'spark.ml' zu extrahieren? –

+0

Mir ist keine triviale Lösung bekannt. Sie sollten es getrennt fragen - vielleicht hat jemand bereits eine Lösung an Ort und Stelle. Wenn Sie dies tun, bitte ping mich mit einem Link. – zero323

+0

Danke zero323. Ich habe gerade die Frage "Wie man Regeln von Spark ML RandomForestClassifier Modell (Scala-Version) extrahieren?". Bleib auf dem laufenden, wenn ich eine Antwort bekomme. –

Verwandte Themen