import sys
from random import randint
class Node:
"""Binary Search Tree Node"""
def __init__(self, key, freq):
self.key = key
self.freq = freq
def __str__(self):
"""
>>> str(Node(1, 2))
'Node(key=1, freq=2)'
"""
return f"Node(key={self.key}, freq={self.freq})"
def print_binary_search_tree(root, key, i, j, parent, is_left):
"""
Recursive function to print a BST from a root table.
>>> key = [3, 8, 9, 10, 17, 21]
>>> root = [[0, 1, 1, 1, 1, 1], [0, 1, 1, 1, 1, 3], [0, 0, 2, 3, 3, 3], \
[0, 0, 0, 3, 3, 3], [0, 0, 0, 0, 4, 5], [0, 0, 0, 0, 0, 5]]
>>> print_binary_search_tree(root, key, 0, 5, -1, False)
8 is the root of the binary search tree.
3 is the left child of key 8.
10 is the right child of key 8.
9 is the left child of key 10.
21 is the right child of key 10.
17 is the left child of key 21.
"""
if i > j or i < 0 or j > len(root) - 1:
return
node = root[i][j]
if parent == -1:
print(f"{key[node]} is the root of the binary search tree.")
elif is_left:
print(f"{key[node]} is the left child of key {parent}.")
else:
print(f"{key[node]} is the right child of key {parent}.")
print_binary_search_tree(root, key, i, node - 1, key[node], True)
print_binary_search_tree(root, key, node + 1, j, key[node], False)
def find_optimal_binary_search_tree(nodes):
"""
This function calculates and prints the optimal binary search tree.
The dynamic programming algorithm below runs in O(n^2) time.
Implemented from CLRS (Introduction to Algorithms) book.
https://en.wikipedia.org/wiki/Introduction_to_Algorithms
>>> find_optimal_binary_search_tree([Node(12, 8), Node(10, 34), Node(20, 50), \
Node(42, 3), Node(25, 40), Node(37, 30)])
Binary search tree nodes:
Node(key=10, freq=34)
Node(key=12, freq=8)
Node(key=20, freq=50)
Node(key=25, freq=40)
Node(key=37, freq=30)
Node(key=42, freq=3)
<BLANKLINE>
The cost of optimal BST for given tree nodes is 324.
20 is the root of the binary search tree.
10 is the left child of key 20.
12 is the right child of key 10.
25 is the right child of key 20.
37 is the right child of key 25.
42 is the right child of key 37.
"""
nodes.sort(key=lambda node: node.key)
n = len(nodes)
keys = [nodes[i].key for i in range(n)]
freqs = [nodes[i].freq for i in range(n)]
dp = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)]
total = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)]
root = [[i if i == j else 0 for j in range(n)] for i in range(n)]
for interval_length in range(2, n + 1):
for i in range(n - interval_length + 1):
j = i + interval_length - 1
dp[i][j] = sys.maxsize
total[i][j] = total[i][j - 1] + freqs[j]
for r in range(root[i][j - 1], root[i + 1][j] + 1):
left = dp[i][r - 1] if r != i else 0
right = dp[r + 1][j] if r != j else 0
cost = left + total[i][j] + right
if dp[i][j] > cost:
dp[i][j] = cost
root[i][j] = r
print("Binary search tree nodes:")
for node in nodes:
print(node)
print(f"\nThe cost of optimal BST for given tree nodes is {dp[0][n - 1]}.")
print_binary_search_tree(root, keys, 0, n - 1, -1, False)
def main():
nodes = [Node(i, randint(1, 50)) for i in range(10, 0, -1)]
find_optimal_binary_search_tree(nodes)
if __name__ == "__main__":
main()