Vielleicht sind die Beschriftungen Ihres Datasets One-Hot-Vektoren. (In diesem Fall verwende ich Mnist-Dataset. Siehe auch https://www.tensorflow.org/versions/r0.7/tutorials/mnist/beginners/index.html)
In [1]: from tensorflow.examples.tutorials.mnist import input_data
In [2]: mnist = input_data.read_data_sets("MNIST_DATA/", one_hot=True)
Extracting MNIST_DATA/train-images-idx3-ubyte.gz
Extracting MNIST_DATA/train-labels-idx1-ubyte.gz
Extracting MNIST_DATA/t10k-images-idx3-ubyte.gz
Extracting MNIST_DATA/t10k-labels-idx1-ubyte.gz
In [3]: mnist.train.labels
Out[3]:
array([[ 0., 0., 0., ..., 1., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
...,
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 1., 0.]])
In [4]: mnist.train.labels.shape
Out[4]: (55000, 10)
In [5]: import skflow
In [6]: classifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 10, 10], n_classes=10, batch_size=100, learning_rate=0.05)
In [7]: classifier.fit(mnist.train.images, mnist.train.labels)
Dann habe ich den gleichen Fehler.
ValueError: Shapes (?, 10) and (?, 10, 10) must have the same rank
Aber skflow geht davon aus, dass Lables Zahlen zwischen 0 und 9 (one_hot = False)
In [5]: mnist = input_data.read_data_sets("MNIST_DATA/", one_hot=False)
Extracting MNIST_DATA/train-images-idx3-ubyte.gz
Extracting MNIST_DATA/train-labels-idx1-ubyte.gz
Extracting MNIST_DATA/t10k-images-idx3-ubyte.gz
Extracting MNIST_DATA/t10k-labels-idx1-ubyte.gz
In [6]: mnist.train.labels
Out[6]: array([7, 3, 4, ..., 5, 6, 8], dtype=uint8)
In [7]: classifier.fit(mnist.train.images, mnist.train.labels)
Step #99, avg. train loss: 2.31658
Step #199, avg. train loss: 1.63361
Out[7]:
TensorFlowDNNClassifier(batch_size=100, class_weight=None,
config_addon=<skflow.addons.config_addon.ConfigAddon object at 0x11cf7eb90>,
continue_training=False, hidden_units=[10, 10, 10],
keep_checkpoint_every_n_hours=10000, learning_rate=0.05,
max_to_keep=5, n_classes=10, optimizer='SGD', steps=200,
tf_master='', tf_random_seed=42, verbose=1)
Bitte geben Sie ihm einen Versuch.
Welche Versionen von Tensorflow und Skflow verwenden Sie? – user728291
tensorflow 0.71 – hmmmbob
tensorflow 0.71 und skflow 0.10 – hmmmbob