🏈 指点迷津 | Brief 🎯要点
🎯多重散射腔内的被动非线性光学映射 | 🎯积分球构建多重散射腔 | 🎯改变腔内的光散射增强非线性阶数 | 🎯使用Fashion MNIST图像集,面部关键点数据集和行人检测数据集评估算法
🍪语言内容分比
🍇Python图像抖动算法优化
我们可以使用此算法将灰度图像转换为黑色和白色两种颜色。该算法将像素值四舍五入为两个极值中的最接近值(0 表示黑色,255 表示白色)。原始值与四舍五入值之间的差值(称为误差)将添加到相邻像素,分布如下:
Copy [ ... , current pixel (rounded) , + 7/16 of error, ... ]
[ + 3/16 of error, + 5/16 of error , + 1/16 of error, ... ]
因此,该行上的下一个像素的误差为 7/16,而下一行的像素的误差为 5/16,依此类推。处理完当前像素后,算法将转到下一个像素,该像素现在包含前一个像素的部分误差。优化此算法的一个关键问题是,并行处理像素可能是不可能的:每个像素的最终值都会受到对先前像素进行的计算的影响。这表明使用多个线程进行并行处理可能很困难或不可能。
让我们加载将要使用的库以及测试图像(一个 400×400 NumPy 的 uint8 数组)。
Copy from numba import njit
import numpy as np
from skimage import io
image = io.imread("images/hallway.jpg")
如果这不仅仅是一个示例,我们希望使用各种图像和尺寸对代码进行基准测试,以匹配我们期望遇到的各种输入。然而,为了简单起见,我们将坚持使用这张图片。
以下,我将从一个简单的实现开始,使其更快,减少内存使用量,然后再进行进一步优化。为清晰起见,省略了一些中间步骤和失败的实验。
最简单实现:
代码将临时结果存储在 16 位整数中,因为添加错误可能会使某些像素为负数或大于 255。这两种情况都不适合无符号 8 位整数。最后,我将结果转换为 8 位图像,这就是函数应该返回的内容。
用 numba.njit 修饰的代码看起来像 Python 代码,但是实际上在运行时编译成机器代码。这是快速、低水平的代码!
Copy @njit
def dither(img):
result = img.astype(np.int16)
y_size = img.shape[0]
x_size = img.shape[1]
last_y = y_size - 1
last_x = x_size - 1
for y in range(y_size):
for x in range(x_size):
old_value = result[y, x]
if old_value < 0:
new_value = 0
elif old_value > 255:
new_value = 255
else:
new_value = np.uint8(np.round(old_value / 255.0)) * 255
result[y, x] = new_value
error = np.int16(old_value) - new_value
if x < last_x:
result[y, x + 1] += error * 7 // 16
if y < last_y and x > 0:
result[y + 1, x - 1] += error * 3 // 16
if y < last_y:
result[y + 1, x] += error * 5 // 16
if y < last_y and x < last_x:
result[y + 1, x + 1] += error // 16
return result.astype(np.uint8)
baseline_result = dither(image)
运行时间:2.3ms
一般来说,我们希望使循环的内部部分尽可能快。查看代码中的误差扩散部分,指令级并行性应该有助于加快代码的运行速度,因为每个计算都是独立的:
Copy if x < last_x:
result[y, x + 1] += error * 7 // 16
if y < last_y and x > 0:
result[y + 1, x - 1] += error * 3 // 16
if y < last_y:
result[y + 1, x] += error * 5 // 16
if y < last_y and x < last_x:
result[y + 1, x + 1] += error // 16
分支预测错误会导致它们变慢吗?稍微思考一下就会发现,这些分支非常容易预测,因为它们仅取决于像素位置。考虑一个 6×6 图像:根据图像中像素的位置,将采用不同的分支组合。
Copy 1 2 2 2 2 3
1 2 2 2 2 3
1 2 2 2 2 3
1 2 2 2 2 3
1 2 2 2 2 3
4 5 5 5 5 6
例如,区域 1 和 4 中的像素无法将误差扩散到前一列,因为没有前一列。因此不会采用相关分支(如果 y < last_y 且 x > 0)。
在较大的图像中,几乎所有像素都将位于区域 2 中,并采用完全相同的分支,因此 CPU 应该能够可靠地预测这些分支。因此,合理的假设是,即使在当前状态下,代码的扩散部分也以不错的速度运行。
相反,误差本身的计算似乎可能很慢:
Copy if old_value < 0:
new_value = 0
elif old_value > 255:
new_value = 255
else:
new_value = np.uint8(np.round(old_value / 255.0)) * 255
首先,分支取决于像素值,因此 CPU 可能难以预测,而且不清楚编译器是否会生成无分支代码。其次,这其中涉及一些相对复杂的数学运算:对浮点数进行舍入似乎会很慢。
优化
所有这些都可以简化为一个简单的检查:测量像素值是否小于或大于中间点。在 Python 中:new_value = 0 if old_value < 128 else 255。由于 new_value 在任何一种情况下都会获得一个值集,因此希望编译器将其转换为无分支代码,这样我们就不必担心分支预测错误的代价。
Copy @njit
def dither2(img):
result = img.astype(np.int16)
y_size = img.shape[0]
x_size = img.shape[1]
last_y = y_size - 1
last_x = x_size - 1
for y in range(y_size):
for x in range(x_size):
old_value = result[y, x]
# Branchless, simple rounding:
new_value = 0 if old_value < 128 else 255
result[y, x] = new_value
error = old_value - new_value
if x < last_x:
result[y, x + 1] += error * 7 // 16
if y < last_y and x > 0:
result[y + 1, x - 1] += error * 3 // 16
if y < last_y:
result[y + 1, x] += error * 5 // 16
if y < last_y and x < last_x:
result[y + 1, x + 1] += error // 16
return result.astype(np.uint8)
assert np.array_equal(dither2(image), baseline_result)
运行时间:830.7μ \mu μ s
优化内存
Copy @njit
def dither3(img):
result = np.empty(img.shape, dtype=np.uint8)
staging = img[0:2].astype(np.int16)
y_size = img.shape[0]
x_size = img.shape[1]
last_x = x_size - 1
for y in range(y_size):
for x in range(x_size):
old_value = staging[0, x]
new_value = 0 if old_value < 128 else 255
staging[0, x] = new_value
error = old_value - new_value
if x < last_x:
staging[0, x + 1] += error * 7 // 16
if x > 0:
staging[1, x - 1] += error * 3 // 16
staging[1, x] += error * 5 // 16
if x < last_x:
staging[1, x + 1] += error // 16
result[y,:] = staging[0,:]
staging[0,:] = staging[1,:]
if y < y_size - 2:
staging[1,:] = img[y + 2,:]
return result
assert np.array_equal(dither3(image), baseline_result)
运行时间:909.6μ \mu μ s
优化读写
Copy @njit
def dither6(img):
result = np.empty(img.shape, dtype=np.uint8)
y_size = img.shape[0]
x_size = img.shape[1]
staging_current = np.zeros(x_size + 2, np.int16)
staging_current[1:-1] = img[0]
staging_next = np.zeros(x_size + 2, np.int16)
for y in range(y_size):
right_pixel_error = 0
downleft_prev_error = 0
downleft_prevprev_error = 0
for x in range(x_size):
old_value = staging_current[x + 1] + right_pixel_error
new_value = 0 if old_value < 128 else 255
result[y, x] = new_value
error = old_value - new_value
right_pixel_error = error * 7 // 16
staging_next[x] = (
img[y + 1, x - 1] + downleft_prev_error + error * 3 // 16
)
downleft_prev_error = downleft_prevprev_error + error * 5 // 16
downleft_prevprev_error = error // 16
staging_next[x_size] = img[y + 1, x_size - 1] + downleft_prev_error
staging_current, staging_next = staging_next, staging_current
return result
assert np.array_equal(dither6(image), baseline_result)
运行时间:602.7μ \mu μ s