2017-07-24 3 views
1

Ich versuche, meine Tensorflow-Code-Profil (Lauf und Speicherverbrauch der einzelnen Schichten im Netzwerk) durch die Laufzeitstatistik-Anweisung here zu erhalten. Soweit ich verstehe, muss ich laufen Optionen erstellen und Metadaten wie dieseGet Runtime-Statistik mit überwachtem Trainingssitzung in Tensorflow

run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 
run_metadata = tf.RunMetadata() 

und sie an sess.run

Allerdings laufen, wie ich es auch zu benutzen tf.train.MonitoredTrainingSession versuche ich weiß nicht, ob ich kann dasselbe in diese Klasse weitergeben. Ein plausibler Ansatz könnte Hooks nutzen, aber ich weiß nicht, wie ich das machen soll. Ich bin immer noch sehr neu zu ihnen

Antwort

2

Sie können einfach einen benutzerdefinierten Haken erstellen und an die MonitoredTrainingSession übergeben. Es besteht keine Notwendigkeit, Ihre eigene tf.RunMetadata() Instanz an den Aufruf zu übergeben. Hier

ist ein Beispiel Haken, die Metadaten alle N Schritte ckptdir speichert:

import tensorflow as tf 

class TraceHook(tf.train.SessionRunHook): 
    """Hook to perform Traces every N steps.""" 

    def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE): 
     self._trace = every_step == 1 
     self.writer = tf.summary.FileWriter(ckptdir) 
     self.trace_level = trace_level 
     self.every_step = every_step 

    def begin(self): 
     self._global_step_tensor = tf.train.get_global_step() 
     if self._global_step_tensor is None: 
      raise RuntimeError("Global step should be created to use _TraceHook.") 

    def before_run(self, run_context): 
     if self._trace: 
      options = tf.RunOptions(trace_level=self.trace_level) 
     else: 
      options = None 
     return tf.train.SessionRunArgs(fetches=self._global_step_tensor, 
             options=options) 

    def after_run(self, run_context, run_values): 
     global_step = run_values.results - 1 
     if self._trace: 
      self._trace = False 
      self.writer.add_run_metadata(run_values.run_metadata, 
             f'{global_step}', global_step) 
     if not (global_step + 1) % self.every_step: 
      self._trace = True 

Es in before_run überprüft, ob sie verfolgen hat oder nicht und wenn ja, fügt die RunOptions. In after_run überprüft es, ob der nächste Aufruf ausgeführt werden muss, und wenn ja, setzt er _trace erneut auf True. Außerdem speichert es die Metadaten, wenn sie verfügbar sind.