2012-12-06 9 views
8

Standardmäßig wird durch das Beizen eines numpy View-Arrays die View-Beziehung verloren, auch wenn die Array-Basis ebenfalls gebeizt ist. Meine Situation ist, dass ich einige komplexe Containerobjekte habe, die gebeizt werden. Und in einigen Fällen sind einige enthaltene Daten Ansichten in einigen anderen. Das Speichern eines unabhängigen Arrays jeder Ansicht ist nicht nur ein Platzverlust, sondern die neu geladenen Daten haben auch die Sichtbeziehung verloren.Speichern der numpigen Ansicht beim Beizen

Ein einfaches Beispiel wäre (aber in meinem Fall sind die Behälter komplexer als ein Wörterbuch):

d1 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])} 
d1 after: {'a': array([ 1., 1.]), 'b': array([ 1., 1.])} 
d2 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])} 
d2 after: {'a': array([ 0., 0.]), 'b': array([ 1., 1.])} # not a view anymore 

Meine Frage:

import numpy as np 
import cPickle 

tmp = np.zeros(2) 
d1 = dict(a=tmp,b=tmp[:]) # d1 to be saved: b is a view on a 

pickled = cPickle.dumps(d1) 
d2 = cPickle.loads(pickled) # d2 reloaded copy of d1 container 

print 'd1 before:', d1 
d1['b'][:] = 1 
print 'd1 after: ', d1 

print 'd2 before:', d2 
d2['b'][:] = 1 
print 'd2 after: ', d2 

, die gedruckt werden würden

(1) Gibt es einen Weg, es zu bewahren? (2) (noch besser) ist es eine Möglichkeit, es nur zu tun, wenn die Basis

Für das gebeizte (1) Ich denke, es die __setstate__, eine Möglichkeit sein kann __reduce_ex_ durch Veränderung, etc ... von der Array anzeigen. Aber ich traue mich erst jetzt nicht damit. Für die (2) habe ich keine Ahnung.

Antwort

7

Dies wird nicht in NumPy richtig gemacht, da es nicht immer sinnvoll ist, das Basisarray zu beizen, und Pickle bietet nicht die Möglichkeit zu prüfen, ob ein anderes Objekt als Teil seiner API auch gebeizt wird.

Diese Überprüfung kann jedoch in einem benutzerdefinierten Container für NumPy-Arrays durchgeführt werden. Zum Beispiel:

import numpy as np 
import pickle 

def byte_offset(array, source): 
    return array.__array_interface__['data'][0] - np.byte_bounds(source)[0] 

class SharedPickleList(object): 
    def __init__(self, arrays): 
     self.arrays = list(arrays) 

    def __getstate__(self): 
     unique_ids = {id(array) for array in self.arrays} 
     source_arrays = {} 
     view_tuples = {} 
     for array in self.arrays: 
      if array.base is None or id(array.base) not in unique_ids: 
       # only use views if the base is also being pickled 
       source_arrays[id(array)] = array 
      else: 
       view_tuples[id(array)] = (array.shape, 
              array.dtype, 
              id(array.base), 
              byte_offset(array, array.base), 
              array.strides) 
     order = [id(array) for array in self.arrays] 
     return (source_arrays, view_tuples, order) 

    def __setstate__(self, state): 
     source_arrays, view_tuples, order = state 
     view_arrays = {} 
     for k, view_state in view_tuples.items(): 
      (shape, dtype, source_id, offset, strides) = view_state 
      buffer = source_arrays[source_id].data 
      array = np.ndarray(shape, dtype, buffer, offset, strides) 
      view_arrays[k] = array 
     self.arrays = [source_arrays[i] 
         if i in source_arrays 
         else view_arrays[i] 
         for i in order] 

# unit tests 
def check_roundtrip(arrays): 
    unpickled_arrays = pickle.loads(pickle.dumps(
     SharedPickleList(arrays))).arrays 
    assert all(a.shape == b.shape and (a == b).all() 
       for a, b in zip(arrays, unpickled_arrays)) 

indexers = [0, None, slice(None), slice(2), slice(None, -1), 
      slice(None, None, -1), slice(None, 6, 2)] 

source0 = np.random.randint(100, size=10) 
arrays0 = [np.asarray(source0[k1]) for k1 in indexers] 
check_roundtrip([source0] + arrays0) 

source1 = np.random.randint(100, size=(8, 10)) 
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers] 
check_roundtrip([source1] + arrays1) 

Dies führt zu erheblichen Platzersparnis:

source = np.random.rand(1000) 
arrays = [source] + [source[n:] for n in range(99)] 
print(len(pickle.dumps(arrays, protocol=-1))) 
# 766372 
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1))) 
# 11833 
Verwandte Themen