-
Notifications
You must be signed in to change notification settings - Fork 129
Avoid default allocation for taps of length 1 in ScanSaveMem #1395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1186,7 +1186,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): | |||||||||||
return subtensor_merge_replacements | ||||||||||||
|
||||||||||||
|
||||||||||||
def _is_default_scan_buffer(x: TensorVariable) -> bool: | ||||||||||||
def _is_default_scan_buffer(x: TensorVariable, taps: int) -> bool: | ||||||||||||
node = x.owner | ||||||||||||
|
||||||||||||
if node is None: | ||||||||||||
|
@@ -1218,7 +1218,7 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool: | |||||||||||
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable | ||||||||||||
# But due to laziness we use the slightly more conservative check: | ||||||||||||
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable | ||||||||||||
if broadcasted_by(y, x): | ||||||||||||
if (taps > 1) and broadcasted_by(y, x): | ||||||||||||
return False | ||||||||||||
|
||||||||||||
return True | ||||||||||||
|
@@ -1574,15 +1574,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: | |||||||||||
# If the memory for this output has been pre-allocated | ||||||||||||
# before going into the scan op (by an alloc node) | ||||||||||||
if idx < op_info.n_mit_sot + op_info.n_sit_sot: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verify that deriving 'taps' from init_l[i] accurately reflects the intended tap count, and that this value is consistently used to compute extra_size in buffer expansion.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||
taps = init_l[i] | ||||||||||||
nw_input = nw_inputs[offset + idx] | ||||||||||||
|
||||||||||||
# Recreate default buffers with new size | ||||||||||||
if _is_default_scan_buffer(nw_input): | ||||||||||||
extra_size = 1 if required_orphan else val - init_l[i] | ||||||||||||
if _is_default_scan_buffer(nw_input, taps): | ||||||||||||
extra_size = 1 if required_orphan else val - taps | ||||||||||||
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size) | ||||||||||||
# Otherwise, just trim with a slice | ||||||||||||
else: | ||||||||||||
stop = init_l[i] if required_orphan else val | ||||||||||||
stop = taps if required_orphan else val | ||||||||||||
nw_input = nw_input[:stop] | ||||||||||||
|
||||||||||||
nw_inputs[offset + idx] = nw_input | ||||||||||||
|
@@ -1626,14 +1627,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: | |||||||||||
# val == 0 means that we want to keep all intermediate | ||||||||||||
# results for that state, including the initial values. | ||||||||||||
if idx < op_info.n_mit_sot + op_info.n_sit_sot: | ||||||||||||
taps = init_l[op_info.n_mit_mot + idx] | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirm that updating the slice boundary to use 'taps' (instead of init_l) maintains the intended behavior for buffer trimming in ScanSaveMem.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||
in_idx = offset + idx | ||||||||||||
nw_input = nw_inputs[in_idx] | ||||||||||||
if _is_default_scan_buffer(nw_input): | ||||||||||||
if _is_default_scan_buffer(nw_input, taps): | ||||||||||||
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps) | ||||||||||||
else: | ||||||||||||
# Number of steps in the initial state | ||||||||||||
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx]) | ||||||||||||
nw_input = nw_input[: (init_l_pt + nw_steps)] | ||||||||||||
nw_input = nw_input[: (taps + nw_steps)] | ||||||||||||
nw_inputs[in_idx] = nw_input | ||||||||||||
|
||||||||||||
elif ( | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,14 @@ | |
from pytensor.compile.mode import get_default_mode | ||
from pytensor.configdefaults import config | ||
from pytensor.gradient import grad, jacobian | ||
from pytensor.graph.basic import Constant, equal_computations | ||
from pytensor.graph.basic import Constant, ancestors, equal_computations | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.replace import clone_replace | ||
from pytensor.scan.op import Scan | ||
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge | ||
from pytensor.scan.utils import until | ||
from pytensor.tensor import stack | ||
from pytensor.tensor.basic import AllocEmpty | ||
from pytensor.tensor.blas import Dot22 | ||
from pytensor.tensor.elemwise import Elemwise | ||
from pytensor.tensor.math import Dot, dot, sigmoid, tanh | ||
|
@@ -1207,7 +1208,7 @@ def test_inplace3(self): | |
|
||
|
||
class TestSaveMem: | ||
mode = get_default_mode().including("scan_save_mem") | ||
mode = get_default_mode().including("scan_save_mem").excluding("scan_pushout") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirm that excluding 'scan_pushout' aligns with the intended optimization behavior and does not conflict with other scan optimizations. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
|
||
def test_save_mem(self): | ||
rng = np.random.default_rng(utt.fetch_seed()) | ||
|
@@ -1371,7 +1372,7 @@ def test_save_mem_cannot_reduce_constant_number_of_steps(self): | |
) | ||
|
||
def test_save_mem_store_steps(self): | ||
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): | ||
def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): | ||
return ( | ||
u_t + 1.0, | ||
u_t + 2.0, | ||
|
@@ -1388,7 +1389,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): | |
x30 = vector("x30") | ||
x40 = scalar("x40") | ||
[x1, x2, x3, x4, x5, x6, x7], updates = scan( | ||
f_rnn, | ||
step, | ||
u, | ||
[ | ||
None, | ||
|
@@ -1404,7 +1405,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): | |
go_backwards=False, | ||
) | ||
|
||
f2 = function( | ||
f = function( | ||
[u, x10, x20, x30, x40], | ||
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]], | ||
updates=updates, | ||
|
@@ -1417,13 +1418,49 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): | |
v_u = rng.uniform(-5.0, 5.0, size=(20,)) | ||
|
||
# compute the output in numpy | ||
tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0) | ||
|
||
utt.assert_allclose(tx1, v_u[-7] + 1.0) | ||
utt.assert_allclose(tx2, v_u[-3:-1] + 2.0) | ||
utt.assert_allclose(tx3, v_u[-6:] + 3.0) | ||
utt.assert_allclose(tx4, v_u[-1] + 4.0) | ||
utt.assert_allclose(tx5, v_u[-1] + 5.0) | ||
tx1, tx2, tx3, tx4, tx5 = f(v_u, [0, 0], 0, [0, 0], 0) | ||
np.testing.assert_allclose(tx1, v_u[-7] + 1.0) | ||
np.testing.assert_allclose(tx2, v_u[-3:-1] + 2.0) | ||
np.testing.assert_allclose(tx3, v_u[-6:] + 3.0) | ||
np.testing.assert_allclose(tx4, v_u[-1] + 4.0) | ||
np.testing.assert_allclose(tx5, v_u[-1] + 5.0) | ||
|
||
# Confirm reduction in buffer sizes | ||
[scan_node] = [ | ||
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) | ||
] | ||
# x6 and x7 are dropped because they are not used | ||
[n_steps, seq, x4_buffer, x5_buffer, x1_len, x2_len, x3_len] = scan_node.inputs | ||
[x4_underlying_alloc] = [ | ||
var | ||
for var in ancestors([x4_buffer]) | ||
if var.owner and isinstance(var.owner.op, AllocEmpty) | ||
] | ||
[x5_underlying_alloc] = [ | ||
var | ||
for var in ancestors([x5_buffer]) | ||
if var.owner and isinstance(var.owner.op, AllocEmpty) | ||
] | ||
buffer_lengths = pytensor.function( | ||
[u, x10, x20, x30, x40], | ||
[ | ||
x1_len, | ||
x2_len, | ||
x3_len, | ||
x4_underlying_alloc.shape[0], | ||
x5_underlying_alloc.shape[0], | ||
], | ||
accept_inplace=True, | ||
on_unused_input="ignore", | ||
)(v_u, [0, 0], 0, [0, 0], 0) | ||
# ScanSaveMem keeps +1 entries to handle taps with preallocated outputs | ||
assert [int(i) for i in buffer_lengths] == [ | ||
7, # entry -7 of a map variable is kept, we need at least that many | ||
3, # entries [-3, -2] of a map variable are kept, we need at least 3 | ||
6, # last six entries of a map variable are kept | ||
2 + 1, # last entry of a double tap variable is kept | ||
1 + 1, # last entry of a single tap variable is kept | ||
] | ||
|
||
def test_savemem_does_not_duplicate_number_of_scan_nodes(self): | ||
var = pt.ones(()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure that all callers of _is_default_scan_buffer supply the correct 'taps' value so that the default buffer check correctly distinguishes between single and multiple taps.
Copilot uses AI. Check for mistakes.