Hier ist eine wiederverwendbare Funktion basierend auf dem Code von @ Ophir-carmi:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
import itertools
import numpy as np
def gridshow(grid_x, grid_y, data, **kwargs):
vmin = kwargs.pop("vmin", None)
vmax = kwargs.pop("vmax", None)
data = np.array(data).reshape(-1)
# there should be data for (n-1)x(m-1) cells
assert (grid_x.shape[0] - 1) * (grid_y.shape[0] - 1) == data.shape[0], "Wrong number of data points. grid_x=%s, grid_y=%s, data=%s" % (grid_x.shape, grid_y.shape, data.shape)
ptchs = []
for j, i in itertools.product(xrange(len(grid_y) - 1), xrange(len(grid_x) - 1)):
xy = grid_x[i], grid_y[j]
width = grid_x[i+1] - grid_x[i]
height = grid_y[j+1] - grid_y[j]
ptchs.append(Rectangle(xy=xy, width=width, height=height, rasterized=True, linewidth=0, linestyle="None"))
p = PatchCollection(ptchs, linewidth=0, **kwargs)
p.set_array(np.array(data))
p.set_clim(vmin, vmax)
ax = plt.gca()
ax.set_aspect("equal")
plt.xlim([grid_x[0], grid_x[-1]])
plt.ylim([grid_y[0], grid_y[-1]])
ret = ax.add_collection(p)
plt.sci(ret)
return ret
if __name__ == "__main__":
grid_x = np.linspace(0, 20, 21) + np.random.randn(21)/5.0
grid_y = np.linspace(0, 18, 19) + np.random.randn(19)/5.0
grid_x = np.round(grid_x, 2)
grid_y = np.round(grid_y, 2)
data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1))
fig = plt.figure()
ax = fig.add_subplot(111)
gridshow(grid_x, grid_y, data, alpha=1.0)
plt.savefig("test.png")
Ich bin nicht ganz sicher über die Leistung für große Netze und wenn die **kwargs
sollte auf PatchCollection
angewendet werden. Und zwischen einigen Rechtecken scheint 1px Leerzeichen zu sein, wahrscheinlich aufgrund schlechter Rundung. Vielleicht benötigt die dx, width, height
eine konsistente floor
/ceil
zum nächsten vollen Pixel.
Eine andere Lösung mit rtree
und imshow
:
import matplotlib.pyplot as plt
import numpy as np
from rtree import index
def gridshow(grid_x, grid_y, data, rows=200, cols=200, eps=1e-3, **kwargs):
grid_x1, grid_y1 = np.meshgrid(grid_x, grid_y)
grid_x2 = grid_x1[:-1, :-1].flat
grid_y2 = grid_y1[:-1, :-1].flat
grid_x3 = grid_x1[1:, 1:].flat
grid_y3 = grid_y1[1:, 1:].flat
grid_j = np.linspace(grid_x[0], grid_x[-1], cols)
grid_i = np.linspace(grid_y[0], grid_y[-1], rows)
j, i = np.meshgrid(grid_j, grid_i)
i = i.flat
j = j.flat
im = np.empty((rows, cols), dtype=np.float64)
idx = index.Index()
for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)):
idx.insert(m, (x0, y0, x1, y1))
for k, (i0, j0) in enumerate(zip(i, j)):
ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps)))
im[np.unravel_index(k, im.shape)] = data[np.unravel_index(ind, data.shape)]
fig = plt.gca()
return plt.imshow(im, interpolation="nearest")
if __name__ == "__main__":
grid_x = np.linspace(0, 200, 201) + np.random.randn(201)/5.0
grid_y = np.linspace(0, 108, 109) + np.random.randn(109)/5.0
grid_x = np.round(grid_x, 2)
grid_y = np.round(grid_y, 2)
data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1))
fig = plt.figure()
ax = fig.add_subplot(111)
gridshow(grid_x, grid_y, data, alpha=1.0)
plt.savefig("test.png")
Was ist die Bedeutung von 'grid_x' ist? Soll das zweite 'grid_x'' grid_y' sein? –
Die '' grid''-Arrays sind die Schritte auf den x/y-Achsen. So entspricht ein Punkt "A [i, j]" der Position "grid_x [i], grid_y [j]" auf dem Gitter (oder dem Rechteck "grid_x [i: i + 1], grid_y [ j: j + 1] ''). – allo
Wenn die Gitterschritte Punkte definieren, haben Sie '' n + 1'' Einträge in '' grid_x'' für '' n'' Rechtecke von links nach rechts. '' imshow'' zeigt nur ein Quadrat für jeden Punkt, also dort können Sie die Einträge von '' A'' entweder als die oberen linken Punkte der Quadrate oder als die mittleren Punkte sehen, solange sie äquidistant sind, tut es nicht Es ist egal. – allo