2017-02-14 1 views
2

Ich habe ein folgenden Code linear_model.Lasso:Python - k fach Kreuzvalidierung für linear_model.Lasso

X_train, X_test, y_train, y_test = cross_validation.train_test_split(X,y,test_size=0.2) 
clf = linear_model.Lasso() 
clf.fit(X_train,y_train) 
accuracy = clf.score(X_test,y_test) 
print(accuracy) 

Ich will k-fache (10-mal genau zu sein) auszuführen cross_validation. Was wäre der richtige Code, um das zu tun?

Antwort

2

können Sie führen das 10-fache des model_selection Modul:

# for 0.18 version or newer, use: 
from sklearn.model_selection import cross_val_score 

# for pre-0.18 versions of scikit, use: 
from sklearn.cross_validation import cross_val_score 

X = # Some features 
y = # Some classes 

clf = linear_model.Lasso() 
scores = cross_val_score(clf, X, y, cv=10) 

Dieser Code 10 verschiedene Noten zurück. Sie können den Mittelwert leicht zu bekommen:

scores.mean() 
+0

Ich benutze Python3, und ich habe die folgende Fehlermeldung: ImportError: Kein Modul mit dem Namen 'sklearn.model_selection'. Irgendwelche Hinweise? – Ryo

+0

Welche scikit-learn Version verwendest du? Die model_selection-Modul-Funktionalität, die im cross_validation-Modul verwendet wurde – Elisha

+0

Ich verwende '0.17.1' – Ryo

4

hier ist der Code, den ich auf einem linearen Regressionsmodell Kreuzvalidierung durchführen verwenden und auch die Details zu erhalten:

from sklearn.model_selection import cross_val_score 
scores = cross_val_score(clf, X_Train, Y_Train, scoring="neg_mean_squared_error", cv=10) 
rmse_scores = np.sqrt(-scores) 

Wie gesagt in this Buch auf Seite 108 dies ist der Grund, warum wir -score verwenden:

Scikit-Learn cross-validation features expect a utility function (greater is better) rather than a cost function (lower is better), so the scoring function is actually the opposite of the MSE (i.e., a negative value), which is why the preceding code computes -scores before calculating the square root.

und das Ergebnis verwendet diese einfache Funktion sichtbar zu machen:

def display_scores(scores): 
    print("Scores:", scores) 
    print("Mean:", scores.mean()) 
    print("Standard deviation:", scores.std()) 
Verwandte Themen