import logging
import torch
from torch_cluster import radius_graph
from torch_geometric.data import Data, DataLoader
from torch_scatter import scatter
def tblock():
pos = [
[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],
[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)],
[(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],
[(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],
[(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)],
]
pos = torch.tensor(pos, dtype=torch.get_default_dtype())
labels = torch.tensor(
[
[+1, 0, 0, 0, 0, 0, 0],
[-1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1],
],
dtype=torch.get_default_dtype(),
)
def mean_std(name, x) -> None:
print(f"{name} \t{x.mean():.1f} ± ({x.var(0).mean().sqrt():.1f}|{x.std():.1f})")
class Convolution(torch.nn.Module):
def __init__(self, irreps_in, irreps_sh, irreps_out, num_neighbors) -> None:
super().__init__()
self.num_neighbors = num_neighbors
tp = FullyConnectedTensorProduct(
irreps_in1=irreps_in,
irreps_in2=irreps_sh,
irreps_out=irreps_out,
internal_weights=False,
shared_weights=False,
)
self.fc = FullyConnectedNet([3, 256, tp.weight_numel], torch.relu)
self.tp = tp
self.irreps_out = self.tp.irreps_out
def forward(self, node_features, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor:
weight = self.fc(edge_scalars)
edge_features = self.tp(node_features[edge_src], edge_attr, weight)
node_features = scatter(edge_features, edge_dst, dim=0).div(self.num_neighbors**0.5)
return node_features
class Network(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.num_neighbors = 3.8
self.irreps_sh = o3.Irreps.spherical_harmonics(3)
irreps = self.irreps_sh
# First layer with gate
gate = Gate(
"16x0e + 16x0o",
[torch.relu, torch.abs],
"8x0e + 8x0o + 8x0e + 8x0o",
[torch.relu, torch.tanh, torch.relu, torch.tanh], # gates (scalars)
"16x1o + 16x1e",
)
self.conv = Convolution(irreps, self.irreps_sh, gate.irreps_in, self.num_neighbors)
self.gate = gate
irreps = self.gate.irreps_out
# Final layer
self.final = Convolution(irreps, self.irreps_sh, "0o + 6x0e", self.num_neighbors)
self.irreps_out = self.final.irreps_out