2017-02-26 5 views
6

Ich bin neu in Python & tensorFlow, und ich bin auf this MNIST tutorial auf TensorFlow Dokumentation folgen.FLAGS = Keine Bedeutung?

Im ersten Bit, ich weiß nicht, was FLAGS = None hier tut. Ich suchte in Google und kam leer zurück. Scheint, als ob das für andere zu offensichtlich ist?

from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 

import argparse 
import sys 

from tensorflow.examples.tutorials.mnist import input_data 

import tensorflow as tf 

FLAGS = None 


def main(_): 
    # Import data 
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) 

Also was ist FLAGS und wie es benutzt wird? z. B. FLAGS.data_dir

Jede Hilfe wäre willkommen!

Antwort

3

initialisieren FLAGS=None ist nur ein Weg, Initialisierung der globalen Konstante. Wenn es so belassen wird wie es ist, wird es einen Fehler in main auslösen, da None keine Attribute hat.

Aber wenn über ein argparseparser wie in den volleren Beispielen gezeigt eingestellt, ist es ein einfaches Objekt mit einer Vielzahl von Attributen. main nimmt an, dass eine dieser Attribute data_dir heißt.

Wenn nach dem

FLAGS, unparsed = parser.parse_known_args() 
print(FLAGS) 

Sie Namespace(data_dir='a directory', ....) sehen sollte, wobei der Wert für data_dir von der Kommandozeile analysiert wurde.

4

Dies war der vollständige Code Sie suchen an: Ich werde erklären:

from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 

import argparse 
import sys 

from tensorflow.examples.tutorials.mnist import input_data 

import tensorflow as tf 

FLAGS = None #Adds a default value to FLAGS 


def main(_): #Everything inside the function is not checked until it's called 
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) #FLAGS is not None anymore because it got changed below 

    x = tf.placeholder(tf.float32, [None, 784]) 
    W = tf.Variable(tf.zeros([784, 10])) 
    b = tf.Variable(tf.zeros([10])) 
    y = tf.matmul(x, W) + b 

    y_ = tf.placeholder(tf.float32, [None, 10]) 

    cross_entropy = tf.reduce_mean(
     tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) 
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 

    sess = tf.InteractiveSession() 
    tf.global_variables_initializer().run() 
    # Train 
    for _ in range(1000): 
    batch_xs, batch_ys = mnist.train.next_batch(100) 
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 

    # Test trained model 
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
    print(sess.run(accuracy, feed_dict={x: mnist.test.images, 
             y_: mnist.test.labels})) 

if __name__ == '__main__': 
    parser = argparse.ArgumentParser() 
    parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', 
         help='Directory for storing input data') 

    FLAGS, unparsed = parser.parse_known_args() #Here it changed the value of FLAGS to the first thing returned from parser.parse_known_args() 

    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) #runs the app (calling main) 

passiert, was ist, dass FLAGS hier geändert wurde: FLAGS, unparsed = parser.parse_known_args()