2016-07-12 5 views
1

ich derzeit matplotlib.pyplot verwenden einige 2D-Daten zu visualisieren:pyplot.imshow für Rechtecken

from matplotlib import pyplot as plt 
import numpy as np 
A=np.matrix("1 2 1;3 0 3;1 2 0") # 3x3 matrix with 2D data 
plt.imshow(A, interpolation="nearest") # draws one square per matrix entry 
plt.show() 

Jetzt bewegte ich die Daten von den Quadraten zu Rechtecke, ich habe zwei zusätzliche Arrays bedeuten zum Beispiel:

grid_x = np.array([0.0, 1.0, 4.0, 5.0]) # points on the x-axis 
grid_x = np.array([0.0, 2.5, 4.0, 5.0]) # points on the y-axis 

jetzt will ich ein Raster mit Rechtecke:

  • linke obere Ecke: (grid_x[i], grid_y[j])
  • rechtse untere Ecke: (grid_x[i+1], grid_y[j+1])
  • Daten (Farbe): A[i,j]

Was ist eine einfache Möglichkeit, die Daten auf das neue Raster zu zeichnen? imshow scheint verwendbar zu sein, schaute ich auf pcolormesh, aber es ist verwirrend mit dem Gitter als 2D-Array, mit zwei Matrizen wie np.mgrid[0:5:0.5,0:5:0.5] für das regelmäßige Gitter und Aufbau etwas ähnliches für die unregelmäßige.

Was ist ein einfacher Weg für die Visualisierung der Rechtecke?

+0

Was ist die Bedeutung von 'grid_x' ist? Soll das zweite 'grid_x'' grid_y' sein? –

+0

Die '' grid''-Arrays sind die Schritte auf den x/y-Achsen. So entspricht ein Punkt "A [i, j]" der Position "grid_x [i], grid_y [j]" auf dem Gitter (oder dem Rechteck "grid_x [i: i + 1], grid_y [ j: j + 1] ''). – allo

+0

Wenn die Gitterschritte Punkte definieren, haben Sie '' n + 1'' Einträge in '' grid_x'' für '' n'' Rechtecke von links nach rechts. '' imshow'' zeigt nur ein Quadrat für jeden Punkt, also dort können Sie die Einträge von '' A'' entweder als die oberen linken Punkte der Quadrate oder als die mittleren Punkte sehen, solange sie äquidistant sind, tut es nicht Es ist egal. – allo

Antwort

1
import matplotlib.pyplot as plt 
from matplotlib.patches import Rectangle 
import matplotlib.cm as cm 
from matplotlib.collections import PatchCollection 
import numpy as np 

A = np.matrix("1 2 1;3 0 3;1 2 0;4 1 2") # 4x3 matrix with 2D data 

grid_x0 = np.array([0.0, 1.0, 4.0, 6.7]) 
grid_y0 = np.array([0.0, 2.5, 4.0, 7.8, 12.4]) 

grid_x1, grid_y1 = np.meshgrid(grid_x0, grid_y0) 
grid_x2 = grid_x1[:-1, :-1].flat 
grid_y2 = grid_y1[:-1, :-1].flat 
widths = np.tile(np.diff(grid_x0)[np.newaxis], (len(grid_y0)-1, 1)).flat 
heights = np.tile(np.diff(grid_y0)[np.newaxis].T, (1, len(grid_x0)-1)).flat 

fig = plt.figure() 
ax = fig.add_subplot(111) 
ptchs = [] 
for x0, y0, w, h in zip(grid_x2, grid_y2, widths, heights): 
    ptchs.append(Rectangle(
     (x0, y0), w, h, 
    )) 
p = PatchCollection(ptchs, cmap=cm.viridis, alpha=0.4) 
p.set_array(np.ravel(A)) 
ax.add_collection(p) 
plt.xlim([0, 8]) 
plt.ylim([0, 13]) 
plt.show() 

enter image description here

Hier ist eine andere Art und Weise, mit Bild und R-Baum und imshow mit colorbar, müssen Sie das ändern x-ticks und y-ticks (Es gibt eine Menge von SO Q & A darüber, wie es zu tun).

from rtree import index 
import matplotlib.pyplot as plt 
import numpy as np 

eps = 1e-3 

A = np.matrix("1 2 1;3 0 3;1 2 0;4 1 2") # 4x3 matrix with 2D data 
grid_x0 = np.array([0.0, 1.0, 4.0, 6.7]) 
grid_y0 = np.array([0.0, 2.5, 4.0, 7.8, 12.4]) 

grid_x1, grid_y1 = np.meshgrid(grid_x0, grid_y0) 
grid_x2 = grid_x1[:-1, :-1].flat 
grid_y2 = grid_y1[:-1, :-1].flat 
grid_x3 = grid_x1[1:, 1:].flat 
grid_y3 = grid_y1[1:, 1:].flat 

fig = plt.figure() 

rows = 100 
cols = 200 
im = np.zeros((rows, cols), dtype=np.int8) 
grid_j = np.linspace(grid_x0[0], grid_x0[-1], cols) 
grid_i = np.linspace(grid_y0[0], grid_y0[-1], rows) 
j, i = np.meshgrid(grid_j, grid_i) 

i = i.flat 
j = j.flat 

idx = index.Index() 

for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)): 
    idx.insert(m, (x0, y0, x1, y1)) 


for k, (i0, j0) in enumerate(zip(i, j)): 
    ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps))) 

    im[np.unravel_index(k, im.shape)] = A[np.unravel_index(ind, A.shape)] 
plt.imshow(im) 
plt.colorbar() 
plt.show() 

enter image description here

+0

Sie zeichnen nur zwei Rechtecke, während das 3x3 Gitter 9 haben sollte (mit den Werten aus der Matrix '' A'' in meinem Beispiel). du zeichnest nie diagonale Punkte mit '' zip'' auf den beiden Achsen ab. – allo

+0

sehe meine bearbeitete Antwort –

+0

Sieht mehr wie ich, was ich wollte. Ich habe gerade etwas Ähnliches gemacht, werde in einem Moment einige gebrauchsfertige Funktionen veröffentlichen. – allo

1

Hier ist eine wiederverwendbare Funktion basierend auf dem Code von @ Ophir-carmi:

import matplotlib.pyplot as plt 
from matplotlib.patches import Rectangle 
from matplotlib.collections import PatchCollection 
import itertools 
import numpy as np 

def gridshow(grid_x, grid_y, data, **kwargs): 
    vmin = kwargs.pop("vmin", None) 
    vmax = kwargs.pop("vmax", None) 
    data = np.array(data).reshape(-1) 
    # there should be data for (n-1)x(m-1) cells 
    assert (grid_x.shape[0] - 1) * (grid_y.shape[0] - 1) == data.shape[0], "Wrong number of data points. grid_x=%s, grid_y=%s, data=%s" % (grid_x.shape, grid_y.shape, data.shape) 
    ptchs = [] 
    for j, i in itertools.product(xrange(len(grid_y) - 1), xrange(len(grid_x) - 1)): 
     xy = grid_x[i], grid_y[j] 
     width = grid_x[i+1] - grid_x[i] 
     height = grid_y[j+1] - grid_y[j] 
     ptchs.append(Rectangle(xy=xy, width=width, height=height, rasterized=True, linewidth=0, linestyle="None")) 
    p = PatchCollection(ptchs, linewidth=0, **kwargs) 
    p.set_array(np.array(data)) 
    p.set_clim(vmin, vmax) 
    ax = plt.gca() 
    ax.set_aspect("equal") 
    plt.xlim([grid_x[0], grid_x[-1]]) 
    plt.ylim([grid_y[0], grid_y[-1]]) 
    ret = ax.add_collection(p) 
    plt.sci(ret) 
    return ret 

if __name__ == "__main__": 
    grid_x = np.linspace(0, 20, 21) + np.random.randn(21)/5.0 
    grid_y = np.linspace(0, 18, 19) + np.random.randn(19)/5.0 
    grid_x = np.round(grid_x, 2) 
    grid_y = np.round(grid_y, 2) 
    data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1)) 

    fig = plt.figure() 
    ax = fig.add_subplot(111) 
    gridshow(grid_x, grid_y, data, alpha=1.0) 
    plt.savefig("test.png") 

example image from the <code>gridshow</code> function

Ich bin nicht ganz sicher über die Leistung für große Netze und wenn die **kwargs sollte auf PatchCollection angewendet werden. Und zwischen einigen Rechtecken scheint 1px Leerzeichen zu sein, wahrscheinlich aufgrund schlechter Rundung. Vielleicht benötigt die dx, width, height eine konsistente floor/ceil zum nächsten vollen Pixel.

Eine andere Lösung mit rtree und imshow:

import matplotlib.pyplot as plt 
import numpy as np 
from rtree import index 

def gridshow(grid_x, grid_y, data, rows=200, cols=200, eps=1e-3, **kwargs): 
    grid_x1, grid_y1 = np.meshgrid(grid_x, grid_y) 
    grid_x2 = grid_x1[:-1, :-1].flat 
    grid_y2 = grid_y1[:-1, :-1].flat 
    grid_x3 = grid_x1[1:, 1:].flat 
    grid_y3 = grid_y1[1:, 1:].flat 

    grid_j = np.linspace(grid_x[0], grid_x[-1], cols) 
    grid_i = np.linspace(grid_y[0], grid_y[-1], rows) 
    j, i = np.meshgrid(grid_j, grid_i) 
    i = i.flat 
    j = j.flat 

    im = np.empty((rows, cols), dtype=np.float64) 

    idx = index.Index() 
    for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)): 
     idx.insert(m, (x0, y0, x1, y1)) 
    for k, (i0, j0) in enumerate(zip(i, j)): 
     ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps))) 
     im[np.unravel_index(k, im.shape)] = data[np.unravel_index(ind, data.shape)] 

    fig = plt.gca() 
    return plt.imshow(im, interpolation="nearest") 


if __name__ == "__main__": 
    grid_x = np.linspace(0, 200, 201) + np.random.randn(201)/5.0 
    grid_y = np.linspace(0, 108, 109) + np.random.randn(109)/5.0 
    grid_x = np.round(grid_x, 2) 
    grid_y = np.round(grid_y, 2) 
    data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1)) 

    fig = plt.figure() 
    ax = fig.add_subplot(111) 
    gridshow(grid_x, grid_y, data, alpha=1.0) 
    plt.savefig("test.png") 

image with imshow and rtree

+0

Sie können immer ein Bild mit ausreichender Auflösung und R-Struktur verwenden, um seine Werte zu ändern und dann 'imshow'. –

+0

Immer noch daran arbeiten, scheint die Indizierung nicht richtig. Weißt du, wie man es in eine Funktion wie "imshow" steckt? Zum Beispiel funktioniert colorbar nicht '' RuntimeError: Sie müssen zuerst ein Bild definieren, zB mit imshow''. Wenn möglich, sollte es auch nicht notwendig sein, ein eigenes Teilplot zu definieren, um die Achsen zu erhalten. – allo

+0

Die '' imshow'' Methode funktioniert jetzt für mich. Nimmt 3-6 Sekunden für eine Datenauflösung von 200x100 hier, wird für größere Auflösungen viel schlechter. Sollte aber wahrscheinlich für meine Zwecke ausreichen. – allo

Verwandte Themen