Das Entwickler-API-Beispiel (https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala) gibt ein einfaches Implementierungsbeispiel für die Funktion predictRaw() in einem Klassifizierungsmodell. Dies ist eine Funktion innerhalb der abstrakten Klasse ClassificationModel, die in der konkreten Klasse implementiert werden muss. Nach Angaben der Entwickler-API Beispiel können Sie es wie folgt berechnen:Der Versuch, PredictRaw() für das Klassifizierungsmodell in Apache Spark zu implementieren
override def predictRaw(features: Features.Type): Vector = {
val margin = BLAS.dot(features, coefficients)
Vectors.dense(-margin, margin) // Binary classification so we return a length-2 vector, where index i corresponds to class i (i = 0, 1).
}
Mein Verständnis von BLAS.dot(features, coefficients)
ist, dass dies einfach die Matrix Punktprodukt des Merkmalsvektor ist (der Länge NumFeatures) durch den Koeffizienten-Vektor (der Länge numFeatures), so dass jedes "Feature" cols mit einem Koeffizienten multipliziert und dann summiert wird, um val margin
zu erhalten. Spark stellt jedoch nicht mehr den Zugriff auf die BLAS-Bibliothek zur Verfügung, da es in MLlib privat ist. Stattdessen wird Matrix-Multiplikation in der Matrix-Eigenschaft bereitgestellt, wo es verschiedene Factory-Methoden für die Multiplikation gibt.
Mein Verständnis davon, wie die Matrix Factory-Methoden zu implementieren predictRaw()
ist wie folgt:
override def predictRaw(features: Vector): Vector = {
//coefficients is a Vector of length numFeatures: val coefficients = Vectors.zeros(numFeatures)
val coefficientsArray = coefficients.toArray
val coefficientsMatrix: SparkDenseMatrix = new SparkDenseMatrix(numFeatures, 1, coefficientsArray)
val margin: Array[Double] = coefficientsMatrix.multiply(features).toArray // contains a single element
val rawPredictions: Array[Double] = Array(-margin(0),margin(0))
new SparkDenseVector(rawPredictions)
}
Dies wird den Aufwand für die Umwandlung der Datenstrukturen Arrays erfordern. Gibt es einen besseren Weg? Es erscheint seltsam, dass BLAS jetzt privat ist. NB. Code nicht getestet! Im Moment val coefficients: Vector
ist nur ein Vektor von Nullen, aber sobald ich den Lernalgorithmus implementiert habe, würde dies die Ergebnisse enthalten.