🎯 Towards tackling the Spectral Bias of Neural Operators!
Marimuthu Kalimuthu1,2,3, David Holzmüller 4, Mathias Niepert1,2,5
1Universität Stuttgart,
2Stuttgart Center for Simulation Science - SimTech,
3International Max Planck Research School for Intelligent Systems (IMPRS-IS),
4INRIA Paris, École Normale Supérieure, PSL University,
5NEC Labs Europe
Modeling high-frequency information is a critical challenge in Scientific Machine Learning. For instance, fully turbulent flow simulations of Navier-Stokes equations at Reynolds numbers 3500 and above can generate high-frequency signals due to swirling fluid motions caused by eddies and vortices. Faithfully modeling such signals using neural networks depends on the accurate reconstruction of moderate to high frequencies. However, it has been well known that deep neural nets exhibit the so-called spectral bias toward learning low-frequency components. Meanwhile, Fourier Neural Operators (FNOs) have emerged as a popular class of data-driven models in recent years for solving Partial Differential Equations (PDEs) and for surrogate modeling in general. Although impressive results have been achieved on several PDE benchmark problems, FNOs often perform poorly in learning non-dominant frequencies characterized by local features. This limitation stems from the spectral bias inherent in neural networks and the explicit exclusion of high-frequency modes in FNO and its variants. Therefore, to mitigate these issues and improve FNO’s spectral learning capabilities to represent a broad range of frequency components, we propose two key architectural enhancements: (i) a parallel branch performing local spectral convolutions and (ii) a high-frequency propagation module. Moreover, we propose a novel frequency-sensitive loss term based on radially binned spectral errors. This introduction of a parallel branch for local convolutions reduces the number of trainable parameters by up to 50% while achieving the accuracy of baseline FNO that relies solely on global convolutions. Experiments on three challenging PDE problems in fluid mechanics (Kolmogorov Flow 2D & Turbulent Radiative Layer 3D) and biological pattern formation (Diffusion-Reaction 2D), and the qualitative and spectral analysis of predictions show the effectiveness of our method over the state-of-the-art neural operator baselines.
The core idea of LOGLO-FNO for achieving local spectral convolution in FNO is to first partition or decompose the domain D into M non-overlapping hypercubes called patches, such that $$\bigcup_{m=1}^{M} P_m = D,$$ and then performing (learnable) spectral convolutions on these sub-domains without any truncation of Fourier modes.
$$\text{Let}~~ \mathbf{Z} \in \mathbb{R}^{\cdots \times N_c \times N_x \times N_y[\times N_z]}~~\text{be the output of the (initial) lifting layer. Then,}$$ \[ \mathbf{\Upsilon} := \sigma \Big[ \mathcal{K}(\mathbf{Z}) + \mathbf{W}_{f}*\mathbf{Z} + \mathbf{b}_{f} \Big] \] \[ \mathcal{L}(\mathbf{Z}) := \mathbf{W}_{c2}*\Big( \sigma \left[ \mathbf{W}_{c1} * \mathbf{\Upsilon} + \mathbf{b}_{c1} \right] \Big) + \mathbf{b}_{c2} + \left[ \mathbf{W}_{c} \odot \mathbf{Z} + \mathbf{b}_{c} \right] \] \[ \text{where}~\mathbf{W}_{\square}~\text{and}~\mathbf{b}_{\square}~\text{are learnable parameters and}~ \mathcal{K}(\cdot)~\text{is global kernel integral.} \]
\[ \mathbf{\Upsilon}_g := ~\sigma \Big[ \mathcal{K}_g(\mathbf{Z}) + \mathbf{W}_{g}*\mathbf{Z} + \mathbf{b}_{g} \Big], \qquad \mathbf{\Upsilon}_l := \sigma \Big[ \mathcal{K}_l(\mathbf{\hat{Z}}) + \mathbf{W}_{l}*\mathbf{\hat{Z}} + \mathbf{b}_{l} \Big], \] \[ \mathcal{L}(\mathbf{Z}, \mathbf{\hat{Z}}, \mathbf{Z^{'}}) := \mathbf{W}_{gc2}*\Big( \sigma \left[ \mathbf{W}_{gc1}* \mathbf{\Upsilon}_g + \mathbf{b}_{gc1} \right] \Big) + \mathbf{b}_{gc2} + \left[ \mathbf{W}_{gc} \odot \mathbf{Z} + \mathbf{b}_{gc} \right] ~\textcolor{blue}{+} \] \[ \qquad \qquad \qquad \left[ \mathbf{W}_{lc} \odot \mathbf{\hat{Z}} + \mathbf{b}_{lc} \right] + \mathbf{W}_{lc2}*\Big( \sigma \left[ \mathbf{W}_{lc1}* \mathbf{\Upsilon}_l + \mathbf{b}_{lc1} \right] \Big) + \mathbf{b}_{lc2} ~~\textcolor{blue}{+} \] \[ \mathbf{W}_{hfc2}*\Big( \sigma \left[ \mathbf{W}_{hfc1}* \mathbf{Z^{'}} + \mathbf{b}_{hfc1} \right] \Big) + \mathbf{b}_{hfc2}, \]
Let \(N_b\) be the batch size, \(N_x\) and \(N_y\) the resolutions of the spatial dimensions, \(N_c\) be the width or the number of hidden channels, and \(c_{\text{in}}\) and \(c_{\text{out}}\) are typically set to be of same dimension (e.g., 128). \[ \text{FLOPs}_{\text{FFT}} = 5 \cdot N_b \cdot C_{\text{in}} \cdot N_x \cdot N_y \cdot \log_2(N_x \cdot N_y) \] \[ \text{FLOPs}_{\text{IFFT}} = 5 \cdot N_b \cdot C_{\text{out}} \cdot N_x \cdot N_y \cdot \log_2(N_x \cdot N_y) \] Therefore, FFT computation on the global branch with the full 2D spatial resolution has a per channel complexity of $$\mathcal{O} \Big( N_x \cdot N_y \cdot \log_2(N_x \cdot N_y) \Big),$$ making it expensive for large spatial resolutions such as \(2048\times2048\) or higher.
Let \(N_p\) be the number of patches obtained by \( \frac{N_x \cdot N_y}{P_s^2}\), \(N_x\) and \(N_y\) the resolutions of the spatial dimensions, \(P_s \times P_s\) the patch size (e.g., \(16 \times 16\) ), \(N_c\) be the width or the number of hidden channels, and \(c_{\text{in}}\) and \(c_{\text{out}}\) are typically set to be of same dimension (e.g., 128). \begin{align*} \text{FLOPs}_{\text{FFT}} &= 5 \cdot N_b \cdot C_{\text{in}} \cdot N_p \cdot P_s \cdot P_s \cdot \log_2(P_s \cdot P_s) \\ &= 5 \cdot N_b \cdot C_{\text{in}} \cdot N_x \cdot N_y \cdot \log_2(P_s \cdot P_s) \end{align*} \begin{align*} \text{FLOPs}_{\text{IFFT}} &= 5 \cdot N_b \cdot C_{\text{out}} \cdot N_p \cdot P_s \cdot P_s \cdot \log_2(P_s \cdot P_s) \\ &= 5 \cdot N_b \cdot C_{\text{out}} \cdot N_x \cdot N_y \cdot \log_2(P_s \cdot P_s) \end{align*} Thus, the FFT and IFFT computations on the local branch, which is operating on the patches has a per channel computational complexity of $$\mathcal{O} \Big( N_x \cdot N_y \cdot \log_2(P_s \cdot P_s)\Big).$$ Since, in practice, $$\mathcal{O}\Big(N_x \cdot N_y \cdot \log_2(P_s \cdot P_s)\Big) \ll \mathcal{O}\Big(N_x \cdot N_y \cdot \log_2(N_x \cdot N_y)\Big),$$ the FFT computations are significantly cheaper in the local branch. Moreover, it is highly parallelizable when computing FFTs since each patch can be processed independently, leveraging modern accelerators such as GPUs, TPUs, and other forms of processing units.
Let \(N_b\) be the batch size, \(N_x\), \(N_y\), and \(N_z\) the resolutions of the spatial dimensions, \(N_c\) be the width or the number of hidden channels, and \(c_{\text{in}}\) and \(c_{\text{out}}\) are typically set to be of same dimension (e.g., 128). \[ \text{FLOPs}_{\text{FFT}} = 5 \cdot N_b \cdot C_{\text{in}} \cdot N_x \cdot N_y \cdot N_z \cdot \log_2(N_x \cdot N_y \cdot N_z) \] \[ \text{FLOPs}_{\text{IFFT}} = 5 \cdot N_b \cdot C_{\text{out}} \cdot N_x \cdot N_y \cdot N_z \cdot \log_2(N_x \cdot N_y \cdot N_z) \] Therefore, FFT computation on the global branch with the full 3D spatial resolution has a per channel complexity of $$\mathcal{O}\Big(N_x \cdot N_y \cdot N_z \cdot \log_2(N_x \cdot N_y \cdot N_z)\Big),$$ making it highly expensive for large values of \(N_x\), \(N_y\), and \(N_z\) such as \(512\times512\times512\) or further higher spatial resolutions.
Let \(N_p\) be the number of patches obtained by \( \frac{N_x \cdot N_y \cdot N_z}{P_s^3}\), \(N_x\), \(N_y\), and \(N_z\) the resolutions of the spatial dimensions, \(P_s \times P_s \times P_s\) the patch size (e.g., \(16 \times 16 \times 16\), \(32 \times 32 \times 32\), etc.), \(N_c\) be the width or the number of hidden channels, and \(c_{\text{in}}\) and \(c_{\text{out}}\) are typically set to be of same dimension (e.g., 128). \begin{align*} \text{FLOPs}_{\text{FFT}} &= 5 \cdot N_b \cdot C_{\text{in}} \cdot N_p \cdot P_s \cdot P_s \cdot P_s \cdot \log_2(P_s \cdot P_s \cdot P_s) \\ &= 5 \cdot N_b \cdot C_{\text{in}} \cdot N_x \cdot N_y \cdot N_z \cdot \log_2(P_s \cdot P_s \cdot P_s) \end{align*} \begin{align*} \text{FLOPs}_{\text{IFFT}} &= 5 \cdot N_b \cdot C_{\text{out}} \cdot N_p \cdot P_s \cdot P_s \cdot P_s \cdot \log_2(P_s \cdot P_s \cdot P_s) \\ &= 5 \cdot N_b \cdot C_{\text{out}} \cdot N_x \cdot N_y \cdot N_z \cdot \log_2(P_s \cdot P_s \cdot P_s) \end{align*} Thus, the FFT and IFFT computations on the local branch operating on 3D patches has a per channel computational complexity of $$\mathcal{O}\Big(N_x \cdot N_y \cdot N_z \cdot \log_2(P_s \cdot P_s \cdot P_s)\Big).$$ Since, in practice, $$\mathcal{O}\Big(N_x \cdot N_y \cdot N_z \cdot \log_2(P_s \cdot P_s \cdot P_s)\Big) \ll \mathcal{O}\Big(N_x \cdot N_y \cdot N_z \cdot \log_2(N_x \cdot N_y \cdot N_z)\Big),$$ the FFT computations are significantly cheaper in the local branch. Furthermore, it is highly parallelizable when computing FFTs since each patch can be processed independently, leveraging modern accelerators such as GPUs, TPUs, and other forms of PUs.
In addition, we propose a spectral loss term based on the radial binning of spectral energy of errors, which is as follows:
def RadialBinnedSpectralLoss(preds, target):
# input data shape and params
nb, nc, nx, ny, nt = target.size()
iLow, iHigh = 4, 12
Lx, Ly = 1.0, 1.0
# Compute error in Fourier space
err_phys = preds - target
err_fft = torch.fft.fftn(err_phys, dim=[2, 3])
err_fft_sq = torch.abs(err_fft)**2
err_fft_sq_h = err_fft_sq[Ellipsis, :nx//2, :ny//2, :]
# Create radial indices
x = torch.arange(nx//2)
y = torch.arange(ny//2)
X, Y = torch.meshgrid(x, y, indexing="ij")
radii = torch.sqrt(X**2 + Y**2).floor().to(torch.int) # Radial dist.
max_radius = int(torch.max(radii))
# flatten radii for binary mask
radii_flat = radii.flatten() # (nx//2 * ny//2)
# Spatially flatten Fourier space error; (nb, nc, nx//2 * ny//2, nt)
err_fft_sq_flat = err_fft_sq_h.contiguous().reshape(nb, nc, -1, nt)
# initialize output tensor to hold the Fourier error
# for each radial bin at distance r from the origin
err_F_vect_full = torch.zeros(nb, nc, max_radius + 1, nt)
# Apply ‘index_add_‘ for all radii and accumulate the errors
valid_r = radii_flat <= max_radius # binary mask to find valid radii
# Sum for all valid radial indices
err_F_vect_full.index_add_(2,
radii_flat[valid_r],
err_fft_sq_flat[:, :, valid_r]
)
# Normalize & compute mean over batch; (nc, min(nx//2, ny//2), nt)
nrm = (nx * ny) * Lx * Ly
_err_F = torch.sqrt(torch.mean(err_F_vect_full, dim=0)) / nrm
# Classify Fourier space error into three bands
err_F = torch.zeros([nc, 3, nt])
err_F[:, 0] += torch.mean(_err_F[:, :iLow], dim=1) # low freqs
err_F[:, 1] += torch.mean(_err_F[:, iLow:iHigh], dim=1) # mid freqs
err_F[:, 2] += torch.mean(_err_F[:, iHigh:], dim=1) # high freqs
# mean or sum over the channels and time dimensions
if reduction == "mean":
freq_loss = torch.mean(err_F, dim=[0, -1])
elif reduction == "sum":
freq_loss = torch.sum(err_F, dim=[0, -1])
return freq_loss
The 1-step training loss for N trajectories, each comprising T timesteps, is given by,
$$\theta^{*} = \arg\min_\theta \sum_{n=1}^{N} \sum_{t=1}^{T-1} \mathcal{C}(\mathcal{N}_{\theta}(u^t), u^{t+1}), \qquad \quad \mathcal{C} = \mathcal{C}_{\mathrm{MSE}} + \lambda \cdot \mathcal{C}_{\mathrm{freq}}, \quad 0 \leq \lambda \leq 1 $$
Let \(\mathbf{X}^{hf} \in \mathbb{R}^{N_b \times N_x \times N_y \times (N_t \cdot N_c)}\) denote the input tensor of HF features, where \(N_b\) is the batch size, \(N_x\) and \(N_y\) are the resolutions of the spatial dimensions, and \(N_t \cdot N_c\) represents the combined temporal and channel dimensions. The high-frequency feature adaptive Gaussian noise \(\mathbf{N}_{dynamic}\) is then computed as follows:
1. Compute (per sample) Mean \(\mu_b\) and Standard Deviation \(\sigma_b\) of High-Frequency Features \[ \mu_b = \frac{1}{N_x \cdot N_y \cdot N_t \cdot N_c} \sum_{i=1}^{N_x} \sum_{j=1}^{N_y} \sum_{k=1}^{N_t N_c} \mathbf{X}_{b,i,j,k}^{\mathrm{hf}} \] \[ \sigma_b = \sqrt{\frac{1}{N_x \cdot N_y \cdot N_t \cdot N_c} \sum_{i=1}^{N_x} \sum_{j=1}^{N_y} \sum_{k=1}^{N_t N_c} (\mathbf{X}_{b,i,j,k}^{\mathrm{hf}} - \mu)^2} + \epsilon \] where \(\epsilon\) is a small constant added for numerical stability, and \(\mu\) and \(\sigma\) are obtained by stacking the per sample statistics along the batch dimension. 2. Generate Standard Gaussian Noise \[ \mathbf{N} \sim \mathcal{N}(0, 1) \] 3. Scale Noise Dynamically \[ \mathbf{N}_{dynamic} = \mu + \alpha \cdot \sigma \cdot \mathbf{N} \] where \(\alpha\) is a small value such as 0.025 and \(\mathbf{N}_{dynamic}\) has the same shape as the input \(\mathbf{X}^{hf}\). \(\mathbf{N}_{dynamic}\) can now be added to the batch of inputs to the global and local branches during the training phase of LOGLO-FNO.In the below slideshow, we visualize the radially binned spectral energy errors of predictions of the considered neural operators and LOGLO-FNO on the Kolmogorov Flow 2D PDE. Note that we show only alternate radial bins prioritizing uncluttered representation over completeness.
Radial Spectral Loss of Base FNO model predictions on the turbulent Kolmogorov Flow 2D
We evaluate LOGLO-FNO on the challenging turbulent version of Kolmogorov Flow 2D benchmark
We evaluate LOGLO-FNO on the challenging setup of Turbulent Radiative Mixing Layer 3D benchmark
![]() |
![]() |
The setup comprises training with 1-step loss and evaluating 1-step results on a host of metrics. The results are compared against a diverse set of competitive neural operator baselines such as Modern UNet, ConvNext U-Net, and FNO.
We evaluate LOGLO-FNO on the challenging coupled PDE, Diffusion Reaction 2D, benchmark
We visualize the predictions of FNO and LOGLO-FNO on the time-dependent Turbulent Radiative Mixing Layer 3D PDE.
@inproceedings{loglo-fno-kalimuthu:2025,
title={{LOGLO}-{FNO}: Efficient Learning of Local and Global Features in Fourier Neural Operators},
author={Marimuthu Kalimuthu and David Holzm{\"u}ller and Mathias Niepert},
booktitle={ICLR 2025 Workshop on Machine Learning Multiscale Processes},
year={2025},
url={https://openreview.net/forum?id=OCM7OkVg9C}
}