In: Computer Science
I finished this class called AVLTree(), but the last test case does not pass. Can you help me out with it?
class AVLTree:
class Node:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n,
self.right
def rotate_left(self):
n = self.right
self.val, n.val = n.val, self.val
self.right, n.right, self.left, n.left = n.right, n.left, n,
self.left
@staticmethod
def height(n):
if not n:
return 0
else:
return min(1+AVLTree.Node.height(n.left),
1+AVLTree.Node.height(n.right))
def __init__(self):
self.size = 0
self.root = None
@staticmethod
def rebalance(t):
if AVLTree.Node.height(t.left) >
AVLTree.Node.height(t.right):
if AVLTree.Node.height(t.left.left) >=
AVLTree.Node.height(t.left.right):
t.rotate_right()
else:
t.left.rotate_left()
t.rotate_right()
else:
if AVLTree.Node.height(t.right.right) >=
AVLTree.Node.height(t.right.left):
t.rotate_left()
else:
t.right.rotate_right()
t.rotate_left()
def add(self, val):
assert(val not in self)
def add_rec(n):
if not n:
return AVLTree.Node(val)
elif val < n.val:
n.left = add_rec(n.left)
else:
n.right = add_rec(n.right)
if abs(AVLTree.Node.height(n.left)-AVLTree.Node.height(n.right))
>=1:
AVLTree.rebalance(n)
return n
self.root = add_rec(self.root)
self.size+=1
def __delitem__(self, val):
assert(val in self)
re_bal = []
def delitem_rec(n):
if not n:
return None
elif val < n.val:
n.left = delitem_rec(n.left)
elif val > n.val:
n.right = delitem_rec(n.right)
else:
if not n.left and not n.right:
return None
elif n.left and not n.right:
return n.left
elif n.right and not n.left:
return n.right
else:
n1 = n.left
re_bal.append(n1)
if not n1.right:
n.val = n1.val
n.left = n1.left
else:
n2 = n1
while n2.right.right:
n2 = n2.right
re_bal.append(n2)
re_bal.append(n2)
n1 = n2.right
n2.right = n1.left
n.val = n1.val
while re_bal:
s = re_bal.pop()
if abs(AVLTree.Node.height(s.left) - AVLTree.Node.height(s.right))
>= 2:
AVLTree.rebalance(s)
if abs(AVLTree.Node.height(n.left) - AVLTree.Node.height(n.right))
>= 2:
AVLTree.rebalance(n)
return n
self.root = delitem_rec(self.root)
self.size-=1
def __contains__(self, val):
def contains_rec(node):
if not node:
return False
elif val < node.val:
return contains_rec(node.left)
elif val > node.val:
return contains_rec(node.right)
else:
return True
return contains_rec(self.root)
def __len__(self):
return self.size
def __iter__(self):
def iter_rec(node):
if node:
yield from iter_rec(node.left)
yield node.val
yield from iter_rec(node.right)
yield from iter_rec(self.root)
def pprint(self, width=64):
height = self.height()
nodes = [(self.root, 0)]
prev_level = 0
repr_str = ''
while nodes:
n,level = nodes.pop(0)
if prev_level != level:
prev_level = level
repr_str += '\n'
if not n:
if level < height-1:
nodes.extend([(None, level+1), (None, level+1)])
repr_str += '{val:^{width}}'.format(val='-',
width=width//2**level)
elif n:
if n.left or level < height-1:
nodes.append((n.left, level+1))
if n.right or level < height-1:
nodes.append((n.right, level+1))
repr_str += '{val:^{width}}'.format(val=n.val,
width=width//2**level)
print(repr_str)
def height(self):
def height_rec(t):
if not t:
return 0
else:
return max(1+height_rec(t.left), 1+height_rec(t.right))
return height_rec(self.root)
HERE IS THE TEST CASE THAT FAILS, I ADDED THE ASSERTION ERROR SO YOU CAN SEE WHERE IT IS NOT WORKING
from unittest import TestCase
import random
tc = TestCase()
def traverse(t, fn):
if t:
fn(t)
traverse(t.left, fn)
traverse(t.right, fn)
def height(t):
if not t:
return 0
else:
return max(1+height(t.left), 1+height(t.right))
def check_balance(t):
tc.assertLess(abs(height(t.left) - height(t.right)), 2, 'Tree is
out of balance')
t = AVLTree()
vals = list(range(1000))
random.shuffle(vals)
for i in range(len(vals)):
t.add(vals[i])
for x in vals[:i+1]:
tc.assertIn(x, t, 'Element added not in tree')
traverse(t.root, check_balance)
random.shuffle(vals)
for i in range(len(vals)):
del t[vals[i]]
for x in vals[i+1:]:
tc.assertIn(x, t, 'Incorrect element removed from tree')
for x in vals[:i+1]:
tc.assertNotIn(x, t, 'Element removed still in tree')
traverse(t.root, check_balance)
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) <ipython-input-120-8428ca646bca> in <module> 29 for x in vals[:i+1]: 30 tc.assertIn(x, t, 'Element added not in tree') ---> 31 traverse(t.root, check_balance) 32 33 random.shuffle(vals) <ipython-input-120-8428ca646bca> in traverse(t, fn) 9 def traverse(t, fn): 10 if t: ---> 11 fn(t) 12 traverse(t.left, fn) 13 traverse(t.right, fn) <ipython-input-120-8428ca646bca> in check_balance(t) 20 21 def check_balance(t): ---> 22 tc.assertLess(abs(height(t.left) - height(t.right)), 2, 'Tree is out of balance') 23 24 t = AVLTree() ~\Anaconda3\lib\unittest\case.py in assertLess(self, a, b, msg) 1224 if not a < b: 1225 standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b)) -> 1226 self.fail(self._formatMessage(msg, standardMsg)) 1227 1228 def assertLessEqual(self, a, b, msg=None): ~\Anaconda3\lib\unittest\case.py in fail(self, msg) 678 def fail(self, msg=None): 679 """Fail immediately, with the given message.""" --> 680 raise self.failureException(msg) 681 682 def assertFalse(self, expr, msg=None): AssertionError: 2 not less than 2 : Tree is out of balance
class AVLTree:
class Node:
def _init_(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def rotate_right(self):
n = self.left
self.val, n.val = n.val, self.val
self.left, n.left, self.right, n.right = n.left, n.right, n,
self.right
def rotate_left(self):
n = self.right
self.val, n.val = n.val, self.val
self.right, n.right, self.left, n.left = n.right, n.left, n,
self.left
@staticmethod
def height(n):
if not n:
return 0
else:
return max(1 + AVLTree.Node.height(n.left), 1 +
AVLTree.Node.height(n.right))
def _init_(self):
self.size = 0
self.root = None
@staticmethod
def rebalance(t):
if AVLTree.Node.height(t.left) >
AVLTree.Node.height(t.right):
if AVLTree.Node.height(t.left.left) >=
AVLTree.Node.height(t.left.right): # LL
t.rotate_right()
else:
t.left.rotate_left()
t.rotate_right()
else:
if AVLTree.Node.height(t.right.right) >=
AVLTree.Node.height(t.right.left): # RR
t.rotate_left()
else:
t.right.rotate_right()
t.rotate_left()
def add(self, val):
assert (val not in self)
def add_rec(node):
if not node:
return AVLTree.Node(val)
elif val < node.val:
node.left = add_rec(node.left)
elif val > node.val:
node.right = add_rec(node.right)
if abs(AVLTree.Node.height(node.left) -
AVLTree.Node.height(node.right)) >= 2:
AVLTree.rebalance(node)
return node
self.root = add_rec(self.root)
self.size += 1
def _delitem_(self, val):
assert (val in self)
rebal = []
def delitem_rec(node):
if val < node.val:
node.left = delitem_rec(node.left)
elif val > node.val:
node.right = delitem_rec(node.right)
else:
if not node.left and not node.right:
return None
elif node.left and not node.right:
return node.left
elif node.right and not node.left:
return node.right
else:
# remove the largest value from the left subtree (t) as a
replacement
# for the root value of this tree
t = node.left
rebal.append(t)
if not t.right:
node.val = t.val
node.left = t.left
else:
n = t
while n.right.right:
n = n.right
rebal.append(n)
rebal.append(n)
t = n.right
n.right = t.left
node.val = t.val
while rebal:
s = rebal.pop()
if abs(AVLTree.Node.height(s.left) - AVLTree.Node.height(s.right))
>= 2:
AVLTree.rebalance(s)
if abs(AVLTree.Node.height(node.left) -
AVLTree.Node.height(node.right)) >= 2:
AVLTree.rebalance(node)
return node
self.root = delitem_rec(self.root)
self.size -= 1
def _contains_(self, val):
def contains_rec(node):
if not node:
return False
elif val < node.val:
return contains_rec(node.left)
elif val > node.val:
return contains_rec(node.right)
else:
return True
return contains_rec(self.root)
def _len_(self):
return self.size
def _iter_(self):
def iter_rec(node):
if node:
yield from iter_rec(node.left)
yield node.val
yield from iter_rec(node.right)
yield from iter_rec(self.root)
def pprint(self, width=64):
"""Attempts to pretty-print this tree's contents."""
height = self.height()
nodes = [(self.root, 0)]
prev_level = 0
repr_str = ''
while nodes:
n, level = nodes.pop(0)
if prev_level != level:
prev_level = level
repr_str += '\n'
if not n:
if level < height - 1:
nodes.extend([(None, level + 1), (None, level + 1)])
repr_str += '{val:^{width}}'.format(val='-', width=width // 2 **
level)
elif n:
if n.left or level < height - 1:
nodes.append((n.left, level + 1))
if n.right or level < height - 1:
nodes.append((n.right, level + 1))
repr_str += '{val:^{width}}'.format(val=n.val, width=width // 2 **
level)
print(repr_str)
def height(self):
"""Returns the height of the longest branch of the tree."""
def height_rec(t):
if not t:
return 0
else:
return max(1 + height_rec(t.left), 1 + height_rec(t.right))
return height_rec(self.root)