Skip to content

fix bug: edge case of avl delete #4001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 82 additions & 48 deletions data_structures/binary_tree/avl_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,84 +8,86 @@

import math
import random
import unittest
from typing import Any


class my_queue:
def __init__(self):
def __init__(self) -> None:
self.data = []
self.head = 0
self.tail = 0

def is_empty(self):
def is_empty(self) -> bool:
return self.head == self.tail

def push(self, data):
def push(self, data: Any) -> None:
self.data.append(data)
self.tail = self.tail + 1

def pop(self):
def pop(self) -> Any:
ret = self.data[self.head]
self.head = self.head + 1
return ret

def count(self):
def count(self) -> int:
return self.tail - self.head

def print(self):
def print(self) -> None:
print(self.data)
print("**************")
print(self.data[self.head : self.tail])


class my_node:
def __init__(self, data):
def __init__(self, data: Any) -> None:
self.data = data
self.left = None
self.right = None
self.height = 1

def get_data(self):
def get_data(self) -> Any:
return self.data

def get_left(self):
def get_left(self) -> "my_node":
return self.left

def get_right(self):
def get_right(self) -> "my_node":
return self.right

def get_height(self):
def get_height(self) -> int:
return self.height

def set_data(self, data):
def set_data(self, data: Any) -> None:
self.data = data
return

def set_left(self, node):
def set_left(self, node: "my_node") -> None:
self.left = node
return

def set_right(self, node):
def set_right(self, node: "my_node") -> None:
self.right = node
return

def set_height(self, height):
def set_height(self, height: int) -> None:
self.height = height
return


def get_height(node):
def get_height(node: "my_node") -> int:
if node is None:
return 0
return node.get_height()


def my_max(a, b):
def my_max(a: Any, b: Any) -> Any:
if a > b:
return a
return b


def right_rotation(node):
def right_rotation(node: "my_node") -> "my_node":
r"""
A B
/ \ / \
Expand All @@ -107,7 +109,7 @@ def right_rotation(node):
return ret


def left_rotation(node):
def left_rotation(node: "my_node") -> "my_node":
"""
a mirror symmetry rotation of the left_rotation
"""
Expand All @@ -122,7 +124,7 @@ def left_rotation(node):
return ret


def lr_rotation(node):
def lr_rotation(node: "my_node") -> "my_node":
r"""
A A Br
/ \ / \ / \
Expand All @@ -137,12 +139,12 @@ def lr_rotation(node):
return right_rotation(node)


def rl_rotation(node):
def rl_rotation(node: "my_node") -> "my_node":
node.set_right(right_rotation(node.get_right()))
return left_rotation(node)


def insert_node(node, data):
def insert_node(node: "my_node", data: Any) -> "my_node":
if node is None:
return my_node(data)
if data < node.get_data():
Expand All @@ -168,19 +170,19 @@ def insert_node(node, data):
return node


def get_rightMost(root):
def get_rightMost(root: "my_node") -> "my_node":
while root.get_right() is not None:
root = root.get_right()
return root.get_data()


def get_leftMost(root):
def get_leftMost(root: "my_node") -> "my_node":
while root.get_left() is not None:
root = root.get_left()
return root.get_data()


def del_node(root, data):
def del_node(root: "my_node", data: Any) -> "my_node":
if root.get_data() == data:
if root.get_left() is not None and root.get_right() is not None:
temp_data = get_leftMost(root.get_right())
Expand All @@ -204,14 +206,14 @@ def del_node(root, data):
if root is None:
return root
if get_height(root.get_right()) - get_height(root.get_left()) == 2:
if get_height(root.get_right().get_right()) > get_height(
if get_height(root.get_right().get_right()) >= get_height(
root.get_right().get_left()
):
root = left_rotation(root)
else:
root = rl_rotation(root)
elif get_height(root.get_right()) - get_height(root.get_left()) == -2:
if get_height(root.get_left().get_left()) > get_height(
if get_height(root.get_left().get_left()) >= get_height(
root.get_left().get_right()
):
root = right_rotation(root)
Expand Down Expand Up @@ -256,25 +258,27 @@ class AVLtree:
*************************************
"""

def __init__(self):
def __init__(self) -> None:
self.root = None

def get_height(self):
# print("yyy")
def get_height(self) -> int:
return get_height(self.root)

def insert(self, data):
def insert(self, data: Any) -> None:
print("insert:" + str(data))
self.root = insert_node(self.root, data)

def del_node(self, data):
def del_node(self, data: Any) -> None:
print("delete:" + str(data))
if self.root is None:
print("Tree is empty!")
return
self.root = del_node(self.root, data)

def __str__(self): # a level traversale, gives a more intuitive look on the tree
def __str__(self) -> str:
"""
A level traversale, gives a more intuitive look on the tree
"""
output = ""
q = my_queue()
q.push(self.root)
Expand Down Expand Up @@ -308,21 +312,51 @@ def __str__(self): # a level traversale, gives a more intuitive look on the tre
return output


def _test():
import doctest

doctest.testmod()
class Test(unittest.TestCase):
def _is_balance(self, avl: AVLtree):
if avl.root is None:
return True
dfs = [avl.root]
while dfs:
now = dfs.pop()
if now.left:
left_height = now.left.height
dfs.append(now.left)
else:
left_height = 0
if now.right:
right_height = now.right.height
dfs.append(now.right)
else:
right_height = 0
if abs(left_height - right_height) > 1:
return False
return True

def test_delete(self):
avl = AVLtree()
for i in [8, 7, 4, 3, 9, 10, 11, 13, 6, 0, 2, 12, 1, 14, 5]:
avl.insert(i)
self.assertTrue(self._is_balance(avl))

for v in [8, 7, 4, 3, 9, 10, 11, 13, 6, 0, 2, 12, 1, 14, 5]:
avl.del_node(v)
print(avl)
self.assertTrue(self._is_balance(avl))

def test_delete_random(self):
avl = AVLtree()
random.seed(0)
values = list(range(1000))
random.shuffle(values)
for i in values:
avl.insert(i)
self.assertTrue(self._is_balance(avl))
random.shuffle(values)
for i in values:
avl.del_node(i)
self.assertTrue(self._is_balance(avl))


if __name__ == "__main__":
_test()
t = AVLtree()
lst = list(range(10))
random.shuffle(lst)
for i in lst:
t.insert(i)
print(str(t))
random.shuffle(lst)
for i in lst:
t.del_node(i)
print(str(t))
unittest.main()