2017-06-30 1 views
1

Ich habe begonnen, an einem Projekt zu arbeiten, in dem ich trainierbare Parameter für einen gegebenen scikit-learn Schätzer erkennen muss, und wenn möglich, zulässige Werte für kategoriale Variablen (und angemessene Intervalle für kontinuierliche) zu finden.Wie erkennt man, welche Werte in einem Parameterraster zulässig sind? (sklearn)

Ich kann ein Wörterbuch mit Parametern unter Verwendung estimator.get_params() holen und dann einen Wert unter Verwendung estimator.set_params(**{'var1':val1, 'var2':val2}), und so weiter einstellen.

Zum Beispiel haben wir für einen KNN-Klassifikator das folgende dict von params: {'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}.

Nun kann ich die Typen von den Werten abzuleiten, die kategorischen sind (str Typen), kontinuierlich (float), diskrete (int) und so weiter. Ein möglicherweise damit verbundenes Problem sind Parameter, für die der Standardwert auf NoneType festgelegt ist, aber ich kann diese aus einem guten Grund sowieso nicht berühren.

Die Herausforderung besteht nun darin, ein Parametergitter abzuleiten und zu definieren, das z. RandomizedSearchCV. Für diskrete und kontinuierliche Variablen ist das Problem unter Verwendung z.B. eine Kombination von try - except Blöcke zusammen mit dem scipy.stats Modul, möglich, das Intervall zu beschränken, um in der "Nähe" um den Standardwert zu liegen (aber gleichzeitig darauf achten, z. B. n_jobs nicht auf irgendeinen verrückten Wert einzustellen - das muss möglicherweise fest codiert oder später explizit festgelegt werden). Wenn Sie Erfahrung mit etwas Ähnlichem haben und ein paar Tipps/Tricks im Ärmel haben, würde ich gerne von ihnen hören.

Aber das eigentliche Problem ist jetzt: wie z. algorithm dass die zulässigen Werte tatsächlich {‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’} ??

Ich habe gerade angefangen, das Problem zu untersuchen, und vielleicht können wir die Fehlermeldung parsen, die wir bekommen, wenn wir versuchen, sie auf einen nicht zulässigen Wert zu setzen? Ich bin auf der Suche nach guten Ideen hier, wie ich vermeiden möchte, dies manuell zu tun (ich werde, wenn ich muss, aber es scheint eher unelegant ...)

Vielen Dank!

+0

Hinweis zu sich selbst: Dies könnte ein sehr schwieriges/unlösbares Problem sein. Ich habe in der API und im Quellcode herumgestöbert und geschaut, wie z. Auto-Sklearn löst dies. Es scheint, dass eine manuelle (hart codierte) Lösung der Weg ist, um jetzt zu gehen. – Magnus

+0

Interessantes Problem, das Sie dort haben.Abgesehen von [Parsing der Signatur und der Standardparameter] (https://stackoverflow.com/questions/2677185/how-can-i-read-a-functions-signature- including-default- argument- values) würde ich wohl versuchen Scripten von scikit-learn wie [this] (https://stackoverflow.com/questions/713138/getting-the-docstring-from-a-function) analysieren. Eine andere Sache, die zu versuchen wäre, wäre das Analysieren der verknüpften Funktion, z. "__init__" des Schätzers, aber das ist eine - chaotisch - lange Zeit, da ich keine Kontrollen dort sehe, und es gibt eine ganze Hierarchie, die Sie betrachten müssen. – mkaran

+0

Hallo! Schön, dass Sie das Thema interessant finden. Ja, das war/ist eine der Optionen, die ich in Betracht gezogen habe (das Parsing des Dokuments). Aber was mich beunruhigt, ist die Konsistenz in der Art, wie die Docstrings geschrieben werden, und es gibt keine erzwungenen Konventionen (aber ich könnte mich irren), die man ausnutzen könnte. Ich könnte einfach ein wenig Zeit damit verbringen, einen Parser zu implementieren und es an einer Reihe von Docstrings zu testen ... – Magnus

Antwort

0

Ich fand eine Lösung für das spezielle Beispiel, das ich betrachtete, aber es verallgemeinert sich nicht gut auf andere Doc-Strings, da es keine festgelegte Konvention gibt, wie sie für jeden Schätzer in sklearn geschrieben werden.

Deshalb poste ich meine "Lösung", damit andere übernehmen und möglicherweise verbessern können. Sehen Sie im folgenden Code-Schnipsel:

import re 
from pprint import pprint 
from sklearn.neighbors import KNeighborsClassifier 
knn = KNeighborsClassifier() 
doc = knn.__doc__ # Get the doc string 
#from sklearn.svm import SVC 
#svc = SVC() 
#doc = svc.__doc__ 
pattern = "([a-zA-Z_]+\s:\s)|(-\s*)'([a-zA-Z_]+)'" # Define search pattern 
re.compile(pattern) 
matches = re.findall(pattern, doc) 

clf_params = {} 
previous_param = '' 
for param, _, value in matches: 
    if ":" in param and param[-4]!="_": # 'Hack-y' 
     if param not in clf_params.keys(): 
      clf_params[param] = list() 
      previous_param = param 
     else: 
      if len(value)>0: 
       clf_params[previous_param].append(value) 
pprint(clf_params) 

Dieser Ausschnitt druckt

{'algorithm : ': ['ball_tree', 'kd_tree', 'brute', 'auto'], 
'leaf_size : ': [], 
'metric : ': [], 
'metric_params : ': [], 
'n_jobs : ': [], 
'n_neighbors : ': [], 
'p : ': [], 
'weights : ': ['uniform', 'distance']} 

Was ist richtig.

Wenn wir jedoch das gleiche Verfahren für SVC().__doc__ wiederholen, werden wir sehen, dass es fehlschlägt.

Ich hoffe jemand findet das etwas nützlich.

Verwandte Themen