From e94b42d0f2b704cd8e293ac372d4980061d5a870 Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Sat, 25 Jan 2020 22:23:30 +0800 Subject: [PATCH 01/13] enhanced segment tree implementation and more pythonic enhanced segment tree implementation and more pythonic --- .../binary_tree/segment_tree_other.py | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 data_structures/binary_tree/segment_tree_other.py diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py new file mode 100644 index 000000000000..b0cf9da380fb --- /dev/null +++ b/data_structures/binary_tree/segment_tree_other.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from queue import Queue + + +class SegmentTreeNode(object): + def __init__(self, start, end, val, left=None, right=None): + self.start = start + self.end = end + self.val = val + self.mid = (start + end) // 2 + self.left = left + self.right = right + + def __str__(self): + return 'val: %s, start: %s, end: %s' % (self.val, self.start, self.end) + + +class NumArray: + def __init__(self, nums): + self.nums = nums + if self.nums: + self.root = self._build_tree(0, len(nums) - 1) + + def update(self, i, val): + self._update_tree(self.root, i, val) + + def sum_range(self, i, j): + return self._sum_range(self.root, i, j) + + def _build_tree(self, start, end): + if start == end: + return SegmentTreeNode(start, end, self.nums[start]) + mid = (start + end) // 2 + left = self._build_tree(start, mid) + right = self._build_tree(mid + 1, end) + return SegmentTreeNode(start, end, left.val + right.val, left, right) + + def _update_tree(self, root, i, val): + if root.start == i and root.end == i: + root.val = val + return + if i <= root.mid: + self._update_tree(root.left, i, val) + else: + self._update_tree(root.right, i, val) + root.val = root.left.val + root.right.val + + def _sum_range(self, root, i, j): + if root.start == i and root.end == j: + return root.val + """ + [i, j] [i, j] [i, j] + [start mid] [mid+1 end] + """ + if j <= root.mid: + return self._sum_range(root.left, i, j) + elif i > root.mid: + return self._sum_range(root.right, i, j) + else: + return self._sum_range(root.left, i, root.mid) + self._sum_range(root.right, root.mid + 1, j) + + def traverse(self): + result = [] + if self.root is not None: + queue = Queue() + queue.put(self.root) + while not queue.empty(): + node = queue.get() + result.append(node) + + if node.left is not None: + queue.put(node.left) + + if node.right is not None: + queue.put(node.right) + return result + + +class MaxArray: + def __init__(self, nums): + self.nums = nums + if self.nums: + self.root = self._build_tree(0, len(nums) - 1) + + def update(self, i, val): + self._update_tree(self.root, i, val) + + def max_range(self, i, j): + return self._max_range(self.root, i, j) + + def _build_tree(self, start, end): + if start == end: + return SegmentTreeNode(start, end, self.nums[start]) + mid = (start + end) // 2 + left = self._build_tree(start, mid) + right = self._build_tree(mid + 1, end) + return SegmentTreeNode(start, end, max([left.val, right.val]), left, right) + + def _update_tree(self, root, i, val): + if root.start == i and root.end == i: + root.val = val + return + if i <= root.mid: + self._update_tree(root.left, i, val) + else: + self._update_tree(root.right, i, val) + root.val = max([root.left.val, root.right.val]) + + def _max_range(self, root, i, j): + if root.start == i and root.end == j: + return root.val + """ + [i, j] [i, j] [i, j] + [start mid] [mid+1 end] + """ + if j <= root.mid: + return self._max_range(root.left, i, j) + elif i > root.mid: + return self._max_range(root.right, i, j) + else: + return max([self._max_range(root.left, i, root.mid), self._max_range(root.right, root.mid + 1, j)]) + + def traverse(self): + result = [] + if self.root is not None: + queue = Queue() + queue.put(self.root) + while not queue.empty(): + node = queue.get() + result.append(node) + + if node.left is not None: + queue.put(node.left) + + if node.right is not None: + queue.put(node.right) + return result + + +class MinArray: + def __init__(self, nums): + self.nums = nums + if self.nums: + self.root = self._build_tree(0, len(nums) - 1) + + def update(self, i, val): + self._update_tree(self.root, i, val) + + def min_range(self, i, j): + return self._min_range(self.root, i, j) + + def _build_tree(self, start, end): + if start == end: + return SegmentTreeNode(start, end, self.nums[start]) + mid = (start + end) // 2 + left = self._build_tree(start, mid) + right = self._build_tree(mid + 1, end) + return SegmentTreeNode(start, end, min([left.val, right.val]), left, right) + + def _update_tree(self, root, i, val): + if root.start == i and root.end == i: + root.val = val + return + if i <= root.mid: + self._update_tree(root.left, i, val) + else: + self._update_tree(root.right, i, val) + root.val = min([root.left.val, root.right.val]) + + def _min_range(self, root, i, j): + if root.start == i and root.end == j: + return root.val + """ + [i, j] [i, j] [i, j] + [start mid] [mid+1 end] + """ + if j <= root.mid: + return self._min_range(root.left, i, j) + elif i > root.mid: + return self._min_range(root.right, i, j) + else: + return min([self._min_range(root.left, i, root.mid), self._min_range(root.right, root.mid + 1, j)]) + + def traverse(self): + result = [] + if self.root is not None: + queue = Queue() + queue.put(self.root) + while not queue.empty(): + node = queue.get() + result.append(node) + + if node.left is not None: + queue.put(node.left) + + if node.right is not None: + queue.put(node.right) + return result + + +if __name__ == '__main__': + print('求和线段树') + num_arr = NumArray([2, 1, 5, 3, 4]) + for node in num_arr.traverse(): + print(node) + print() + + num_arr.update(1, 5) + for node in num_arr.traverse(): + print(node) + print() + + print(num_arr.sum_range(3, 4)) # 7 + print(num_arr.sum_range(2, 2)) # 5 + print(num_arr.sum_range(1, 3)) # 13 + + print() + print('求最大值线段树') + max_arr = MaxArray([2, 1, 5, 3, 4]) + for node in max_arr.traverse(): + print(node) + print() + + max_arr.update(1, 5) + for node in max_arr.traverse(): + print(node) + print() + + print(max_arr.max_range(3, 4)) # 4 + print(max_arr.max_range(2, 2)) # 5 + print(max_arr.max_range(1, 3)) # 5 + + print() + print('求最小值线段树') + min_arr = MinArray([2, 1, 5, 3, 4]) + for node in min_arr.traverse(): + print(node) + print() + + min_arr.update(1, 5) + for node in min_arr.traverse(): + print(node) + print() + + print(min_arr.min_range(3, 4)) # 3 + print(min_arr.min_range(2, 2)) # 5 + print(min_arr.min_range(1, 3)) # 3 From 2880d47ce9bcb396f32a794e3aabee26d3709eaf Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Sun, 26 Jan 2020 10:08:50 +0800 Subject: [PATCH 02/13] add doctests for segment tree --- .../binary_tree/segment_tree_other.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index b0cf9da380fb..496a415cbc1c 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -18,6 +18,43 @@ def __str__(self): class NumArray: + """NumArray is sum tree of object. Parent node is sum of two child nodes. + >>> num_arr = NumArray([2, 1, 5, 3, 4]) + >>> for node in num_arr.traverse(): + ... print(node) + ... + val: 15, start: 0, end: 4 + val: 8, start: 0, end: 2 + val: 7, start: 3, end: 4 + val: 3, start: 0, end: 1 + val: 5, start: 2, end: 2 + val: 3, start: 3, end: 3 + val: 4, start: 4, end: 4 + val: 2, start: 0, end: 0 + val: 1, start: 1, end: 1 + >>> + >>> num_arr.update(1, 5) + >>> for node in num_arr.traverse(): + ... print(node) + ... + val: 19, start: 0, end: 4 + val: 12, start: 0, end: 2 + val: 7, start: 3, end: 4 + val: 7, start: 0, end: 1 + val: 5, start: 2, end: 2 + val: 3, start: 3, end: 3 + val: 4, start: 4, end: 4 + val: 2, start: 0, end: 0 + val: 5, start: 1, end: 1 + >>> + >>> num_arr.sum_range(3, 4) + 7 + >>> num_arr.sum_range(2, 2) + 5 + >>> num_arr.sum_range(1, 3) + 13 + >>> + """ def __init__(self, nums): self.nums = nums if self.nums: @@ -79,6 +116,43 @@ def traverse(self): class MaxArray: + """MaxArray is max tree of object. Parent node is max of two child nodes. + >>> max_arr = MaxArray([2, 1, 5, 3, 4]) + >>> for node in max_arr.traverse(): + ... print(node) + ... + val: 5, start: 0, end: 4 + val: 5, start: 0, end: 2 + val: 4, start: 3, end: 4 + val: 2, start: 0, end: 1 + val: 5, start: 2, end: 2 + val: 3, start: 3, end: 3 + val: 4, start: 4, end: 4 + val: 2, start: 0, end: 0 + val: 1, start: 1, end: 1 + >>> + >>> max_arr.update(1, 5) + >>> for node in max_arr.traverse(): + ... print(node) + ... + val: 5, start: 0, end: 4 + val: 5, start: 0, end: 2 + val: 4, start: 3, end: 4 + val: 5, start: 0, end: 1 + val: 5, start: 2, end: 2 + val: 3, start: 3, end: 3 + val: 4, start: 4, end: 4 + val: 2, start: 0, end: 0 + val: 5, start: 1, end: 1 + >>> + >>> max_arr.max_range(3, 4) + 4 + >>> max_arr.max_range(2, 2) + 5 + >>> max_arr.max_range(1, 3) + 5 + >>> + """ def __init__(self, nums): self.nums = nums if self.nums: @@ -140,6 +214,43 @@ def traverse(self): class MinArray: + """MinArray is min tree of object. Parent node is min of two child nodes. + >>> min_arr = MinArray([2, 1, 5, 3, 4]) + >>> for node in min_arr.traverse(): + ... print(node) + ... + val: 1, start: 0, end: 4 + val: 1, start: 0, end: 2 + val: 3, start: 3, end: 4 + val: 1, start: 0, end: 1 + val: 5, start: 2, end: 2 + val: 3, start: 3, end: 3 + val: 4, start: 4, end: 4 + val: 2, start: 0, end: 0 + val: 1, start: 1, end: 1 + >>> + >>> min_arr.update(1, 5) + >>> for node in min_arr.traverse(): + ... print(node) + ... + val: 2, start: 0, end: 4 + val: 2, start: 0, end: 2 + val: 3, start: 3, end: 4 + val: 2, start: 0, end: 1 + val: 5, start: 2, end: 2 + val: 3, start: 3, end: 3 + val: 4, start: 4, end: 4 + val: 2, start: 0, end: 0 + val: 5, start: 1, end: 1 + >>> + >>> min_arr.min_range(3, 4) + 3 + >>> min_arr.min_range(2, 2) + 5 + >>> min_arr.min_range(1, 3) + 3 + >>> + """ def __init__(self, nums): self.nums = nums if self.nums: From 030a55d9d84313fc6963a19f94793a017bc2ef9b Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Sun, 26 Jan 2020 10:22:16 +0800 Subject: [PATCH 03/13] add type annotations --- .../binary_tree/segment_tree_other.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 496a415cbc1c..3d2c5b86e3ca 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- from queue import Queue +from collections.abc import Sequence class SegmentTreeNode(object): @@ -55,10 +56,10 @@ class NumArray: 13 >>> """ - def __init__(self, nums): - self.nums = nums - if self.nums: - self.root = self._build_tree(0, len(nums) - 1) + def __init__(self, collection: Sequence): + self.collection = collection + if self.collection: + self.root = self._build_tree(0, len(collection) - 1) def update(self, i, val): self._update_tree(self.root, i, val) @@ -68,7 +69,7 @@ def sum_range(self, i, j): def _build_tree(self, start, end): if start == end: - return SegmentTreeNode(start, end, self.nums[start]) + return SegmentTreeNode(start, end, self.collection[start]) mid = (start + end) // 2 left = self._build_tree(start, mid) right = self._build_tree(mid + 1, end) @@ -153,10 +154,10 @@ class MaxArray: 5 >>> """ - def __init__(self, nums): - self.nums = nums - if self.nums: - self.root = self._build_tree(0, len(nums) - 1) + def __init__(self, collection: Sequence): + self.collection = collection + if self.collection: + self.root = self._build_tree(0, len(collection) - 1) def update(self, i, val): self._update_tree(self.root, i, val) @@ -166,7 +167,7 @@ def max_range(self, i, j): def _build_tree(self, start, end): if start == end: - return SegmentTreeNode(start, end, self.nums[start]) + return SegmentTreeNode(start, end, self.collection[start]) mid = (start + end) // 2 left = self._build_tree(start, mid) right = self._build_tree(mid + 1, end) @@ -251,10 +252,10 @@ class MinArray: 3 >>> """ - def __init__(self, nums): - self.nums = nums - if self.nums: - self.root = self._build_tree(0, len(nums) - 1) + def __init__(self, collection: Sequence): + self.collection = collection + if self.collection: + self.root = self._build_tree(0, len(collection) - 1) def update(self, i, val): self._update_tree(self.root, i, val) @@ -264,7 +265,7 @@ def min_range(self, i, j): def _build_tree(self, start, end): if start == end: - return SegmentTreeNode(start, end, self.nums[start]) + return SegmentTreeNode(start, end, self.collection[start]) mid = (start + end) // 2 left = self._build_tree(start, mid) right = self._build_tree(mid + 1, end) From e0c1b062a3587a2a700a5d04f110a62e486cafb8 Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Sun, 26 Jan 2020 18:56:59 +0800 Subject: [PATCH 04/13] unified processing sum min max segment tre --- .../binary_tree/segment_tree_other.py | 243 ++++-------------- 1 file changed, 48 insertions(+), 195 deletions(-) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 3d2c5b86e3ca..69903d38ccc4 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -1,5 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +""" +Segment_tree creates a segment tree with a given array and function, +allowing queries to be done later in log(N) time +function takes 2 values and returns a same type value +""" from queue import Queue from collections.abc import Sequence @@ -18,9 +23,10 @@ def __str__(self): return 'val: %s, start: %s, end: %s' % (self.val, self.start, self.end) -class NumArray: - """NumArray is sum tree of object. Parent node is sum of two child nodes. - >>> num_arr = NumArray([2, 1, 5, 3, 4]) +class SegmentTree(object): + """ + >>> import operator + >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) >>> for node in num_arr.traverse(): ... print(node) ... @@ -48,77 +54,14 @@ class NumArray: val: 2, start: 0, end: 0 val: 5, start: 1, end: 1 >>> - >>> num_arr.sum_range(3, 4) + >>> num_arr.query_range(3, 4) 7 - >>> num_arr.sum_range(2, 2) + >>> num_arr.query_range(2, 2) 5 - >>> num_arr.sum_range(1, 3) + >>> num_arr.query_range(1, 3) 13 >>> - """ - def __init__(self, collection: Sequence): - self.collection = collection - if self.collection: - self.root = self._build_tree(0, len(collection) - 1) - - def update(self, i, val): - self._update_tree(self.root, i, val) - - def sum_range(self, i, j): - return self._sum_range(self.root, i, j) - - def _build_tree(self, start, end): - if start == end: - return SegmentTreeNode(start, end, self.collection[start]) - mid = (start + end) // 2 - left = self._build_tree(start, mid) - right = self._build_tree(mid + 1, end) - return SegmentTreeNode(start, end, left.val + right.val, left, right) - - def _update_tree(self, root, i, val): - if root.start == i and root.end == i: - root.val = val - return - if i <= root.mid: - self._update_tree(root.left, i, val) - else: - self._update_tree(root.right, i, val) - root.val = root.left.val + root.right.val - - def _sum_range(self, root, i, j): - if root.start == i and root.end == j: - return root.val - """ - [i, j] [i, j] [i, j] - [start mid] [mid+1 end] - """ - if j <= root.mid: - return self._sum_range(root.left, i, j) - elif i > root.mid: - return self._sum_range(root.right, i, j) - else: - return self._sum_range(root.left, i, root.mid) + self._sum_range(root.right, root.mid + 1, j) - - def traverse(self): - result = [] - if self.root is not None: - queue = Queue() - queue.put(self.root) - while not queue.empty(): - node = queue.get() - result.append(node) - - if node.left is not None: - queue.put(node.left) - - if node.right is not None: - queue.put(node.right) - return result - - -class MaxArray: - """MaxArray is max tree of object. Parent node is max of two child nodes. - >>> max_arr = MaxArray([2, 1, 5, 3, 4]) + >>> max_arr = SegmentTree([2, 1, 5, 3, 4], max) >>> for node in max_arr.traverse(): ... print(node) ... @@ -146,77 +89,14 @@ class MaxArray: val: 2, start: 0, end: 0 val: 5, start: 1, end: 1 >>> - >>> max_arr.max_range(3, 4) + >>> max_arr.query_range(3, 4) 4 - >>> max_arr.max_range(2, 2) + >>> max_arr.query_range(2, 2) 5 - >>> max_arr.max_range(1, 3) + >>> max_arr.query_range(1, 3) 5 >>> - """ - def __init__(self, collection: Sequence): - self.collection = collection - if self.collection: - self.root = self._build_tree(0, len(collection) - 1) - - def update(self, i, val): - self._update_tree(self.root, i, val) - - def max_range(self, i, j): - return self._max_range(self.root, i, j) - - def _build_tree(self, start, end): - if start == end: - return SegmentTreeNode(start, end, self.collection[start]) - mid = (start + end) // 2 - left = self._build_tree(start, mid) - right = self._build_tree(mid + 1, end) - return SegmentTreeNode(start, end, max([left.val, right.val]), left, right) - - def _update_tree(self, root, i, val): - if root.start == i and root.end == i: - root.val = val - return - if i <= root.mid: - self._update_tree(root.left, i, val) - else: - self._update_tree(root.right, i, val) - root.val = max([root.left.val, root.right.val]) - - def _max_range(self, root, i, j): - if root.start == i and root.end == j: - return root.val - """ - [i, j] [i, j] [i, j] - [start mid] [mid+1 end] - """ - if j <= root.mid: - return self._max_range(root.left, i, j) - elif i > root.mid: - return self._max_range(root.right, i, j) - else: - return max([self._max_range(root.left, i, root.mid), self._max_range(root.right, root.mid + 1, j)]) - - def traverse(self): - result = [] - if self.root is not None: - queue = Queue() - queue.put(self.root) - while not queue.empty(): - node = queue.get() - result.append(node) - - if node.left is not None: - queue.put(node.left) - - if node.right is not None: - queue.put(node.right) - return result - - -class MinArray: - """MinArray is min tree of object. Parent node is min of two child nodes. - >>> min_arr = MinArray([2, 1, 5, 3, 4]) + >>> min_arr = SegmentTree([2, 1, 5, 3, 4], min) >>> for node in min_arr.traverse(): ... print(node) ... @@ -244,24 +124,26 @@ class MinArray: val: 2, start: 0, end: 0 val: 5, start: 1, end: 1 >>> - >>> min_arr.min_range(3, 4) + >>> min_arr.query_range(3, 4) 3 - >>> min_arr.min_range(2, 2) + >>> min_arr.query_range(2, 2) 5 - >>> min_arr.min_range(1, 3) + >>> min_arr.query_range(1, 3) 3 >>> + """ - def __init__(self, collection: Sequence): + def __init__(self, collection: Sequence, function): self.collection = collection + self.fn = function if self.collection: self.root = self._build_tree(0, len(collection) - 1) def update(self, i, val): self._update_tree(self.root, i, val) - def min_range(self, i, j): - return self._min_range(self.root, i, j) + def query_range(self, i, j): + return self._query_range(self.root, i, j) def _build_tree(self, start, end): if start == end: @@ -269,7 +151,7 @@ def _build_tree(self, start, end): mid = (start + end) // 2 left = self._build_tree(start, mid) right = self._build_tree(mid + 1, end) - return SegmentTreeNode(start, end, min([left.val, right.val]), left, right) + return SegmentTreeNode(start, end, self.fn(left.val, right.val), left, right) def _update_tree(self, root, i, val): if root.start == i and root.end == i: @@ -279,9 +161,9 @@ def _update_tree(self, root, i, val): self._update_tree(root.left, i, val) else: self._update_tree(root.right, i, val) - root.val = min([root.left.val, root.right.val]) + root.val = self.fn(root.left.val, root.right.val) - def _min_range(self, root, i, j): + def _query_range(self, root, i, j): if root.start == i and root.end == j: return root.val """ @@ -289,11 +171,11 @@ def _min_range(self, root, i, j): [start mid] [mid+1 end] """ if j <= root.mid: - return self._min_range(root.left, i, j) + return self._query_range(root.left, i, j) elif i > root.mid: - return self._min_range(root.right, i, j) + return self._query_range(root.right, i, j) else: - return min([self._min_range(root.left, i, root.mid), self._min_range(root.right, root.mid + 1, j)]) + return self.fn(self._query_range(root.left, i, root.mid), self._query_range(root.right, root.mid + 1, j)) def traverse(self): result = [] @@ -313,49 +195,20 @@ def traverse(self): if __name__ == '__main__': - print('求和线段树') - num_arr = NumArray([2, 1, 5, 3, 4]) - for node in num_arr.traverse(): - print(node) - print() - - num_arr.update(1, 5) - for node in num_arr.traverse(): - print(node) - print() - - print(num_arr.sum_range(3, 4)) # 7 - print(num_arr.sum_range(2, 2)) # 5 - print(num_arr.sum_range(1, 3)) # 13 - - print() - print('求最大值线段树') - max_arr = MaxArray([2, 1, 5, 3, 4]) - for node in max_arr.traverse(): - print(node) - print() - - max_arr.update(1, 5) - for node in max_arr.traverse(): - print(node) - print() - - print(max_arr.max_range(3, 4)) # 4 - print(max_arr.max_range(2, 2)) # 5 - print(max_arr.max_range(1, 3)) # 5 - - print() - print('求最小值线段树') - min_arr = MinArray([2, 1, 5, 3, 4]) - for node in min_arr.traverse(): - print(node) - print() - - min_arr.update(1, 5) - for node in min_arr.traverse(): - print(node) - print() - - print(min_arr.min_range(3, 4)) # 3 - print(min_arr.min_range(2, 2)) # 5 - print(min_arr.min_range(1, 3)) # 3 + import operator + for fn in [operator.add, max, min]: + print('*' * 50) + arr = SegmentTree([2, 1, 5, 3, 4], fn) + for node in arr.traverse(): + print(node) + print() + + arr.update(1, 5) + for node in arr.traverse(): + print(node) + print() + + print(arr.query_range(3, 4)) # 7 + print(arr.query_range(2, 2)) # 5 + print(arr.query_range(1, 3)) # 13 + print() From 1e2b002458c3b58aed1226b42eec7c644781de1f Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Mon, 27 Jan 2020 14:29:13 +0800 Subject: [PATCH 05/13] delete source encoding in segment tree --- data_structures/binary_tree/segment_tree_other.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 69903d38ccc4..d743ac8a7a3f 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Segment_tree creates a segment tree with a given array and function, allowing queries to be done later in log(N) time From 78134bf1d1009087e236019e44866ac69b08baca Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Mon, 27 Jan 2020 18:58:15 +0800 Subject: [PATCH 06/13] use a generator function instead of returning --- data_structures/binary_tree/segment_tree_other.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index d743ac8a7a3f..822b65808246 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -176,20 +176,18 @@ def _query_range(self, root, i, j): return self.fn(self._query_range(root.left, i, root.mid), self._query_range(root.right, root.mid + 1, j)) def traverse(self): - result = [] if self.root is not None: queue = Queue() queue.put(self.root) while not queue.empty(): node = queue.get() - result.append(node) + yield node if node.left is not None: queue.put(node.left) if node.right is not None: queue.put(node.right) - return result if __name__ == '__main__': From ac028db7cdcee3855712cba35c856b78b8559218 Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Mon, 27 Jan 2020 19:30:28 +0800 Subject: [PATCH 07/13] add doctests for methods --- .../binary_tree/segment_tree_other.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 822b65808246..6f14fa31e72d 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -138,9 +138,32 @@ def __init__(self, collection: Sequence, function): self.root = self._build_tree(0, len(collection) - 1) def update(self, i, val): + """ + update value in collection + :param i: index of collection + :param val: new value + :return: + >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) + >>> num_arr.update(1, 5) + """ self._update_tree(self.root, i, val) def query_range(self, i, j): + """ + Sum, Max, Min operation in intervals i and j([i, j]) + :param i: left index + :param j: right index + :return: Sum, Max, Min + >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) + >>> num_arr.update(1, 5) + >>> num_arr.query_range(3, 4) + 7 + >>> num_arr.query_range(2, 2) + 5 + >>> num_arr.query_range(1, 3) + 13 + >>> + """ return self._query_range(self.root, i, j) def _build_tree(self, start, end): From 6928b3287fac1598a96a1e31135e61bbae0e7187 Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Mon, 27 Jan 2020 19:45:26 +0800 Subject: [PATCH 08/13] add doctests for methods --- data_structures/binary_tree/segment_tree_other.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 6f14fa31e72d..36a1185acb0c 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -143,6 +143,7 @@ def update(self, i, val): :param i: index of collection :param val: new value :return: + >>> import operator >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) >>> num_arr.update(1, 5) """ @@ -154,6 +155,7 @@ def query_range(self, i, j): :param i: left index :param j: right index :return: Sum, Max, Min + >>> import operator >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) >>> num_arr.update(1, 5) >>> num_arr.query_range(3, 4) From 0508a9914b0c1d3d344ca274675e346d2861f192 Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Wed, 29 Jan 2020 14:12:07 +0800 Subject: [PATCH 09/13] add doctests --- .../binary_tree/segment_tree_other.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 36a1185acb0c..0f45f40f2af2 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -169,6 +169,40 @@ def query_range(self, i, j): return self._query_range(self.root, i, j) def _build_tree(self, start, end): + r"""build segment tree + collection: [2, 1, 5, 3, 4] + + _build_tree(0, 4) -> 15 = 8 + 7 + _build_tree(0, 2) -> 8 = 3 + 5 + _build_tree(0, 1) -> 3 = 2 + 1 + _build_tree(0, 0) -> 2 + _build_tree(1, 1) -> 1 + _build_tree(2, 2) -> 5 + _build_tree(3, 4) -> 7 = 3 + 4 + _build_tree(3, 3) -> 3 + _build_tree(4, 4) -> 4 + + 1. determine the interval of each node + + 0-4(2) + / \ + 0-2(1) 3-4(3) + / \ / \ + 0-1(0) 2-2 3-3 4-4 + / \ + 0-0 1-1 + + 2. determine the value of each node. + + 15 + / \ + 8 7 + / \ / \ + 3 5 3 4 + / \ + 2 1 + + """ if start == end: return SegmentTreeNode(start, end, self.collection[start]) mid = (start + end) // 2 @@ -177,6 +211,12 @@ def _build_tree(self, start, end): return SegmentTreeNode(start, end, self.fn(left.val, right.val), left, right) def _update_tree(self, root, i, val): + r"""update segment tree + _update_tree(15, 1, 5) -> update value of node with value of left child add value of right child + _update_tree(8, 1, 5) -> update value of node with value of left child add value of right child + _update_tree(3, 1, 5) -> update value of node with value of left child add value of right child + _update_tree(1, 1, 5) -> update value of leaf node + """ if root.start == i and root.end == i: root.val = val return @@ -193,10 +233,13 @@ def _query_range(self, root, i, j): [i, j] [i, j] [i, j] [start mid] [mid+1 end] """ + # interval in left child tree if j <= root.mid: return self._query_range(root.left, i, j) + # interval in child child tree elif i > root.mid: return self._query_range(root.right, i, j) + # interval in left child tree and right child tree else: return self.fn(self._query_range(root.left, i, root.mid), self._query_range(root.right, root.mid + 1, j)) From e6254a9ae6f9f6b459d7e59194ed3c31cdff04ba Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Thu, 30 Jan 2020 15:40:50 +0800 Subject: [PATCH 10/13] fix doctest --- .../binary_tree/segment_tree_other.py | 65 ++++--------------- 1 file changed, 11 insertions(+), 54 deletions(-) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 0f45f40f2af2..8250c82035b9 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -139,22 +139,23 @@ def __init__(self, collection: Sequence, function): def update(self, i, val): """ - update value in collection - :param i: index of collection + Update an element in log(N) time + :param i: position to be update :param val: new value - :return: >>> import operator >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) >>> num_arr.update(1, 5) + >>> max_arr.query_range(1, 3) + 5 """ self._update_tree(self.root, i, val) def query_range(self, i, j): """ - Sum, Max, Min operation in intervals i and j([i, j]) - :param i: left index - :param j: right index - :return: Sum, Max, Min + Get range query value in log(N) time + :param i: left element index + :param j: right element index + :return: element combined in the range [i, j] >>> import operator >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) >>> num_arr.update(1, 5) @@ -169,40 +170,6 @@ def query_range(self, i, j): return self._query_range(self.root, i, j) def _build_tree(self, start, end): - r"""build segment tree - collection: [2, 1, 5, 3, 4] - - _build_tree(0, 4) -> 15 = 8 + 7 - _build_tree(0, 2) -> 8 = 3 + 5 - _build_tree(0, 1) -> 3 = 2 + 1 - _build_tree(0, 0) -> 2 - _build_tree(1, 1) -> 1 - _build_tree(2, 2) -> 5 - _build_tree(3, 4) -> 7 = 3 + 4 - _build_tree(3, 3) -> 3 - _build_tree(4, 4) -> 4 - - 1. determine the interval of each node - - 0-4(2) - / \ - 0-2(1) 3-4(3) - / \ / \ - 0-1(0) 2-2 3-3 4-4 - / \ - 0-0 1-1 - - 2. determine the value of each node. - - 15 - / \ - 8 7 - / \ / \ - 3 5 3 4 - / \ - 2 1 - - """ if start == end: return SegmentTreeNode(start, end, self.collection[start]) mid = (start + end) // 2 @@ -211,12 +178,6 @@ def _build_tree(self, start, end): return SegmentTreeNode(start, end, self.fn(left.val, right.val), left, right) def _update_tree(self, root, i, val): - r"""update segment tree - _update_tree(15, 1, 5) -> update value of node with value of left child add value of right child - _update_tree(8, 1, 5) -> update value of node with value of left child add value of right child - _update_tree(3, 1, 5) -> update value of node with value of left child add value of right child - _update_tree(1, 1, 5) -> update value of leaf node - """ if root.start == i and root.end == i: root.val = val return @@ -229,17 +190,13 @@ def _update_tree(self, root, i, val): def _query_range(self, root, i, j): if root.start == i and root.end == j: return root.val - """ - [i, j] [i, j] [i, j] - [start mid] [mid+1 end] - """ - # interval in left child tree + # range in left child tree if j <= root.mid: return self._query_range(root.left, i, j) - # interval in child child tree + # range in child child tree elif i > root.mid: return self._query_range(root.right, i, j) - # interval in left child tree and right child tree + # range in left child tree and right child tree else: return self.fn(self._query_range(root.left, i, root.mid), self._query_range(root.right, root.mid + 1, j)) From 4197e8bb7b77801b9db3262ed7677e2399190067 Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Thu, 30 Jan 2020 16:46:49 +0800 Subject: [PATCH 11/13] fix doctest --- data_structures/binary_tree/segment_tree_other.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 8250c82035b9..2c3c1895fa17 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -145,7 +145,7 @@ def update(self, i, val): >>> import operator >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) >>> num_arr.update(1, 5) - >>> max_arr.query_range(1, 3) + >>> num_arr.query_range(1, 3) 5 """ self._update_tree(self.root, i, val) From d3365b152c4e0924925c6af548cf98e715369c87 Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Thu, 30 Jan 2020 17:09:35 +0800 Subject: [PATCH 12/13] fix doctest --- data_structures/binary_tree/segment_tree_other.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 2c3c1895fa17..7871efd23027 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -146,7 +146,7 @@ def update(self, i, val): >>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add) >>> num_arr.update(1, 5) >>> num_arr.query_range(1, 3) - 5 + 13 """ self._update_tree(self.root, i, val) From cc9da55b04093cb9713a408e0f87dd275a915921 Mon Sep 17 00:00:00 2001 From: lanzhiwang Date: Tue, 3 Mar 2020 18:17:24 +0800 Subject: [PATCH 13/13] fix function parameter and fix determine conditions --- .../binary_tree/segment_tree_other.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/data_structures/binary_tree/segment_tree_other.py b/data_structures/binary_tree/segment_tree_other.py index 7871efd23027..93b603cdc7a2 100644 --- a/data_structures/binary_tree/segment_tree_other.py +++ b/data_structures/binary_tree/segment_tree_other.py @@ -177,28 +177,30 @@ def _build_tree(self, start, end): right = self._build_tree(mid + 1, end) return SegmentTreeNode(start, end, self.fn(left.val, right.val), left, right) - def _update_tree(self, root, i, val): - if root.start == i and root.end == i: - root.val = val + def _update_tree(self, node, i, val): + if node.start == i and node.end == i: + node.val = val return - if i <= root.mid: - self._update_tree(root.left, i, val) + if i <= node.mid: + self._update_tree(node.left, i, val) else: - self._update_tree(root.right, i, val) - root.val = self.fn(root.left.val, root.right.val) - - def _query_range(self, root, i, j): - if root.start == i and root.end == j: - return root.val - # range in left child tree - if j <= root.mid: - return self._query_range(root.left, i, j) - # range in child child tree - elif i > root.mid: - return self._query_range(root.right, i, j) - # range in left child tree and right child tree + self._update_tree(node.right, i, val) + node.val = self.fn(node.left.val, node.right.val) + + def _query_range(self, node, i, j): + if node.start == i and node.end == j: + return node.val + + if i <= node.mid: + if j <= node.mid: + # range in left child tree + return self._query_range(node.left, i, j) + else: + # range in left child tree and right child tree + return self.fn(self._query_range(node.left, i, node.mid), self._query_range(node.right, node.mid + 1, j)) else: - return self.fn(self._query_range(root.left, i, root.mid), self._query_range(root.right, root.mid + 1, j)) + # range in right child tree + return self._query_range(node.right, i, j) def traverse(self): if self.root is not None: