🍠Python非线性光学映射数据压缩算法

🏈指点迷津 | Brief

🎯要点

🎯多重散射腔内的被动非线性光学映射 | 🎯积分球构建多重散射腔 | 🎯改变腔内的光散射增强非线性阶数 | 🎯使用Fashion MNIST图像集,面部关键点数据集和行人检测数据集评估算法

🍪语言内容分比

🍇Python图像抖动算法优化

我们可以使用此算法将灰度图像转换为黑色和白色两种颜色。该算法将像素值四舍五入为两个极值中的最接近值(0 表示黑色,255 表示白色)。原始值与四舍五入值之间的差值(称为误差)将添加到相邻像素,分布如下:

 [ ...            , current pixel (rounded)   , + 7/16 of error, ... ]
 [ + 3/16 of error, + 5/16 of error           , + 1/16 of error, ... ]

因此,该行上的下一个像素的误差为 7/16,而下一行的像素的误差为 5/16,依此类推。处理完当前像素后,算法将转到下一个像素,该像素现在包含前一个像素的部分误差。优化此算法的一个关键问题是,并行处理像素可能是不可能的:每个像素的最终值都会受到对先前像素进行的计算的影响。这表明使用多个线程进行并行处理可能很困难或不可能。

让我们加载将要使用的库以及测试图像(一个 400×400 NumPy 的 uint8 数组)。

 from numba import njit
 import numpy as np
 from skimage import io
 ​
 image = io.imread("images/hallway.jpg")

如果这不仅仅是一个示例,我们希望使用各种图像和尺寸对代码进行基准测试,以匹配我们期望遇到的各种输入。然而,为了简单起见,我们将坚持使用这张图片。

以下,我将从一个简单的实现开始,使其更快,减少内存使用量,然后再进行进一步优化。为清晰起见,省略了一些中间步骤和失败的实验。

最简单实现:

代码将临时结果存储在 16 位整数中,因为添加错误可能会使某些像素为负数或大于 255。这两种情况都不适合无符号 8 位整数。最后,我将结果转换为 8 位图像,这就是函数应该返回的内容。

用 numba.njit 修饰的代码看起来像 Python 代码,但是实际上在运行时编译成机器代码。这是快速、低水平的代码!

 @njit
 def dither(img):
     result = img.astype(np.int16)
     y_size = img.shape[0]
     x_size = img.shape[1]
     last_y = y_size - 1
     last_x = x_size - 1
     for y in range(y_size):
         for x in range(x_size):
             old_value = result[y, x]
             if old_value < 0:
                 new_value = 0
             elif old_value > 255:
                 new_value = 255
             else:
                 new_value = np.uint8(np.round(old_value / 255.0)) * 255
             result[y, x] = new_value
             error = np.int16(old_value) - new_value
             if x < last_x:
                 result[y, x + 1] += error * 7 // 16
             if y < last_y and x > 0:
                 result[y + 1, x - 1] += error * 3 // 16
             if y < last_y:
                 result[y + 1, x] += error * 5 // 16
             if y < last_y and x < last_x:
                 result[y + 1, x + 1] += error // 16
 ​
     return result.astype(np.uint8)
 ​
 baseline_result = dither(image)

运行时间:2.3ms

一般来说,我们希望使循环的内部部分尽可能快。查看代码中的误差扩散部分,指令级并行性应该有助于加快代码的运行速度,因为每个计算都是独立的:

 if x < last_x:
     result[y, x + 1] += error * 7 // 16
 if y < last_y and x > 0:
     result[y + 1, x - 1] += error * 3 // 16
 if y < last_y:
     result[y + 1, x] += error * 5 // 16
 if y < last_y and x < last_x:
     result[y + 1, x + 1] += error // 16

分支预测错误会导致它们变慢吗?稍微思考一下就会发现,这些分支非常容易预测,因为它们仅取决于像素位置。考虑一个 6×6 图像:根据图像中像素的位置,将采用不同的分支组合。

 1 2 2 2 2 3
 1 2 2 2 2 3
 1 2 2 2 2 3
 1 2 2 2 2 3
 1 2 2 2 2 3
 4 5 5 5 5 6

例如,区域 1 和 4 中的像素无法将误差扩散到前一列,因为没有前一列。因此不会采用相关分支(如果 y < last_y 且 x > 0)。

在较大的图像中,几乎所有像素都将位于区域 2 中,并采用完全相同的分支,因此 CPU 应该能够可靠地预测这些分支。因此,合理的假设是,即使在当前状态下,代码的扩散部分也以不错的速度运行。

相反,误差本身的计算似乎可能很慢:

 if old_value < 0:
     new_value = 0
 elif old_value > 255:
     new_value = 255
 else:
     new_value = np.uint8(np.round(old_value / 255.0)) * 255

首先,分支取决于像素值,因此 CPU 可能难以预测,而且不清楚编译器是否会生成无分支代码。其次,这其中涉及一些相对复杂的数学运算:对浮点数进行舍入似乎会很慢。

优化

所有这些都可以简化为一个简单的检查:测量像素值是否小于或大于中间点。在 Python 中:new_value = 0 if old_value < 128 else 255。由于 new_value 在任何一种情况下都会获得一个值集,因此希望编译器将其转换为无分支代码,这样我们就不必担心分支预测错误的代价。

 @njit
 def dither2(img):
     result = img.astype(np.int16)
     y_size = img.shape[0]
     x_size = img.shape[1]
     last_y = y_size - 1
     last_x = x_size - 1
     for y in range(y_size):
         for x in range(x_size):
             old_value = result[y, x]
             # Branchless, simple rounding:
             new_value = 0 if old_value < 128 else 255
             result[y, x] = new_value
             error = old_value - new_value
             if x < last_x:
                 result[y, x + 1] += error * 7 // 16
             if y < last_y and x > 0:
                 result[y + 1, x - 1] += error * 3 // 16
             if y < last_y:
                 result[y + 1, x] += error * 5 // 16
             if y < last_y and x < last_x:
                 result[y + 1, x + 1] += error // 16
     return result.astype(np.uint8)
 ​
 assert np.array_equal(dither2(image), baseline_result)

运行时间:830.7μ\mus

优化内存

 @njit
 def dither3(img):
     result = np.empty(img.shape, dtype=np.uint8)
 ​
     staging = img[0:2].astype(np.int16)
     y_size = img.shape[0]
     x_size = img.shape[1]
     last_x = x_size - 1
     for y in range(y_size):
         for x in range(x_size):
             old_value = staging[0, x]
             new_value = 0 if old_value < 128 else 255
             staging[0, x] = new_value
             error = old_value - new_value
             if x < last_x:
                 staging[0, x + 1] += error * 7 // 16
             if x > 0:
                 staging[1, x - 1] += error * 3 // 16
             staging[1, x] += error * 5 // 16
             if x < last_x:
                 staging[1, x + 1] += error // 16
 ​
         result[y,:] = staging[0,:]
 ​
         staging[0,:] = staging[1,:]
         if y < y_size - 2:
             staging[1,:] = img[y + 2,:]
     return result
 ​
 assert np.array_equal(dither3(image), baseline_result)

运行时间:909.6μ\mus

优化读写

 @njit
 def dither6(img):
     result = np.empty(img.shape, dtype=np.uint8)
     y_size = img.shape[0]
     x_size = img.shape[1]
     staging_current = np.zeros(x_size + 2, np.int16)
     staging_current[1:-1] = img[0]
     staging_next = np.zeros(x_size + 2, np.int16)
 ​
     for y in range(y_size):
         right_pixel_error = 0
         downleft_prev_error = 0
         downleft_prevprev_error = 0
         for x in range(x_size):
             old_value = staging_current[x + 1] + right_pixel_error
             new_value = 0 if old_value < 128 else 255
             result[y, x] = new_value
             error = old_value - new_value
             right_pixel_error = error * 7 // 16
 ​
             staging_next[x] = (
                 img[y + 1, x - 1] + downleft_prev_error + error * 3 // 16
             )
 ​
             downleft_prev_error = downleft_prevprev_error + error * 5 // 16
             downleft_prevprev_error = error // 16
 ​
         staging_next[x_size] = img[y + 1, x_size - 1] + downleft_prev_error
 ​
         staging_current, staging_next = staging_next, staging_current
 ​
     return result
 ​
 assert np.array_equal(dither6(image), baseline_result)

运行时间:602.7μ\mus

Last updated