Skip to content

Avoid default allocation for taps of length 1 in ScanSaveMem #1395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements


def _is_default_scan_buffer(x: TensorVariable) -> bool:
def _is_default_scan_buffer(x: TensorVariable, taps: int) -> bool:
Copy link
Preview

Copilot AI May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure that all callers of _is_default_scan_buffer supply the correct 'taps' value so that the default buffer check correctly distinguishes between single and multiple taps.

Suggested change
def _is_default_scan_buffer(x: TensorVariable, taps: int) -> bool:
def _is_default_scan_buffer(x: TensorVariable, taps: int) -> bool:
"""
Determine if a scan buffer is the default buffer.
Parameters:
x (TensorVariable): The tensor variable to check.
taps (int): The number of taps (time steps) associated with the buffer.
Must be correctly supplied by the caller to ensure accurate checks.
Returns:
bool: True if the buffer is the default scan buffer, False otherwise.
"""

Copilot uses AI. Check for mistakes.

node = x.owner

if node is None:
Expand Down Expand Up @@ -1218,7 +1218,7 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# But due to laziness we use the slightly more conservative check:
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
if broadcasted_by(y, x):
if (taps > 1) and broadcasted_by(y, x):
return False

return True
Expand Down Expand Up @@ -1574,15 +1574,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
Copy link
Preview

Copilot AI May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that deriving 'taps' from init_l[i] accurately reflects the intended tap count, and that this value is consistently used to compute extra_size in buffer expansion.

Suggested change
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
# Validate init_l[i] before using it to derive taps
if not isinstance(init_l[i], int) or init_l[i] < 0:
raise ValueError(f"Invalid tap count in init_l[{i}]: {init_l[i]}")

Copilot uses AI. Check for mistakes.

taps = init_l[i]
nw_input = nw_inputs[offset + idx]

# Recreate default buffers with new size
if _is_default_scan_buffer(nw_input):
extra_size = 1 if required_orphan else val - init_l[i]
if _is_default_scan_buffer(nw_input, taps):
extra_size = 1 if required_orphan else val - taps
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
# Otherwise, just trim with a slice
else:
stop = init_l[i] if required_orphan else val
stop = taps if required_orphan else val
nw_input = nw_input[:stop]

nw_inputs[offset + idx] = nw_input
Expand Down Expand Up @@ -1626,14 +1627,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
taps = init_l[op_info.n_mit_mot + idx]
Copy link
Preview

Copilot AI May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirm that updating the slice boundary to use 'taps' (instead of init_l) maintains the intended behavior for buffer trimming in ScanSaveMem.

Suggested change
taps = init_l[op_info.n_mit_mot + idx]
taps = taps[op_info.n_mit_mot + idx]

Copilot uses AI. Check for mistakes.

in_idx = offset + idx
nw_input = nw_inputs[in_idx]
if _is_default_scan_buffer(nw_input):
if _is_default_scan_buffer(nw_input, taps):
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
else:
# Number of steps in the initial state
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
nw_input = nw_input[: (init_l_pt + nw_steps)]
nw_input = nw_input[: (taps + nw_steps)]
nw_inputs[in_idx] = nw_input

elif (
Expand Down
61 changes: 49 additions & 12 deletions tests/scan/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.basic import Constant, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until
from pytensor.tensor import stack
from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot, dot, sigmoid, tanh
Expand Down Expand Up @@ -1207,7 +1208,7 @@ def test_inplace3(self):


class TestSaveMem:
mode = get_default_mode().including("scan_save_mem")
mode = get_default_mode().including("scan_save_mem").excluding("scan_pushout")
Copy link
Preview

Copilot AI May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirm that excluding 'scan_pushout' aligns with the intended optimization behavior and does not conflict with other scan optimizations.

Copilot uses AI. Check for mistakes.


def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed())
Expand Down Expand Up @@ -1371,7 +1372,7 @@ def test_save_mem_cannot_reduce_constant_number_of_steps(self):
)

def test_save_mem_store_steps(self):
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
return (
u_t + 1.0,
u_t + 2.0,
Expand All @@ -1388,7 +1389,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
x30 = vector("x30")
x40 = scalar("x40")
[x1, x2, x3, x4, x5, x6, x7], updates = scan(
f_rnn,
step,
u,
[
None,
Expand All @@ -1404,7 +1405,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
go_backwards=False,
)

f2 = function(
f = function(
[u, x10, x20, x30, x40],
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
updates=updates,
Expand All @@ -1417,13 +1418,49 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
v_u = rng.uniform(-5.0, 5.0, size=(20,))

# compute the output in numpy
tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0)

utt.assert_allclose(tx1, v_u[-7] + 1.0)
utt.assert_allclose(tx2, v_u[-3:-1] + 2.0)
utt.assert_allclose(tx3, v_u[-6:] + 3.0)
utt.assert_allclose(tx4, v_u[-1] + 4.0)
utt.assert_allclose(tx5, v_u[-1] + 5.0)
tx1, tx2, tx3, tx4, tx5 = f(v_u, [0, 0], 0, [0, 0], 0)
np.testing.assert_allclose(tx1, v_u[-7] + 1.0)
np.testing.assert_allclose(tx2, v_u[-3:-1] + 2.0)
np.testing.assert_allclose(tx3, v_u[-6:] + 3.0)
np.testing.assert_allclose(tx4, v_u[-1] + 4.0)
np.testing.assert_allclose(tx5, v_u[-1] + 5.0)

# Confirm reduction in buffer sizes
[scan_node] = [
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
# x6 and x7 are dropped because they are not used
[n_steps, seq, x4_buffer, x5_buffer, x1_len, x2_len, x3_len] = scan_node.inputs
[x4_underlying_alloc] = [
var
for var in ancestors([x4_buffer])
if var.owner and isinstance(var.owner.op, AllocEmpty)
]
[x5_underlying_alloc] = [
var
for var in ancestors([x5_buffer])
if var.owner and isinstance(var.owner.op, AllocEmpty)
]
buffer_lengths = pytensor.function(
[u, x10, x20, x30, x40],
[
x1_len,
x2_len,
x3_len,
x4_underlying_alloc.shape[0],
x5_underlying_alloc.shape[0],
],
accept_inplace=True,
on_unused_input="ignore",
)(v_u, [0, 0], 0, [0, 0], 0)
# ScanSaveMem keeps +1 entries to handle taps with preallocated outputs
assert [int(i) for i in buffer_lengths] == [
7, # entry -7 of a map variable is kept, we need at least that many
3, # entries [-3, -2] of a map variable are kept, we need at least 3
6, # last six entries of a map variable are kept
2 + 1, # last entry of a double tap variable is kept
1 + 1, # last entry of a single tap variable is kept
]

def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = pt.ones(())
Expand Down
Loading