2017-05-05 1 views
2

Ich habe einen Pandas Datenrahmen mit 3 Klassen und Datenpunkte von n Features.Pandas Legende für Streu-Matrix

Der folgende Code erstellt eine Streumatrix mit Histogrammen in der Diagonalen von 4 der Features im Datenframe.

colums = ['n1','n2','n3','n4'] 
grr = pd.scatter_matrix(
dataframe[columns], c=y_train, figsize=(15,15), label=['B','N','O'], marker='.', 
    hist_kwds={'bins':20}, s=10, alpha=.8, cmap='brg') 
plt.legend() 
plt.show() 

wie folgt aus:

Scatter matrix of this dataframe

Das Problem, das ich habe, ist, dass plt.legend() scheint nicht zu funktionieren, es keine Legende überhaupt angezeigt (oder es ist die kleine 'le8' kaum sichtbar in der ersten Spalte der zweiten Zeile ...)

Was ich gerne hätte, wäre eine einzelne Legende, die nur zeigt, welche Farbe welche Klasse ist.

Ich habe alle vorgeschlagenen Fragen ausprobiert, aber keine hat eine Lösung. Ich versuchte auch, wie dies die Beschriftungen in der Legende Funktionsparameter zu setzen:

plt.legend(label=['B','N','O'], loc=1) 

aber ohne Erfolg ..

Was mache ich falsch?

+0

Ich habe selber nie 'pd.scatter_matrix' Streumatrix Plot zu zeichnen, aber Seaborn nützlich sein könnte, wenn Sie wollen. Hier ein Beispiel mit der Legende: https://seaborn.pydata.org/examples/scatterplot_matrix.html –

Antwort

3

Die Pandas scatter_matrix ist ein Wrapper für mehrere Matplotlib scatter Plots. Argumente werden an die scatter Funktion weitergeleitet. Die Streuung wird jedoch normalerweise für die Verwendung mit einer Colormap und nicht für eine Legende mit diskreten beschrifteten Punkten verwendet, sodass für die automatische Erstellung einer Legende kein Argument verfügbar ist.

Ich habe Angst, Sie müssen die Legende manuell erstellen. Zu diesem Zweck können Sie die Punkte aus der Streuung mit der Funktion plot von matplotlib (mit leeren Daten) erstellen und sie der Legende als Punkte hinzufügen.

import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
plt.rcParams["figure.subplot.right"] = 0.8 

v= np.random.rayleigh(size=(30,5)) 
v[:,4] = np.random.randint(1,4,size=30)/3. 
dataframe= pd.DataFrame(v, columns=['n1','n2','n3','n4',"c"]) 

columns = ['n1','n2','n3','n4'] 
grr = pd.scatter_matrix(
dataframe[columns], c=dataframe["c"], figsize=(7,5), label=['B','N','O'], marker='.', 
    hist_kwds={'bins':20}, s=10, alpha=.8, cmap='brg') 

handles = [plt.plot([],[],color=plt.cm.brg(i/2.), ls="", marker=".", \ 
        markersize=np.sqrt(10))[0] for i in range(3)] 
labels=["Label A", "Label B", "Label C"] 
plt.legend(handles, labels, loc=(1.02,0)) 
plt.show() 

enter image description here

+0

Ja, thx !, das macht den Trick! –

Verwandte Themen