2017-03-17 7 views
0

Ich habe versucht, mxnet aus dem Tutorial zu lernen, während das Laden von Daten Ich bekomme 'int' hat nicht 'getitem', aber ich bin nicht in der Lage finden sich die Position des Fehlers helfen mir bitte dank:'int' Objekt hat kein Attribut '__getitem__' mxnet

import mxnet as mx 
import numpy as np 

class SimpleData : 
    def __init__(self,data,label,pad = 0): 
     self.data = data 
     self.label = label 
     self.pad = pad 

class SimpleIter: 
     def __init__(self,mean,std,data_shape,label_shape,num_of_classes,num_batch = 10): 
     self._provide_data = zip(['data'],data_shape[0]) 
     self._provide_label = zip(['softmax_label'],label_shape[0]) 
     self.cur_batch = 0 
     self.num_batch = 10 
     self.mean = mean 
     self.std = std 
     self.data_shape = data_shape[0] 
     self.label_shape = label_shape[0] 
     self.num_of_classes = num_of_classes 

    def __iter__(self): 
     return self 

    def __next__(self): 
     return self.next() 

    def reset(self): 
     self.cur_batch = 0 

    @property 
    def provide_data(self): 
     return self._provide_data 

    @property 
    def provide_label(self): 
     return self._provide_label 

    def next(self): 
     if(self.cur_batch < self.num_batch): 
      self.cur_batch += 1 
      data = [mx.nd.array(np.random.normal(self.mean,self.std, ((self.data_shape)[0][0]/self.num_batch,self.data_shape[0][1])))] 
      label = [mx.nd.array(np.random.randint(0,10, ((self.data_shape)[0][1]/self.num_batch)))] 
      print data 
      print label 
      return SimpleBatch(data,label) 
     else: 
      raise StopIteration 

class SyntheticData: 
    def  __init__(self,mean,std,num_records,num_of_features,num_classes): 
     self.mean = mean 
     self.std = std 
     self.data_shape = zip(num_records,num_of_features) 
     self.label_shape = zip(num_records,) 
     self.num_classes = num_classes 

     def get_iter(self): 
      return  SimpleIter(self.mean,self.std,self.data_shape,self.label_shape,self.num_classes) 
net = mx.sym.Variable('data') 
net = mx.sym.FullyConnected(data = net,name = 'fc1',num_hidden = 64) 
net = mx.sym.Activation(data = net,name = 'relu_1',act_type = 'relu') 
net = mx.sym.FullyConnected(data = net,name = 'fc2',num_hidden = 10) 
net = mx.sym.SoftmaxOutput(data = net,name = 'softmax') 
data = SyntheticData(10,128,[100],[100],10) 
mod.fit(data.get_iter(), 
    eval_data=data.get_iter(), 
    optimizer='sgd', 
    optimizer_params={'learning_rate':0.1}, 
    eval_metric='acc', 
    num_epoch = 5) 

der Fehler ist:

TypeError         Traceback (most recent call last) 
<ipython-input-273-a7375f022406> in <module>() 
     4   optimizer_params={'learning_rate':0.1}, 
     5   eval_metric='acc', 
----> 6   num_epoch = 5) 

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/base_module.pyc in fit(self, train_data, eval_data, eval_metric, epoch_end_callback, batch_end_callback, kvstore, optimizer, optimizer_params, eval_end_callback, eval_batch_end_callback, initializer, arg_params, aux_params, allow_missing, force_rebind, force_init, begin_epoch, num_epoch, validation_metric, monitor) 
    440 
    441   self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, 
--> 442     for_training=True, force_rebind=force_rebind) 
    443   if monitor is not None: 
    444    self.install_monitor(monitor) 

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/module.pyc in bind(self, data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind, shared_module, grad_req) 
    386              fixed_param_names=self._fixed_param_names, 
    387              grad_req=grad_req, 
--> 388              state_names=self._state_names) 
    389   self._total_exec_bytes = self._exec_group._total_exec_bytes 
    390   if shared_module is not None: 

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/executor_group.pyc in __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names, for_training, inputs_need_grad, shared_group, logger, fixed_param_names, grad_req, state_names) 
    203        for name in self.symbol.list_outputs()] 
    204 
--> 205   self.bind_exec(data_shapes, label_shapes, shared_group) 
    206 
    207  def decide_slices(self, data_shapes): 

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/executor_group.pyc in bind_exec(self, data_shapes, label_shapes, shared_group, reshape) 
    282 
    283   # calculate workload and bind executors 
--> 284   self.data_layouts = self.decide_slices(data_shapes) 
    285   if label_shapes is not None: 
    286    # call it to make sure labels has the same batch size  as data 

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4- py2.7.egg/mxnet/module/executor_group.pyc in decide_slices(self,  data_shapes) 
     220     continue 
     221 
-->  222    batch_size = shape[axis] 
     223    if self.batch_size is not None: 
     224     assert batch_size == self.batch_size, ("all data  must have the same batch size: " 

TypeError: 'int' object has no attribute '__getitem__' 
+0

Sie scheinen nicht den Code zu zeigen, der den Fehler tatsächlich erzeugte (der obere Abschnitt des Traceback). Irgendwo sieht es so aus, als würden Sie eine 'fit()' Methode aufrufen, und es sieht so aus, als ob der erste Parameter ('train_data') nicht das ist, was er erwartet. – glibdud

+0

oh sorry danke, dass ich es bemerkt habe, jetzt habe ich den kompletten Code hinzugefügt – adithya

+0

Was ist 'mod'? Auch die Argumente für die Fit-Funktion sind nicht korrekt. Können Sie ein Beispiel von mxnet github auswählen und es an Ihre Bedürfnisse anpassen? Hier ist ein Beispiel für Daten-Iterator, wenn das ist, was Sie suchen: https://github.com/dmlc/mxnet/blob/master/example/recommenders/movielen_data.py –

Antwort

0

ich denke, Ihr Problem in der Definition Ihrer data_shape ist.

self.data_shape = data_shape[0]

Wie Sie es definiert, self.data_shape ist nur ein int. In Ihrem Fall, ich denke, es einfach sein sollte:

self.data_shape = data_shape

So dass, wenn shape[axis] von decide_slices zugegriffen wird die Anzahl der Elemente erhalten.

+0

es klappte, danke – adithya

Verwandte Themen