🎯图模型和消息传递推理算法 | 🎯消息传递推理和循环消息传递推理算法 | 🎯空间人工智能算法多维姿势估计 | 🎯超图结构解码算法量子计算 | 🎯GPU处理变分推理消息传递贝叶斯网络算法 | 🎯高斯消息传递数学形式和算法代码 | 🎯蜂窝通信Wi-Fi大量数据传输和存储信道编码解码算法 | 🎯图消息传递推理算法暴力哈希加密协议
import numpy as np
from collections import namedtuple
LabeledArray = namedtuple('LabeledArray', [
'array',
'axes_labels',
])
def name_to_axis_mapping(labeled_array):
return {
name: axis
for axis, name in enumerate(labeled_array.axes_labels)
}
def other_axes_from_labeled_axes(labeled_array, axis_label):
return tuple(
axis
for axis, name in enumerate(labeled_array.axes_labels)
if name != axis_label
)
def is_conditional_prob(labeled_array, var_name):
return np.all(np.isclose(np.sum(
labeled_array.array,
axis=name_to_axis_mapping(labeled_array)[var_name]
), 1.0))
def is_joint_prob(labeled_array):
return np.all(np.isclose(np.sum(labeled_array.array), 1.0))
p_v1_given_h1 = LabeledArray(np.array([[0.4, 0.8, 0.9], [0.6, 0.2, 0.1]]), ['v1', 'h1'])
p_h1 = LabeledArray(np.array([0.6, 0.3, 0.1]), ['h1'])
p_v1_given_many = LabeledArray(np.array(
[[[0.9, 0.2], [0.3, 0.2]],
[[0.1, 0.8], [0.7, 0.8]]]
), ['v1', 'h1', 'h2'])
assert is_conditional_prob(p_v1_given_h1, 'v1')
assert not is_joint_prob(p_v1_given_h1)
assert is_conditional_prob(p_h1, 'h1')
assert is_joint_prob(p_h1)
assert is_conditional_prob(p_v1_given_many, 'v1')
assert not is_joint_prob(p_v1_given_many)
def tile_to_shape_along_axis(arr, target_shape, target_axis):
raw_axes = list(range(len(target_shape)))
tile_dimensions = [target_shape[a] for a in raw_axes if a != target_axis]
if len(arr.shape) == 0:
tile_dimensions += [target_shape[target_axis]]
elif len(arr.shape) == 1:
assert arr.shape[0] == target_shape[target_axis]
tile_dimensions += [1]
else:
raise NotImplementedError()
tiled = np.tile(arr, tile_dimensions)
shifted_axes = raw_axes[:target_axis] + [raw_axes[-1]] + raw_axes[target_axis:-1]
transposed = np.transpose(tiled, shifted_axes)
assert transposed.shape == target_shape
return transposed
def tile_to_other_dist_along_axis_name(tiling_labeled_array, target_array):
assert len(tiling_labeled_array.axes_labels) == 1
target_axis_label = tiling_labeled_array.axes_labels[0]
return LabeledArray(
tile_to_shape_along_axis(
tiling_labeled_array.array,
target_array.array.shape,
name_to_axis_mapping(target_array)[target_axis_label]
),
axes_labels=target_array.axes_labels
)
tiled_p_h1 = tile_to_other_dist_along_axis_name(p_h1, p_v1_given_h1)
assert np.isclose(np.sum(p_v1_given_h1.array * tiled_p_h1.array), 1.0)
def _variable_to_factor_messages(variable, factor):
incoming_messages = [
_factor_to_variable_message(neighbor_factor, variable)
for neighbor_factor in variable.neighbors
if neighbor_factor.name != factor.name
]
return np.prod(incoming_messages, axis=0)