torch.rfft – fft-based convolution creating different output than spatial convolution

我在 Pytorch 中实现了基于 FFT 的卷积,并通过 conv2d() 函数将结果与空间卷积进行了比较。使用的卷积滤波器是平均滤波器。 conv2d() 函数由于预期的平均滤波而产生了平滑的输出,但基于 fft 的卷积返回了更模糊的输出。
我已在此处附加代码和输出 –

空间卷积-

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from PIL import Image , ImageOps
import torch
from matplotlib import pyplot as plt
from torchvision. transforms import ToTensor
import torch. nn. functional as F
import numpy as np

im = Image. open ( “/kaggle/input/tiger.jpg” )
im = im. resize ( ( 256 , 256 ) )
gray_im = im. convert ( ‘L’ )
gray_im = ToTensor ( ) (gray_im )
gray_im = gray_im. squeeze ( )

fil = torch. tensor ( [ [ 1/ 9 , 1/ 9 , 1/ 9 ] , [ 1/ 9 , 1/ 9 , 1/ 9 ] , [ 1/ 9 , 1/ 9 , 1/ 9 ] ] )

conv_gray_im = gray_im. unsqueeze ( 0 ). unsqueeze ( 0 )
conv_fil = fil. unsqueeze ( 0 ). unsqueeze ( 0 )

conv_op = F. conv2d (conv_gray_im ,conv_fil )

conv_op = conv_op. squeeze ( )

plt. figure ( )
plt. imshow (conv_op , cmap = ‘gray’ )

基于 FFT 的卷积 –

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def fftshift (image ):
    sh = image. shape
    x = np. arange ( 0 , sh [ 2 ] , 1 )
    y = np. arange ( 0 , sh [ 3 ] , 1 )
    xm , ym   = np. meshgrid (x ,y )
    shifter = (1 )** (xm + ym )
    shifter = torch. from_numpy (shifter )
    return image*shifter

shift_im = fftshift (conv_gray_im )
padded_fil = F. pad (conv_fil , ( 0 , gray_im. shape [ 0 ]-fil. shape [ 0 ] , 0 , gray_im. shape [ 1 ]-fil. shape [ 1 ] ) )
shift_fil = fftshift (padded_fil )
fft_shift_im = torch. rfft (shift_im , 2 , onesided = False )
fft_shift_fil = torch. rfft (shift_fil , 2 , onesided = False )
shift_prod = fft_shift_im*fft_shift_fil
shift_fft_conv = fftshift (torch. irfft (shift_prod , 2 , onesided = False ) )

fft_op = shift_fft_conv. squeeze ( )
plt. figure ( ‘shifted fft’ )
plt. imshow (fft_op , cmap = ‘gray’ )

原图-

基于fft的卷积输出-



相关讨论

  • 你如何生成 padded_fil ?请参阅最小的可重现示例!
  • 哦,对不起,错过了那条线。我已经更新了代码。


您的代码的主要问题是 Torch 不处理复数,其 FFT 的输出是一个 3D 数组,第 3 维有两个值,一个用于实部,一个用于虚部。因此,乘法不会进行复数乘法。

目前在 Torch 中没有定义复数乘法(参见本期),我们必须自己定义。

一个小问题,但如果你想比较两个卷积操作也很重要,如下:

FFT 在第一个元素(图像的左上角像素)中获取其输入的原点。为避免输出偏移,您需要生成一个填充内核,其中内核的原点是左上角的像素。这很棘手,实际上…

您当前的代码:

1
2
3
fil = torch. tensor ( [ [ 1/ 9 , 1/ 9 , 1/ 9 ] , [ 1/ 9 , 1/ 9 , 1/ 9 ] , [ 1/ 9 , 1/ 9 , 1/ 9 ] ] )
conv_fil = fil. unsqueeze ( 0 ). unsqueeze ( 0 )
padded_fil = F. pad (conv_fil , ( 0 , gray_im. shape [ 0 ]-fil. shape [ 0 ] , 0 , gray_im. shape [ 1 ]-fil. shape [ 1 ] ) )

生成一个填充内核,其中原点以像素 (1,1) 为单位,而不是 (0,0)。它需要在每个方向上移动一个像素。 NumPy 有一个函数 roll 对此很有用,我不知道 Torch 的等价物(我对 Torch 一点也不熟悉)。这应该工作:

1
2
3
4
5
fil = torch. tensor ( [ [ 1/ 9 , 1/ 9 , 1/ 9 ] , [ 1/ 9 , 1/ 9 , 1/ 9 ] , [ 1/ 9 , 1/ 9 , 1/ 9 ] ] )
padded_fil = fil. unsqueeze ( 0 ). unsqueeze ( 0 ). numpy ( )
padded_fil = np. pad (padded_fil , ( ( 0 , gray_im. shape [ 0 ]-fil. shape [ 0 ] ) , ( 0 , gray_im. shape [ 1 ]-fil. shape [ 1 ] ) ) )
padded_fil = np. roll (padded_fil ,1 , axis = ( 0 , 1 ) )
padded_fil = torch. from_numpy (padded_fil )

最后,应用于空间域图像的 fftshift 函数会导致频域图像(应用于图像的 FFT 的结果)发生偏移,使得原点位于图像的中间,而不是左上角。这种转变在查看 FFT 的输出时很有用,但在计算卷积时毫无意义。

把这些东西放在一起,现在的卷积是:

1
2
3
4
5
6
7
8
def complex_multiplication (t1 , t2 ):
  real1 , imag1 = t1 [: ,: , 0 ] , t1 [: ,: , 1 ]
  real2 , imag2 = t2 [: ,: , 0 ] , t2 [: ,: , 1 ]
  return torch. stack ( [real1 * real2 – imag1 * imag2 , real1 * imag2 + imag1 * real2 ] , dim =1 )

fft_im = torch. rfft (gray_im , 2 , onesided = False )
fft_fil = torch. rfft (padded_fil , 2 , onesided = False )
fft_conv = torch. irfft (complex_multiplication (fft_im , fft_fil ) , 2 , onesided = False )

请注意,您可以进行单边 FFT 以节省一点计算时间:

1
2
3
fft_im = torch. rfft (gray_im , 2 , onesided = True )
fft_fil = torch. rfft (padded_fil , 2 , onesided = True )
fft_conv = torch. irfft (complex_multiplication (fft_im , fft_fil ) , 2 , onesided = True , signal_sizes =gray_im. shape )

这里的频域大小大约是完整 FFT 的一半,但它只是省略了冗余部分。卷积的结果不变。



相关讨论

  • 感谢您的回答,但我打印了填充内核,它的第一个值位于 (0,0),所以在执行 np.roll 之后,它会将一些值转移到图像的最后一列和最后一行。所以我看不到内核有任何问题。使用此左卷内核的代码的 fft 输出是一些半倒置和半直立的混合图像。此外,我首先在没有 fftshift 函数的情况下完成了基于 fft 的卷积,它给出了与问题中所示相同的额外模糊输出,但反转了一个(180 度)。所以我做了 fftshift 部分,至少得到了一个直立的输出。
  • @psj:内核的原点是它的中心,如果你定义它,否则你会看到一个移位的输出。将内核的中心放在 (0,0) 会导致它的一部分(在本例中为 1 个像素)出现在图像的右端和底端。对于 FFT,图像是周期性的。
  • @psj:好的,我安装了 Torch 来弄清楚发生了什么。事实证明,Torch 不理解复数,这使得提供 FFT 毫无意义。我已经用工作代码更新了这个答案,但是有了更好的工具,这一切都会变得容易得多。直接使用 NumPy,或者任何真实的图像处理包。
  • 谢谢,这有效! IFFT 输出图像现在??看起来类似于 conv 输出。但是当我打印两个矩阵 – conv_op 和 fft_conv 时(我尝试了两种裁剪 fft_conv 以获得等于”有效”卷积的输出——一个来自中心,一个来自左上角),它们似乎没有是平等的——即使在很小的误差范围内。这是两种方法之间的近似误差吗?此外,我应该从 IFFT 输出中选择哪种作物——我尝试了中心和左上角,但不知道应该从逻辑上选择哪一种。
  • @psj:执行的计算非常不同,因此结果会因此而在数值上有所不同。差异会随着图像边缘的增加而增加,其中 FFT 卷积的作用与空间域卷积不同。您应该显示两个图像 plt.imshow(conv_op-fft_conv) 之间的差异。这应该显示没有原始图像任何细节的图像,只是(结构化)噪声。
  • 对于 (256,256) 图像和 (3,3) 内核,”有效”卷积给出 (254,254) 输出。当我使用 FFT 方法通过将内核填充到上面的 (256,256) 并裁剪 IFFT 输出的中心 (254,254) 部分时,它与卷积输出相同,误差为 10^-8。此外,当我通过将图像和内核都填充到 (256 3-1,256 3-1) 来使用 FFT 方法时,IFFT 输出的裁剪 (1:255,1:255) 给出与 \\’ 相同的输出有效\\’卷积输出,误差为 10^-8。你能告诉我如何选择IFFT的作物吗?
  • psj:没有正确的方法。什么是正确的取决于您的需求。计算错误预计为 10^-8。
  • 是的,我对误差幅度很好,因为它可以忽略不计。我想问的是我们应该裁剪 IFFT 输出的哪一部分以获得与”有效”卷积相同的输出?我通过反复试验尝试了不同的作物,上面指定的那些给了我与”有效”卷积相同的输出——但我不明白它背后的逻辑,比如为什么答案就在这部分IFFT 输出而不是其他部分。
  • a€?valida€?卷积通常被定义为内核没有延伸到输入图像区域之外的输出区域。如果卷积将内核原点定义为在内核中间,那么它是围绕一半内核宽度的图像的边界是无效的,并被裁剪掉。但这一切都取决于实施者选择的定义。既然您已经弄清楚了两个输出的哪些部分是等价的,那么您就知道如何为这个特定的库做这件事了。其他库可能会这样做,通常他们会做出相同的选择,但您总是需要检查。


声明:本站(华域联盟www.cnhackhy.com)所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。