2016-10-21 3 views
0

Ich bin Durchführung einer Gridsearch mit H2O das Python-API mit dem folgenden Code,H2O Python API: abrufen besten Modelle von Gridsearch

from h2o.estimators.random_forest import H2ORandomForestEstimator 
from h2o.grid import H2OGridSearch 

hyper_parameters = {'ntrees':[10, 50, 100, 200], 'max_depth':[5, 10, 15, 20, 25], 'balance_classes':[True, False]} 

search_criteria = { 
    "strategy": "RandomDiscrete", 
    "max_runtime_secs": 600, 
    "max_models": 30, 
    "stopping_metric": 'AUTO', 
    "stopping_tolerance": 0.0001, 
    'seed': 42 
} 

grid_search = H2OGridSearch(H2ORandomForestEstimator, hyper_parameters, search_criteria=search_criteria) 
grid_search.train(x=events_names_x, 
        y="total_rsvps", 
        training_frame=train, 
        validation_frame=test) 

Einmal betreibe ich die Modelle drucken möchten, und in der Reihenfolge der AUC vorhersagen,

grid_search.sort_by('auc', False) 

ich folgende Fehlermeldung erhalten,

--------------------------------------------------------------------------- 
KeyError         Traceback (most recent call last) 
<ipython-input-272-b250bf2b838e> in <module>() 
----> 1 grid_search.sort_by('auc', False) 

/Users/stereo/.pyenv/versions/3.5.2/lib/python3.5/site-packages/h2o/grid/grid_search.py in sort_by(self, metric, increasing) 
    663 
    664   if metric[-1] != ')': metric += '()' 
--> 665   c_values = [list(x) for x in zip(*sorted(eval('self.' + metric + '.items()'), key=lambda k_v: k_v[1]))] 
    666   c_values.insert(1, [self.get_hyperparams(model_id, display=False) for model_id in c_values[0]]) 
    667   if not increasing: 

/Users/stereo/.pyenv/versions/3.5.2/lib/python3.5/site-packages/h2o/grid/grid_search.py in <module>() 

/Users/stereo/.pyenv/versions/3.5.2/lib/python3.5/site-packages/h2o/grid/grid_search.py in auc(self, train, valid, xval) 
    606   :return: The AUC. 
    607   """ 
--> 608   return {model.model_id: model.auc(train, valid, xval) for model in self.models} 
    609 
    610  def aic(self, train=False, valid=False, xval=False): 

/Users/stereo/.pyenv/versions/3.5.2/lib/python3.5/site-packages/h2o/grid/grid_search.py in <dictcomp>(.0) 
    606   :return: The AUC. 
    607   """ 
--> 608   return {model.model_id: model.auc(train, valid, xval) for model in self.models} 
    609 
    610  def aic(self, train=False, valid=False, xval=False): 

/Users/stereo/.pyenv/versions/3.5.2/lib/python3.5/site-packages/h2o/model/model_base.py in auc(self, train, valid, xval) 
    669   tm = ModelBase._get_metrics(self, train, valid, xval) 
    670   m = {} 
--> 671   for k, v in viewitems(tm): m[k] = None if v is None else v.auc() 
    672   return list(m.values())[0] if len(m) == 1 else m 
    673 

/Users/stereo/.pyenv/versions/3.5.2/lib/python3.5/site-packages/h2o/model/metrics_base.py in auc(self) 
    158   :return: Retrieve the AUC for this set of metrics. 
    159   """ 
--> 160   return self._metric_json['AUC'] 
    161 
    162  def aic(self): 

KeyError: 'AUC' 

Jede beraten auf:

  • die Modelle in der Reihenfolge der Leistung drucken
  • Prognose mit dem Modell mit der höchsten AUC

Antwort

2

, was Sie brauchen, ist

sorted_grid = grid_search.get_grid(sort_by='auc',decreasing=True) print(sorted_grid)

Sie auf False abnehmend ändern wenn Sie bevorzugen würden

+0

So 'auc' ist scheinbar nicht mehr unterstützt und musste 'mse' oder' r2'. Danke für die Beratung! – Stereo

+0

Können wir diese Ergebnisse irgendwie in eine Datei exportieren? – user90772

+0

auc sollte unterstützt werden, solange Sie ein binäres Klassifizierungsproblem machen, stellen Sie bitte sicher, dass Ihre Zielvariable eine enum mit '.asfactor()' ist (Sie können Beispiele in der h2o Benutzerführung finden). – Lauren

Verwandte Themen