Sie können einfach einen benutzerdefinierten Haken erstellen und an die MonitoredTrainingSession
übergeben. Es besteht keine Notwendigkeit, Ihre eigene tf.RunMetadata()
Instanz an den Aufruf zu übergeben. Hier
ist ein Beispiel Haken, die Metadaten alle N Schritte ckptdir speichert:
import tensorflow as tf
class TraceHook(tf.train.SessionRunHook):
"""Hook to perform Traces every N steps."""
def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE):
self._trace = every_step == 1
self.writer = tf.summary.FileWriter(ckptdir)
self.trace_level = trace_level
self.every_step = every_step
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use _TraceHook.")
def before_run(self, run_context):
if self._trace:
options = tf.RunOptions(trace_level=self.trace_level)
else:
options = None
return tf.train.SessionRunArgs(fetches=self._global_step_tensor,
options=options)
def after_run(self, run_context, run_values):
global_step = run_values.results - 1
if self._trace:
self._trace = False
self.writer.add_run_metadata(run_values.run_metadata,
f'{global_step}', global_step)
if not (global_step + 1) % self.every_step:
self._trace = True
Es in before_run
überprüft, ob sie verfolgen hat oder nicht und wenn ja, fügt die RunOptions. In after_run
überprüft es, ob der nächste Aufruf ausgeführt werden muss, und wenn ja, setzt er _trace
erneut auf True. Außerdem speichert es die Metadaten, wenn sie verfügbar sind.