diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index b8e6b009d8..390e4379ee 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -1186,7 +1186,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): return subtensor_merge_replacements -def _is_default_scan_buffer(x: TensorVariable) -> bool: +def _is_default_scan_buffer(x: TensorVariable, taps: int) -> bool: node = x.owner if node is None: @@ -1218,7 +1218,7 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool: # 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable # But due to laziness we use the slightly more conservative check: # 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable - if broadcasted_by(y, x): + if (taps > 1) and broadcasted_by(y, x): return False return True @@ -1574,15 +1574,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # If the memory for this output has been pre-allocated # before going into the scan op (by an alloc node) if idx < op_info.n_mit_sot + op_info.n_sit_sot: + taps = init_l[i] nw_input = nw_inputs[offset + idx] # Recreate default buffers with new size - if _is_default_scan_buffer(nw_input): - extra_size = 1 if required_orphan else val - init_l[i] + if _is_default_scan_buffer(nw_input, taps): + extra_size = 1 if required_orphan else val - taps nw_input = expand_empty(nw_input.owner.inputs[1], extra_size) # Otherwise, just trim with a slice else: - stop = init_l[i] if required_orphan else val + stop = taps if required_orphan else val nw_input = nw_input[:stop] nw_inputs[offset + idx] = nw_input @@ -1626,14 +1627,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # val == 0 means that we want to keep all intermediate # results for that state, including the initial values. if idx < op_info.n_mit_sot + op_info.n_sit_sot: + taps = init_l[op_info.n_mit_mot + idx] in_idx = offset + idx nw_input = nw_inputs[in_idx] - if _is_default_scan_buffer(nw_input): + if _is_default_scan_buffer(nw_input, taps): nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps) else: - # Number of steps in the initial state - init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx]) - nw_input = nw_input[: (init_l_pt + nw_steps)] + nw_input = nw_input[: (taps + nw_steps)] nw_inputs[in_idx] = nw_input elif ( diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 1b687afcdc..9100ad70ce 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -9,13 +9,14 @@ from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config from pytensor.gradient import grad, jacobian -from pytensor.graph.basic import Constant, equal_computations +from pytensor.graph.basic import Constant, ancestors, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace from pytensor.scan.op import Scan from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge from pytensor.scan.utils import until from pytensor.tensor import stack +from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.blas import Dot22 from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot, dot, sigmoid, tanh @@ -1207,7 +1208,7 @@ def test_inplace3(self): class TestSaveMem: - mode = get_default_mode().including("scan_save_mem") + mode = get_default_mode().including("scan_save_mem").excluding("scan_pushout") def test_save_mem(self): rng = np.random.default_rng(utt.fetch_seed()) @@ -1371,7 +1372,7 @@ def test_save_mem_cannot_reduce_constant_number_of_steps(self): ) def test_save_mem_store_steps(self): - def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): + def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): return ( u_t + 1.0, u_t + 2.0, @@ -1388,7 +1389,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): x30 = vector("x30") x40 = scalar("x40") [x1, x2, x3, x4, x5, x6, x7], updates = scan( - f_rnn, + step, u, [ None, @@ -1404,7 +1405,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): go_backwards=False, ) - f2 = function( + f = function( [u, x10, x20, x30, x40], [x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]], updates=updates, @@ -1417,13 +1418,49 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): v_u = rng.uniform(-5.0, 5.0, size=(20,)) # compute the output in numpy - tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0) - - utt.assert_allclose(tx1, v_u[-7] + 1.0) - utt.assert_allclose(tx2, v_u[-3:-1] + 2.0) - utt.assert_allclose(tx3, v_u[-6:] + 3.0) - utt.assert_allclose(tx4, v_u[-1] + 4.0) - utt.assert_allclose(tx5, v_u[-1] + 5.0) + tx1, tx2, tx3, tx4, tx5 = f(v_u, [0, 0], 0, [0, 0], 0) + np.testing.assert_allclose(tx1, v_u[-7] + 1.0) + np.testing.assert_allclose(tx2, v_u[-3:-1] + 2.0) + np.testing.assert_allclose(tx3, v_u[-6:] + 3.0) + np.testing.assert_allclose(tx4, v_u[-1] + 4.0) + np.testing.assert_allclose(tx5, v_u[-1] + 5.0) + + # Confirm reduction in buffer sizes + [scan_node] = [ + node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + # x6 and x7 are dropped because they are not used + [n_steps, seq, x4_buffer, x5_buffer, x1_len, x2_len, x3_len] = scan_node.inputs + [x4_underlying_alloc] = [ + var + for var in ancestors([x4_buffer]) + if var.owner and isinstance(var.owner.op, AllocEmpty) + ] + [x5_underlying_alloc] = [ + var + for var in ancestors([x5_buffer]) + if var.owner and isinstance(var.owner.op, AllocEmpty) + ] + buffer_lengths = pytensor.function( + [u, x10, x20, x30, x40], + [ + x1_len, + x2_len, + x3_len, + x4_underlying_alloc.shape[0], + x5_underlying_alloc.shape[0], + ], + accept_inplace=True, + on_unused_input="ignore", + )(v_u, [0, 0], 0, [0, 0], 0) + # ScanSaveMem keeps +1 entries to handle taps with preallocated outputs + assert [int(i) for i in buffer_lengths] == [ + 7, # entry -7 of a map variable is kept, we need at least that many + 3, # entries [-3, -2] of a map variable are kept, we need at least 3 + 6, # last six entries of a map variable are kept + 2 + 1, # last entry of a double tap variable is kept + 1 + 1, # last entry of a single tap variable is kept + ] def test_savemem_does_not_duplicate_number_of_scan_nodes(self): var = pt.ones(())