Wenn Sie einen großen Datensatz haben und Cluster nach Bedarf extrahieren müssen, wird eine Beschleunigung angezeigt, wenn Sie numpy.where
verwenden. Hier ein Beispiel zum Iris-Dataset:
from sklearn.cluster import KMeans
from sklearn import datasets
import numpy as np
centers = [[1, 1], [-1, -1], [1, -1]]
iris = datasets.load_iris()
X = iris.data
y = iris.target
km = KMeans(n_clusters=3)
km.fit(X)
Definieren Sie eine Funktion zum Extrahieren der Indizes der von Ihnen angegebenen Cluster-ID. (Hier sind zwei Funktionen, für Benchmarking, sie beide geben die gleichen Werte):
def ClusterIndicesNumpy(clustNum, labels_array): #numpy
return np.where(labels_array == clustNum)[0]
def ClusterIndicesComp(clustNum, labels_array): #list comprehension
return np.array([i for i, x in enumerate(labels_array) if x == clustNum])
Angenommen, Sie alle Proben möchten, die in Cluster sind 2
:
ClusterIndicesNumpy(2, km.labels_)
array([ 52, 77, 100, 102, 103, 104, 105, 107, 108, 109, 110, 111, 112,
115, 116, 117, 118, 120, 122, 124, 125, 128, 129, 130, 131, 132,
134, 135, 136, 137, 139, 140, 141, 143, 144, 145, 147, 148])
Numpy gewinnt den Benchmark:
%timeit ClusterIndicesNumpy(2,km.labels_)
100000 loops, best of 3: 4 µs per loop
%timeit ClusterIndicesComp(2,km.labels_)
1000 loops, best of 3: 479 µs per loop
Jetzt können Sie alle Ihre Cluster-2-Datenpunkte extrahieren etwa so:
X[ClusterIndicesNumpy(2,km.labels_)]
array([[ 6.9, 3.1, 4.9, 1.5],
[ 6.7, 3. , 5. , 1.7],
[ 6.3, 3.3, 6. , 2.5],
... #truncated
doppelt überprüfen Sie die ersten drei Indizes aus dem abgeschnittenen Array oben:
print X[52], km.labels_[52]
print X[77], km.labels_[77]
print X[100], km.labels_[100]
[ 6.9 3.1 4.9 1.5] 2
[ 6.7 3. 5. 1.7] 2
[ 6.3 3.3 6. 2.5] 2
Ja diese Methode funktionieren würde zu filtern. aber wenn es viele Datenpunkte gibt, die alle durchlaufen, um die Etiketten zu bekommen, ist das nicht effizient. Ich war nur die Liste der Datenpunkte für einen bestimmten Cluster. Gibt es nicht einen anderen Weg, dies zu tun? – user77005