diff --git a/tests/extensions.py b/tests/extensions.py index 48ab35c0..742ba7c5 100644 --- a/tests/extensions.py +++ b/tests/extensions.py @@ -26,7 +26,17 @@ def enumerate(semiring, edge, lengths=None): semiring = semiring ssize = semiring.size() edge, batch, N, C, lengths = model._check_potentials(edge, lengths) - chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]] + chains = [ + [ + ( + [c], + semiring.fill( + torch.zeros(ssize, batch), torch.tensor(True), semiring.one + ), + ) + for c in range(C) + ] + ] enum_lengths = torch.LongTensor(lengths.shape) for n in range(1, N): @@ -128,7 +138,13 @@ def enumerate(semiring, edge): edge = semiring.convert(edge) chains = {} chains[0] = [ - ([(c, 0)], semiring.one_(torch.zeros(ssize, batch))) for c in range(C) + ( + [(c, 0)], + semiring.fill( + torch.zeros(ssize, batch), torch.tensor(True), semiring.one + ), + ) + for c in range(C) ] for n in range(1, N + 1): diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index ec40e031..69ddbc41 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -263,7 +263,7 @@ def test_generic_lengths(model_test, data): part = model().sum(vals, lengths=lengths) # Check that max is correct - assert (maxes <= part).all() + assert (maxes <= part + 1e-3).all() m_part = model(MaxSemiring).sum(vals, lengths=lengths) assert (torch.isclose(maxes, m_part)).all(), maxes - m_part diff --git a/torch_struct/autoregressive.py b/torch_struct/autoregressive.py index e750da15..e3eb49f8 100644 --- a/torch_struct/autoregressive.py +++ b/torch_struct/autoregressive.py @@ -118,8 +118,10 @@ def log_prob(self, value, sparse=False): return wrap(scores, sample) def _beam_search(self, semiring, gumbel=False): - beam = semiring.one_( - torch.zeros((semiring.size(),) + self.batch_shape, device=self.device) + beam = semiring.fill( + torch.zeros((semiring.size(),) + self.batch_shape, device=self.device), + torch.tensor(True), + semiring.one, ) ssize = semiring.size() diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index c8cb4baa..8747d7e3 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -66,10 +66,22 @@ def logpartition(self, arc_scores_in, lengths=None, force_grad=False): ] for _ in range(2) ] - semiring.one_(alpha[A][C][L].data[:, :, :, 0].data) - semiring.one_(alpha[A][C][R].data[:, :, :, 0].data) - semiring.one_(alpha[B][C][L].data[:, :, :, -1].data) - semiring.one_(alpha[B][C][R].data[:, :, :, -1].data) + mask = torch.zeros(alpha[A][C][L].data.shape).bool() + mask[:, :, :, 0].fill_(True) + alpha[A][C][L].data[:] = semiring.fill( + alpha[A][C][L].data[:], mask, semiring.one + ) + alpha[A][C][R].data[:] = semiring.fill( + alpha[A][C][R].data[:], mask, semiring.one + ) + mask = torch.zeros(alpha[B][C][L].data[:].shape).bool() + mask[:, :, :, -1].fill_(True) + alpha[B][C][L].data[:] = semiring.fill( + alpha[B][C][L].data[:], mask, semiring.one + ) + alpha[B][C][R].data[:] = semiring.fill( + alpha[B][C][R].data[:], mask, semiring.one + ) if multiroot: start_idx = 0 @@ -119,10 +131,13 @@ def _check_potentials(self, arc_scores, lengths=None): lengths = torch.LongTensor([N - 1] * batch).to(arc_scores.device) assert max(lengths) <= N, "Length longer than N" arc_scores = semiring.convert(arc_scores) - for b in range(batch): - semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :]) - semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :]) + # Set the extra elements of the log-potentials to zero. + keep = torch.ones_like(arc_scores).bool() + for b in range(batch): + keep[:, b, lengths[b] + 1 :, :].fill_(0.0) + keep[:, b, :, lengths[b] + 1 :].fill_(0.0) + arc_scores = semiring.fill(arc_scores, ~keep, semiring.zero) return arc_scores, batch, N, lengths def _arrange_marginals(self, grads): diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index b6f233e5..9ad89c56 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -36,6 +36,7 @@ class StructDistribution(Distribution): log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi` lengths (long tensor, batch_shape) : integers for length masking """ + validate_args = False def __init__(self, log_potentials, lengths=None, args={}): batch_shape = log_potentials.shape[:1] diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 3b7c0a1a..a8d2848c 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -5,13 +5,14 @@ class Chart: def __init__(self, size, potentials, semiring): - self.data = semiring.zero_( - torch.zeros( - *((semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + c = torch.zeros( + *((semiring.size(),) + size), + dtype=potentials.dtype, + device=potentials.device ) + c[:] = semiring.zero.view((semiring.size(),) + len(size) * (1,)) + + self.data = c self.grad = self.data.detach().clone().fill_(0.0) def __getitem__(self, ind): @@ -50,18 +51,17 @@ def _chart(self, size, potentials, force_grad): return self._make_chart(1, size, potentials, force_grad)[0] def _make_chart(self, N, size, potentials, force_grad=False): - return [ - ( - self.semiring.zero_( - torch.zeros( - *((self.semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) - ).requires_grad_(force_grad and not potentials.requires_grad) + chart = [] + for _ in range(N): + c = torch.zeros( + *((self.semiring.size(),) + size), + dtype=potentials.dtype, + device=potentials.device ) - for _ in range(N) - ] + c[:] = self.semiring.zero.view((self.semiring.size(),) + len(size) * (1,)) + c.requires_grad_(force_grad and not potentials.requires_grad) + chart.append(c) + return chart def sum(self, logpotentials, lengths=None, _raw=False): """ diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 593b2404..5329f8ad 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -53,7 +53,9 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False): chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad) # Init - semiring.one_(chart[:, :, :].diagonal(0, 3, 4)) + init = torch.zeros(*chart.shape).bool() + init.diagonal(0, 3, 4).fill_(True) + chart = semiring.fill(chart, init, semiring.one) # Length mask big = torch.zeros( @@ -71,8 +73,8 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False): mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N).type_as(c) mask = mask >= (lengths - 1).view(batch, 1) mask = mask.view(batch * bin_N, 1, 1).to(lp.device) - semiring.zero_mask_(lp.data, mask) - semiring.zero_mask_(c.data, (~mask)) + lp.data[:] = semiring.fill(lp.data, mask, semiring.zero) + c.data[:] = semiring.fill(c.data, ~mask, semiring.zero) c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1)) diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index 2e802c5a..3c5a4bd3 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -34,7 +34,9 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False): ) # Init. - semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1)) + mask = torch.zeros(*init.shape).bool() + mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True) + init = semiring.fill(init, mask, semiring.one) # Length mask big = torch.zeros( @@ -54,16 +56,16 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False): mask = mask.to(log_potentials.device) mask = mask >= (lengths - 1).view(batch, 1) mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device) - semiring.zero_mask_(lp.data, mask) - semiring.zero_mask_(c.data[:, :, :, 0], (~mask)) + lp.data[:] = semiring.fill(lp.data, mask, semiring.zero) + c.data[:, :, :, 0] = semiring.fill(c.data[:, :, :, 0], (~mask), semiring.zero) c[:, :, : K - 1, 0] = semiring.sum( torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1) ) end = torch.min(lengths) - 1 + mask = torch.zeros(*init.shape).bool() for k in range(1, K - 1): - semiring.one_( - init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1) - ) + mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True) + init = semiring.fill(init, mask, semiring.one) K_1 = K - 1 diff --git a/torch_struct/semirings/checkpoint.py b/torch_struct/semirings/checkpoint.py index c4e10c4f..f65eec19 100644 --- a/torch_struct/semirings/checkpoint.py +++ b/torch_struct/semirings/checkpoint.py @@ -4,6 +4,7 @@ try: import genbmm from genbmm import BandedMatrix + has_genbmm = True except ImportError: pass diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index bb7b9ec1..cfc2311c 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -47,6 +47,12 @@ def dot(cls, a, b): b = b.unsqueeze(-1) return cls.matmul(a, b).squeeze(-1).squeeze(-1) + @staticmethod + def fill(c, mask, v): + return torch.where( + mask, v.type_as(c).view((-1,) + (1,) * (len(c.shape) - 1)), c + ) + @classmethod def times(cls, *ls): "Multiply a list of tensors together" @@ -65,21 +71,6 @@ def unconvert(cls, potentials): "Unconvert from semiring by removing extra first dimension." return potentials.squeeze(0) - @staticmethod - def zero_(xs): - "Fill *ssize x ...* tensor with additive identity." - raise NotImplementedError() - - @classmethod - def zero_mask_(cls, xs, mask): - "Fill *ssize x ...* tensor with additive identity." - xs.masked_fill_(mask.unsqueeze(0), cls.zero) - - @staticmethod - def one_(xs): - "Fill *ssize x ...* tensor with multiplicative identity." - raise NotImplementedError() - @staticmethod def sum(xs, dim=-1): "Sum over *dim* of tensor." @@ -91,7 +82,8 @@ def plus(cls, a, b): class _Base(Semiring): - zero = 0 + zero = torch.tensor(0.0) + one = torch.tensor(1.0) @staticmethod def mul(a, b): @@ -101,17 +93,10 @@ def mul(a, b): def prod(a, dim=-1): return torch.prod(a, dim=dim) - @staticmethod - def zero_(xs): - return xs.fill_(0) - - @staticmethod - def one_(xs): - return xs.fill_(1) - class _BaseLog(Semiring): - zero = -1e9 + zero = torch.tensor(-1e5) + one = torch.tensor(-0.0) @staticmethod def sum(xs, dim=-1): @@ -121,14 +106,6 @@ def sum(xs, dim=-1): def mul(a, b): return a + b - @staticmethod - def zero_(xs): - return xs.fill_(-1e5) - - @staticmethod - def one_(xs): - return xs.fill_(0.0) - @staticmethod def prod(a, dim=-1): return torch.sum(a, dim=dim) @@ -200,6 +177,10 @@ def KMaxSemiring(k): "Implements the k-max semiring (kmax, +, [-inf, -inf..], [0, -inf, ...])." class KMaxSemiring(_BaseLog): + + zero = torch.tensor([-1e5 for i in range(k)]) + one = torch.tensor([0 if i == 0 else -1e5 for i in range(k)]) + @staticmethod def size(): return k @@ -211,16 +192,10 @@ def convert(cls, orig_potentials): dtype=orig_potentials.dtype, device=orig_potentials.device, ) - cls.zero_(potentials) + potentials = cls.fill(potentials, torch.tensor(True), cls.zero) potentials[0] = orig_potentials return potentials - @classmethod - def one_(cls, xs): - cls.zero_(xs) - xs[0].fill_(0) - return xs - @staticmethod def unconvert(potentials): return potentials[0] @@ -277,7 +252,8 @@ class KLDivergenceSemiring(Semiring): """ - zero = 0 + zero = torch.tensor([-1e5, -1e5, 0.0]) + one = torch.tensor([0.0, 0.0, 0.0]) @staticmethod def size(): @@ -322,27 +298,6 @@ def mul(a, b): def prod(cls, xs, dim=-1): return xs.sum(dim) - @classmethod - def zero_mask_(cls, xs, mask): - "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, -1e5) - xs[2].masked_fill_(mask, 0) - - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(-1e5) - xs[2].fill_(0) - return xs - - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) - xs[2].fill_(0) - return xs - class CrossEntropySemiring(Semiring): """ @@ -357,7 +312,8 @@ class CrossEntropySemiring(Semiring): * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ - zero = 0 + zero = torch.tensor([-1e5, -1e5, 0.0]) + one = torch.tensor([0.0, 0.0, 0.0]) @staticmethod def size(): @@ -396,27 +352,6 @@ def mul(a, b): def prod(cls, xs, dim=-1): return xs.sum(dim) - @classmethod - def zero_mask_(cls, xs, mask): - "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, -1e5) - xs[2].masked_fill_(mask, 0) - - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(-1e5) - xs[2].fill_(0) - return xs - - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) - xs[2].fill_(0) - return xs - class EntropySemiring(Semiring): """ @@ -431,7 +366,8 @@ class EntropySemiring(Semiring): * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ - zero = 0 + zero = torch.tensor([-1e5, 0.0]) + one = torch.tensor([0.0, 0.0]) @staticmethod def size(): @@ -465,24 +401,6 @@ def mul(a, b): def prod(cls, xs, dim=-1): return xs.sum(dim) - @classmethod - def zero_mask_(cls, xs, mask): - "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, 0) - - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(0) - return xs - - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) - return xs - def TempMax(alpha): class _TempMax(_BaseLog):