Source code for organ.structure.constraints

"""Some generic differentiable constraints on structures."""

import torch
import torch.nn.functional as F


[docs] def edge_consistent(nodes, edges): """Penalizes edges incident to non-existing nodes. The constraint that the function is enforcing is :math:`y_{ij} <= x_i * x_j`, where :math:`x_i, x_j` is presence of a node in respective locations, and :math:`y_{ij}` is presence of an edge. As a penalty, this is transformed to: .. math:: ReLU(y_{ij} - x_i * x_j) Parameters ---------- nodes : torch.tensor Batch of node descriptions (batch, nodes, f). Assumes that sum across the last dimension is 1 and node type 0 is the absence of a node. edges : torch.tensor Batch of edge descriptions (batch, nodes, nodes, edge_types). Assumes that the sum across the last dimension is 1 and edge type 0 is the absence of an edge. Returns ------- float Penalty for edge inconsistence. """ # The probability of node presence (excluding # zero node type) x = torch.sum(nodes[:, :, 1:], -1) x = torch.einsum('bi,bj->bij', x, x) # Exclude absent edges (zero-type) return F.relu(edges[:, :, :, 1:] - torch.unsqueeze(x, -1)).sum()
[docs] def edge_symmetric(edges): """Penalizes non-symmetric edges. Parameters ---------- edges : torch.tensor Batch of edge descriptions (batch, nodes, nodes, edge_types). Returns ------- float Penalty for non-symmetric adjecency matrix. """ return torch.norm((edges - edges.permute(0, 2, 1, 3)) / 2.0)