我一直在使用以下函数,它增强了输入数据以创建具有有效周期边界条件的数据。增强数据相对于修改现有算法具有明显优势:增强数据可以使用任何算法轻松地进行插值。请参见下面的示例。
def augment_with_periodic_bc(points, values, domain):
"""
Augment the data to create periodic boundary conditions.
Parameters
----------
points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, )
The points defining the regular grid in n dimensions.
values : array_like, shape (m1, ..., mn, ...)
The data on the regular grid in n dimensions.
domain : float or None or array_like of shape (n, )
The size of the domain along each of the n dimenions
or a uniform domain size along all dimensions if a
scalar. Using None specifies aperiodic boundary conditions.
Returns
-------
points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, )
The points defining the regular grid in n dimensions with
periodic boundary conditions.
values : array_like, shape (m1, ..., mn, ...)
The data on the regular grid in n dimensions with periodic
boundary conditions.
"""
n = len(points)
if np.ndim(domain) == 0:
domain = [domain] * n
if np.shape(domain) != (n,):
raise ValueError("`domain` must be a scalar or have the same "
"length as `points`")
points = [x if d is None else np.concatenate([x - d, x, x + d])
for x, d in zip(points, domain)]
reps = [1 if d is None else 3 for d in domain]
values = np.tile(values, reps)
return points, values
示例
下面的示例展示了在一维周期边界条件下的插值,但上面的函数可以应用于任意维度。
![example of periodic interpolation](https://istack.dev59.com/YTQnF.webp)
rcParams['figure.dpi'] = 144
fig, axes = plt.subplots(2, 2, True, True)
np.random.seed(0)
x = np.linspace(0, 1, 10, endpoint=False)
y = np.sin(2 * np.pi * x)
ax = axes[0, 0]
ax.plot(x, y, marker='.')
ax.set_title('Points to interpolate')
sampled = np.random.uniform(0, 1, 100)
y_sampled = interpolate.interpn([x], y, sampled, bounds_error=False)
valid = ~np.isnan(y_sampled)
ax = axes[0, 1]
ax.scatter(sampled, np.where(valid, y_sampled, 0), marker='.', c=np.where(valid, 'C0', 'C1'))
ax.set_title('interpn w/o periodic bc')
[x], y = augment_with_periodic_bc([x], y, domain=1.0)
y_sampled_bc = interpolate.interpn([x], y, sampled)
ax = axes[1, 0]
ax.scatter(sampled, y_sampled_bc, marker='.')
ax.set_title('interpn w/ periodic bc')
y_sampled_bc_cubic = interpolate.interp1d(x, y, 'cubic')(sampled)
ax = axes[1, 1]
ax.scatter(sampled, y_sampled_bc_cubic, marker='.')
ax.set_title('cubic interp1d w/ periodic bc')
fig.tight_layout()