【论文阅读】什么是PixelShuffle?为什么它是图像生成任务中更优的上采样方法?

【论文阅读】什么是PixelShuffle?为什么它是图像生成任务中更优的上采样方法?

admin
7月29日发布
温馨提示:
本文最后更新于2025年07月29日,已超过81天没有更新,若内容或图片失效,请留言反馈。

PixelShuffle是一种高效的无参数上采样方法,通过将通道维度信息重新排列到空间维度,显著提升图像生成任务的性能。相比传统转置卷积,PixelShuffle完全消除了因卷积核不均匀重叠导致的棋盘伪影,同时计算速度更快、内存占用更低,适用于实时超分辨率重建。该技术已在PyTorch中集成(nn.PixelShuffle),广泛应用于图像超分辨率、生成对抗网络(GAN)和语义分割等任务,能够有效增强细节生成能力并优化计算效率。实验证明,PixelShuffle在保持高视觉质量的同时,显著加速上采样过程,使其成为深度学习图像生成任务中的理想选择。

原始论文

Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network

这篇论文提出了一种亚像素卷积的方法来对图像进行超分辨率重建,速度特别快。虽然论文里面称提出的方法为亚像素卷积(sub-pixel convolution),但是实际上并不涉及到卷积运算,是一种高效、快速、无参的像素重排列的上采样方式。由于很快,直接用在视频超分中也可以做到实时。

PixelShuffle原始论文

基本原理如下图。PixelShuffle的核心思想是:将通道维度的信息重新排列到空间维度

基本原理图

假设我们要将图像放大r倍:输入的特征图形状(B, C×r², H, W),输出:(B, C, H×r, W×r),其中B:批大小,C:输出通道数,r:放大因子,H, W:输入的高度和宽度。

以2倍放大为例:输入特征图形状:(1, 12, 4, 4) (假设要输出3通道图像),重排为:(1, 3, 4, 4, 2, 2),最终输出:(1, 3, 8, 8),每个输出像素的值来自于输入的不同通道,这样就实现了分辨率的提升。

主要优势

(1)主要优势在于消除棋盘效应,通过像素重排而非卷积操作实现上采样,从根本上避免了转置卷积的棋盘效应问题。转置卷积产生棋盘效应的根本原因在于卷积核的不均匀重叠。当stride不能被kernel_size整除时,某些像素位置会被更多次地"填充",而另一些位置填充次数较少,导致输出图像呈现规律性的格子状伪影。下图就是棋盘效应的例子。

消除棋盘示例

(2)非常快。由于带计算的操作都是在低分辨率空间中进行的,所以速度相对会快很多。PixelShuffle本身不包含可学习参数,只进行张量重排,相比转置卷积,内存占用更少,重排操作的计算复杂度远低于卷积操作。论文里对图片也展示了这一点,可以看到遥遥领先!

运行非常快速

(3)更好的特征利用。通过将通道维度的信息转换为空间维度,PixelShuffle能够更好地利用网络学习到的特征表示,生成更丰富的细节。

PyTorch中使用

在PyTorch中已经进行了集成,使用起来非常简单:

import torch
import torch.nn as nn

# 创建PixelShuffle层
pixel_shuffle = nn.PixelShuffle(upscale_factor=2)

# 输入张量 (batch_size, channels*upscale_factor^2, height, width)
input_tensor = torch.randn(1, 12, 16, 16)  # 要输出3通道,所以输入3*2^2=12通道

# 执行PixelShuffle
output = pixel_shuffle(input_tensor)
print(f"输入形状: {input_tensor.shape}")
print(f"输出形状: {output.shape}")
# 输入形状: torch.Size([1, 12, 16, 16])
# 输出形状: torch.Size([1, 3, 32, 32])

基于底层原理,也可以手动实现:

def pixel_shuffle_manual(input_tensor, upscale_factor):
    """
    手动实现PixelShuffle操作
    """
    batch_size, channels, height, width = input_tensor.size()
    channels_output = channels // (upscale_factor ** 2)
    
    # 重新排列张量
    input_view = input_tensor.contiguous().view(
        batch_size, channels_output, upscale_factor, upscale_factor, height, width
    )
    
    # 调整维度顺序并重新整形
    output = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
    output = output.view(
        batch_size, channels_output, height * upscale_factor, width * upscale_factor
    )
    
    return output

# 测试自定义实现
test_input = torch.randn(1, 12, 8, 8)
manual_output = pixel_shuffle_manual(test_input, 2)
torch_output = nn.PixelShuffle(2)(test_input)

print(f"手动实现输出形状: {manual_output.shape}")
print(f"PyTorch实现输出形状: {torch_output.shape}")
print(f"结果是否相等: {torch.allclose(manual_output, torch_output)}")

主要应用场景

(1)图像超分辨率(Super-Resolution)

这是PixelShuffle最典型的应用场景,用于将低分辨率图像恢复为高分辨率图像。

典型网络结构

  • ESPCN(Efficient Sub-Pixel CNN)
  • SRCNN的改进版本
  • Real-ESRGAN等现代超分网络

(2)图像生成任务

在生成对抗网络(GAN)中,PixelShuffle常用于生成器的上采样层。

应用示例

  • StyleGAN的生成器
  • Pix2Pix网络
  • CycleGAN等图像转换任务

(3)语义分割

在需要恢复原始分辨率的分割任务中,PixelShuffle可以替代转置卷积。

典型网络

  • U-Net的解码器部分
  • DeepLab系列网络
  • PSPNet等密集预测任务
© 版权声明
THE END
喜欢就支持一下吧
点赞 0 分享 赞赏
评论 抢沙发
上传图片
OωO
取消