matplotlib中的色条将0到1之间的数字映射为一种颜色。为了将其他数字映射为颜色,您需要先将其归一化到范围[0,1]。这通常是从最小和最大数据自动完成的,或者通过使用相应绘图函数的vmin和vmax参数来完成。在内部,使用一个归一化实例matplotlib.colors.Normalize来执行归一化,默认情况下假定vmin和vmax之间的线性比例尺。
在这里,您需要一个非线性比例尺,它可以(a)将中间点移动到某个指定值,并且(b)压缩该值周围的颜色。
现在的想法是子类化matplotlib.colors.Normalize并让它返回满足条件(a)和(b)的映射。
一种选项可能是组合两个根函数,如下所示。
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
class SqueezedNorm(matplotlib.colors.Normalize):
def __init__(self, vmin=None, vmax=None, mid=0, s1=2, s2=2, clip=False):
self.vmin = vmin
self.mid = mid
self.vmax = vmax
self.s1=s1; self.s2=s2
f = lambda x, zero,vmax,s: np.abs((x-zero)/(vmax-zero))**(1./s)*0.5
self.g = lambda x, zero,vmin,vmax, s1,s2: f(x,zero,vmax,s1)*(x>=zero) - \
f(x,zero,vmin,s2)*(x<zero)+0.5
matplotlib.colors.Normalize.__init__(self, vmin, vmax, clip)
def __call__(self, value, clip=None):
r = self.g(value, self.mid,self.vmin,self.vmax, self.s1,self.s2)
return np.ma.masked_array(r)
fig, (ax, ax2, ax3) = plt.subplots(nrows=3,
gridspec_kw={"height_ratios":[3,2,1], "hspace":0.25})
x = np.linspace(-13,4, 110)
norm=SqueezedNorm(vmin=-13, vmax=4, mid=0, s1=1.7, s2=4)
line, = ax.plot(x, norm(x))
ax.margins(0)
ax.set_ylim(0,1)
im = ax2.imshow(np.atleast_2d(x).T, cmap="Spectral_r", norm=norm, aspect="auto")
cbar = fig.colorbar(im ,cax=ax3,ax=ax2, orientation="horizontal")
![enter image description here](https://istack.dev59.com/d8FPe.webp)
该函数被选择为独立于其参数将任何范围映射到范围
[0,1]
,以便可以使用颜色地图。参数
mid
确定应将哪个值映射到颜色地图的中间。在这种情况下,这将是
0
。参数
s1
和
s2
确定颜色地图在两个方向上的压缩程度。
设置
mid = np.mean(vmin,vmax),s1 = 1,s2 = 1
将恢复原始比例。
![enter image description here](https://istack.dev59.com/Jt0ka.webp)
为了选择好的参数,可以使用一些滑块来查看实时更新的图表。
![enter image description here](https://istack.dev59.com/50cy8.webp)
from matplotlib.widgets import Slider
midax = plt.axes([0.1, 0.04, 0.2, 0.03], facecolor="lightblue")
s1ax = plt.axes([0.4, 0.04, 0.2, 0.03], facecolor="lightblue")
s2ax = plt.axes([0.7, 0.04, 0.2, 0.03], facecolor="lightblue")
mid = Slider(midax, 'Midpoint', x[0], x[-1], valinit=0)
s1 = Slider(s1ax, 'S1', 0.5, 6, valinit=1.7)
s2 = Slider(s2ax, 'S2', 0.5, 6, valinit=4)
def update(val):
norm=SqueezedNorm(vmin=-13, vmax=4, mid=mid.val, s1=s1.val, s2=s2.val)
im.set_norm(norm)
cbar.update_bruteforce(im)
line.set_ydata(norm(x))
fig.canvas.draw_idle()
mid.on_changed(update)
s1.on_changed(update)
s2.on_changed(update)
fig.subplots_adjust(bottom=0.15)