2017-03-17 7 views
1

Ich habe einen Code (see here), der slice_X() (Standort: keras.engine.training) verwendet, um das gemeinsame TensorBoard-Backend zu modifizieren (was zu OOM (GPU-Nutzung) führt), um während der GPU-Nutzung TensorBoard zu verwenden. Leider führt dies seit meinem Upgrade auf Keras 2.0.0 zu einem Importfehler, da keras.engine.training keine slice_X() mehr enthält. Wohin ist es gegangen? Welche alternative Lösung ist möglich?Wo funktioniert Keras `slice_X()` go?

Ich schätze Ihre Hilfe sehr.

E D I T:

Ich habe den Code (see here) zu Keras 2.0.0 und Tensorflow R1.0 aktualisiert.

class TensorBoard(keras.callbacks.Callback): 
    ''' 
    Avoids OOM problem. 
    Adapted by: https://github.com/Vladimir-Yashin/keras/blob/13e6a1f99f33a3cc7bc0a44d285fda457cc808e4/keras/callbacks.py 
    Updated according to discussion: 
    http://stackoverflow.com/questions/42852495/where-did-keras-function-slice-x-go/42855104?noredirect=1#42855104 

Tensorboard basic visualizations. 
This callback writes a log for TensorBoard, which allows 
you to visualize dynamic graphs of your training and test 
metrics, as well as activation histograms for the different 
layers in your model. 
TensorBoard is a visualization tool provided with TensorFlow. 
If you have installed TensorFlow with pip, you should be able 
to launch TensorBoard from the command line: 
``` 
tensorboard --logdir=/full_path_to_your_logs 
``` 
You can find more information about TensorBoard 
[here](https://www.tensorflow.org/versions/master/how_tos/summaries_and_tensorboard/index.html). 
# Arguments 
    log_dir: the path of the directory where to save the log 
     files to be parsed by Tensorboard 
    histogram_freq: frequency (in epochs) at which to compute activation 
     histograms for the layers of the model. If set to 0, 
     histograms won't be computed. 
    write_graph: whether to visualize the graph in Tensorboard. 
     The log file can become quite large when 
     write_graph is set to True. 
''' 

def __init__(self, log_dir='./logs', histogram_freq=0, write_graph=True, write_images=False): 
    super(BatchedTensorBoard, self).__init__() 
    if K._BACKEND != 'tensorflow': 
     raise RuntimeError('TensorBoard callback only works ' 
          'with the TensorFlow backend.') 
    self.log_dir = log_dir 
    self.histogram_freq = histogram_freq 
    self.merged = None 
    self.write_graph = write_graph 
    self.write_images = write_images 
    #print(dir(self)) 

def set_model(self, model): 
    import tensorflow as tf 
    import keras.backend.tensorflow_backend as KTF 

    self.model = model 
    self.sess = KTF.get_session() 
    if self.histogram_freq and self.merged is None: 
     for layer in self.model.layers: 

      for weight in layer.weights: 
       tf.summary.histogram(weight.name, weight) 

       if self.write_images: 
        w_img = tf.squeeze(weight) 

        shape = w_img.get_shape() 
        if len(shape) > 1 and shape[0] > shape[1]: 
         w_img = tf.transpose(w_img) 

        if len(shape) == 1: 
         w_img = tf.expand_dims(w_img, 0) 

        w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1) 

        tf.image_summary(weight.name, w_img) 

      if hasattr(layer, 'output'): 
       tf.summary.histogram('{}_out'.format(layer.name), 
            layer.output) 
    if parse_version(tf.__version__) >= parse_version('0.12.0'): 
     self.merged = tf.summary.merge_all() 
    else: 
     self.merged = tf.merge_all_summaries() 
    if self.write_graph: 
     if parse_version(tf.__version__) >= parse_version('0.12.0'): 
      self.writer = tf.summary.FileWriter(self.log_dir, 
               self.sess.graph) 
     elif parse_version(tf.__version__) >= parse_version('0.8.0'): 
      self.writer = tf.train.SummaryWriter(self.log_dir, 
               self.sess.graph) 
     else: 
      self.writer = tf.train.SummaryWriter(self.log_dir, 
               self.sess.graph_def) 
    else: 
     if parse_version(tf.__version__) >= parse_version('0.12.0'): 
      self.writer = tf.summary.FileWriter(self.log_dir) 
     else: 
      self.writer = tf.train.SummaryWriter(self.log_dir) 

def on_epoch_end(self, epoch, logs={}): 
    import tensorflow as tf 
    from keras.engine.training import _slice_arrays #original: from keras.engine.training import slice_X 
    tf_session = K.get_session() 
    #result = [] 

    if self.validation_data and self.histogram_freq: 
     if epoch % self.histogram_freq == 0: 
      if self.model.uses_learning_phase: 
       cut_v_data = len(self.model.inputs) 
       val_data = self.validation_data[:cut_v_data] + [0] 
       tensors = self.model.inputs + [K.learning_phase()] 
      else: 
       val_data = self.validation_data 
       tensors = self.model.inputs 
      # Sample one batch of validation data to avoid OOM on GPU 
      if 'batch_size' in self.params: 
       index_array = np.arange(len(val_data[0])) 
       batch_ids = np.random.choice(index_array, self.params['batch_size']) 
       if self.model.uses_learning_phase: 
        ins_batch = _slice_arrays(val_data[:-1], batch_ids) + [val_data[-1]] #original: slice_X(val_data[:-1], batch_ids) + [val_data[-1]] 
       else: 
        ins_batch = _slice_arrays(val_data, batch_ids) #original: slice_X(val_data, batch_ids) 
      else: 
       # Generators yield one batch at a time and don't provide batch_size 
       ins_batch = val_data 
      my_feed_dict = dict(zip(tensors, ins_batch)) 

      result = tf_session.run([self.merged], feed_dict=my_feed_dict) 
      #result = self.sess.run([self.merged], feed_dict=my_feed_dict) 
      summary_str = result[0] 
      self.writer.add_summary(summary_str, epoch) 

    for name, value in logs.items(): 
     if name in ['batch', 'size']: 
      continue 
     summary = tf.Summary() 
     summary_value = summary.value.add() 
     summary_value.simple_value = value.item() 
     summary_value.tag = name 
     self.writer.add_summary(summary, epoch) 
    self.writer.flush() 

def on_train_end(self, _): 
    self.writer.close() 

Antwort

4

Es scheint, dass die slice_X() nicht mehr existiert, aber es ist eine interne Funktion in keras.engine.training: _slice_array(), die die Arbeit des Schneidens der Fall ist. Siehe code here.

Wenn Sie weitere Fragen haben, zögern Sie nicht.

EDIT:

Hier sind die beiden Funktionen. Der alte:

def slice_X(X, start=None, stop=None): 
    """This takes an array-like, or a list of 
    array-likes, and outputs: 
     - X[start:stop] if X is an array-like 
     - [x[start:stop] for x in X] if X in a list 
    Can also work on list/array of indices: `slice_X(x, indices)` 
    # Arguments 
     start: can be an integer index (start index) 
      or a list/array of indices 
     stop: integer (stop index); should be None if 
      `start` was a list. 
    """ 

die neue:

def _slice_arrays(arrays, start=None, stop=None): 
    """Slice an array or list of arrays. 
    This takes an array-like, or a list of 
    array-likes, and outputs: 
     - arrays[start:stop] if `arrays` is an array-like 
     - [x[start:stop] for x in arrays] if `arrays` is a list 
    Can also work on list/array of indices: `_slice_arrays(x, indices)` 
    # Arguments 
     arrays: Single array or list of arrays. 
     start: can be an integer index (start index) 
      or a list/array of indices 
     stop: integer (stop index); should be None if 
      `start` was a list. 
    # Returns 
     A slice of the array(s). 
    """ 

Was hier zu verstehen ist, dass sie im Grunde verändert nur den Namen. Sie bearbeiten Vladimirs Code nur, indem Sie slice_X() durch _slice_arrays() mit den gleichen Argumenten ändern. Ändern Sie auch den Import zu

from keras.engine.training import _slice_arrays 

Ich hoffe, es funktioniert jetzt.

+0

Vielen Dank für Ihre freundliche Hilfe! Um ehrlich zu sein, ich weiß nicht, was diese 'slice_X()' Funktion tat. Ich behandelte den Code mit 'slice_X()' als eine Blackbox. Jetzt fühle ich mich ziemlich unwohl, die alte Funktion mit dem neuen '_slice_array()' nachzuahmen, da der alte Hut 2 params und der neue 3 params ... :( –

+0

vielleicht das Stück des Codes postet, wo es benutzt wird? –

+0

Es ist der modifizierte TensorBoard Callback von Vladimir Yashin. [Siehe hier] (https://github.com/Vladimir-Yashin/keras/blob/13e6a1f99f33a3cc7bc0a44d285fda457cc808e4/keras/callbacks.py) –