Problem Challenge 1
Problem statement
Given a binary tree, find the length of its diameter. The diameter of a tree is the number of nodes on the longest path between any two leaf nodes. The diameter of a tree may or may not pass through the root.
Note: You can always assume that there are at least two leaf nodes in the given tree.
In the following illustration, the diameter of the tree is , as highlighted in purple.
Dry-run templates
This is the implementation of the solution provided for the problem posed in the Problem Challenge 1 lesson:
class TreeNode:def __init__(self, val, left=None, right=None):self.val = valself.left = leftself.right = rightclass TreeDiameter:def __init__(self):self.treeDiameter = 0def find_diameter(self, root):self.calculate_height(root)return self.treeDiameterdef calculate_height(self, currentNode):if currentNode is None:return 0leftTreeHeight = self.calculate_height(currentNode.left)rightTreeHeight = self.calculate_height(currentNode.right)# if the current node doesn't have a left or right subtree, we can't have# a path passing through it, since we need a leaf node on each sideif leftTreeHeight is not None and rightTreeHeight is not None:# diameter at the current node will be equal to the height of left subtree +# the height of right sub-trees + '1' for the current nodediameter = leftTreeHeight + rightTreeHeight + 1# update the global tree diameterself.treeDiameter = max(self.treeDiameter, diameter)# height of the current node will be equal to the maximum of the heights of# left or right subtrees plus '1' for the current nodereturn max(leftTreeHeight, rightTreeHeight) + 1def main():treeDiameter = TreeDiameter()root = TreeNode(1)root.left = TreeNode(2)root.right = TreeNode(3)root.left.left = TreeNode(4)root.right.left = TreeNode(5)root.right.right = TreeNode(6)print("Tree Diameter: " + str(treeDiameter.find_diameter(root)))root.left.left = Noneroot.right.left.left = TreeNode(7)root.right.left.right = TreeNode(8)root.right.right.left = TreeNode(9)root.right.left.right.left = TreeNode(10)root.right.right.left.left = TreeNode(11)print("Tree Diameter: " + str(treeDiameter.find_diameter(root)))main()
Sample input #1
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(3)
root.left.left = TreeNode(4)
root.right.left = TreeNode(5)
root.right.right = TreeNode(6)
_ 1 _
| |
_ 2 _ 3 _
| | |
4 5 6
Note: In the line numbers highlighted orange, the
currentNode
changes since we’re backtracking.
Line number | root | Recursive call | currentNode | leftTreeHeight | rightTreeHeight | diameter | treeDiameter |
143-144 | 1 | - | - | - | - | - | 0 |
21 | 1 | Left subtree | 2 | - | - | - | 0 |
21 | 2 | Left subtree | 4 | - | - | - | 0 |
21 | 4 | Left subtree | None | - | - | - | - |
22 | 4 | Right subtree | None | - | - | - | - |
26-33 | - | - | 4 | 0 | 0 | 1 | 1 |
22 | 2 | Right subtree | None | - | - | - | 1 |
26-33 | - | - | 2 | 1 | 0 | 2 | 2 |
22 | 1 | Right subtree | 3 | - | - | - | - |
21 | 3 | Left subtree | 5 | - | - | - | - |
21 | 5 | Left subtree | None | - | - | - | - |
22 | 5 | Right subtree | None | - | - | - | - |
26-33 | - | - | 5 | 0 | 0 | 1 | 2 |
22 | 3 | Right subtree | 6 | - | - | - | - |
21 | 6 | Left subtree | None | - | - | - | - |
22 | 6 | Right subtree | None | - | - | - | - |
26-33 | - | - | 6 | 0 | 0 | 1 | 2 |
26-33 | - | - | 3 | 1 | 1 | 3 | 3 |
26-33 | - | - | 1 | 2 | 2 | 5 | 5 |
Sample input #2
Students may be encouraged to complete the dry-run with this input:
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(3)
root.right.left = TreeNode(5)
root.right.right = TreeNode(6)
root.right.left.left = TreeNode(7)
root.right.left.right = TreeNode(8)
root.right.right.left = TreeNode(9)
root.right.left.right.left = TreeNode(10)
root.right.right.left.left = TreeNode(11)
_ 1 _
| |
2 _ 3 _________
| |
_ 5 __ _ 6
| | |
7 __ 8 __ 9
| |
10 11
Line number | root | Recursive call | currentNode | leftTreeHeight | rightTreeHeight | diameter | treeDiameter |
143-144 | 1 | - | - | - | - | - | 0 |
21 | 1 | Left subtree | 2 | - | - | - | 0 |
21 | 2 | Left subtree | None | - | - | - | 0 |
22 | 2 | Right subtree | None | - | - | - | - |
26-33 | - | - | 2 | 0 | 0 | 1 | 1 |
22 | 1 | Right subtree | 3 | - | - | - | 1 |
... | ... | ... | ... | ... | ... | ... | ... |
Step-by-step solution construction
The output of the following code shows the traces generated via print
statements.
Note:
draw_node
anddisplay_tree
methods aid in visualizing the tree structure. The students can be encouraged to write these methods themselves. They don’t affect the code execution.
The first step in our code is to calculate the height of the right and left subtrees for each node. Let’s see this below:
tab = 0# Display List Code Below:def height(node):if (node == None):return 0return 1 + max(height(node.left), height(node.right))def draw_node(output, link_above, node, level, p, link_char):if (node == None):returnout = "["h = len(output)SP = " "if (p < 0):for s in output:if (s):s = " "*(-1*p) + sfor s in link_above:if (s):s = " "*(-1*p) + sif level < h - 1:p = max(p, len(output[level + 1]))if (level > 0):p = max(p, len(output[level - 1]))p = max(p, len(output[level]))# Fill in to leftif (node.left):leftData = SP + str(node.left.val) + SPdraw_node(output, link_above, node.left,level + 1, p - len(leftData), 'L')p = max(p, len(output[level + 1]))# Enter this dataspace = p - len(output[level])if (space > 0):output[level] += (' ' * space)node_data = SP + str(node.val) + SPoutput[level] += node_data# Add vertical link abovespace = p + len(SP) - len(link_above[level])if (space > 0):link_above[level] += (' ' * space)link_above[level] += link_char# Fill in to rightif (node.right):draw_node(output, link_above, node.right,level + 1, len(output[level]), 'R')def display_tree(root):if (root == None):print("\tNone")h = height(root)output = []link_above = []for i in range(0, h):output.append("")link_above.append("")draw_node(output, link_above, root, 0, 5, ' ')# Create link linesfor i in range(h):for j in range(len(link_above[i])):if (link_above[i][j] != ' '):size = len(output[i - 1])if (size < j + 1):output[i - 1] += " " * (j + 1 - size)jj = jif (link_above[i][j] == 'L'):while (output[i - 1][jj] == ' '):jj += 1for k in range(j+1, jj-1):str1 = output[i - 1]list1 = list(str1)list1[k] = '_'output[i - 1] = ''.join(list1)elif(link_above[i][j] == 'R'):while (output[i - 1][jj] == ' '):jj -= 1# k = j-1for k in range(j-1, jj+1, -1):temp = output[i - 1]list1 = list(temp)list1[k] = '_'output[i - 1] = ''.join(list1)k = k - 1str1 = link_above[i]list1 = list(str1)list1[j] = '|'link_above[i] = ''.join(list1)# Outputfor i in range(h):if (i):print("\t", link_above[i])print("\t", output[i])class TreeNode:def __init__(self, val, left=None, right=None):self.val = valself.left = leftself.right = rightclass TreeDiameter:def __init__(self):self.treeDiameter = 0def find_diameter(self, root):self.calculate_height(root)return self.treeDiameterdef calculate_height(self, currentNode):global tabtab+=1if currentNode is None:print(tab*"\t" + "currentNode: None ---> backtracking")return 0print(tab*"\t" +"currentNode: ", currentNode.val )print(tab*"\t" +"Recursive call to the left subtree of node ", currentNode.val, sep = "")leftTreeHeight = self.calculate_height(currentNode.left)tab-=1print(tab*"\t" +"leftTreeHeight of node ", currentNode.val, ": ", leftTreeHeight, sep = "")print(tab*"\t" +"Recursive call to the right subtree of node ", currentNode.val, sep = "")rightTreeHeight = self.calculate_height(currentNode.right)tab-=1print(tab*"\t" +"rightTreeHeight of node ", currentNode.val, ": ", rightTreeHeight, sep = "")print("")return max(leftTreeHeight, rightTreeHeight) + 1def main():treeDiameter = TreeDiameter()root = TreeNode(1)root.left = TreeNode(2)root.right = TreeNode(3)root.left.left = TreeNode(4)root.right.left = TreeNode(5)root.right.right = TreeNode(6)treeDiameter2 = TreeDiameter()root2 = TreeNode(1)root2.left = TreeNode(2)root2.right = TreeNode(3)root2.left.left = TreeNode(4)root2.right.left = TreeNode(5)root2.right.right = TreeNode(6)root2.left.left = Noneroot2.right.left.left = TreeNode(7)root2.right.left.right = TreeNode(8)root2.right.right.left = TreeNode(9)root2.right.left.right.left = TreeNode(10)root2.right.right.left.left = TreeNode(11)input = [(treeDiameter, root), (treeDiameter2, root2)]for i in input:print("Input tree: ")display_tree(i[1])print("Tree Diameter: " + str(i[0].find_diameter(i[1])))print(("-"*100)+"\n")global tabtab = 0main()
Next, we use these heights to find the diameter of our binary tree. We’ll use the following formula:
diameter = leftTreeHeight + rightTreeHeight + 1
The global diameter variable treeDiameter
, is updated with the max
of diameter
and treeDiameter
.
self.treeDiameter = max(self.treeDiameter, diameter)
Let’s see this below:
from platform import nodelistnodes = []tab = 0# Display List Code Below:k, lh, rh, f, ans, pathLen = None, 0, 0, 0, 0 - 10 ** 19, 0nodeslist = []def height(node):if (node == None):return 0return 1 + max(height(node.left), height(node.right))def draw_node(output, link_above, node, level, p, link_char):if (node == None):returnout = "["h = len(output)SP = " "if (p < 0):for s in output:if (s):s = " "*(-1*p) + sfor s in link_above:if (s):s = " "*(-1*p) + sif level < h - 1:p = max(p, len(output[level + 1]))if (level > 0):p = max(p, len(output[level - 1]))p = max(p, len(output[level]))# Fill in to leftif (node.left):leftData = SP + str(node.left.val) + SPdraw_node(output, link_above, node.left,level + 1, p - len(leftData), 'L')p = max(p, len(output[level + 1]))# Enter this dataspace = p - len(output[level])if (space > 0):output[level] += (' ' * space)node_data = SP + str(node.val) + SPoutput[level] += node_data# Add vertical link abovespace = p + len(SP) - len(link_above[level])if (space > 0):link_above[level] += (' ' * space)link_above[level] += link_char# Fill in to rightif (node.right):draw_node(output, link_above, node.right,level + 1, len(output[level]), 'R')def display_tree(root):if (root == None):print("\tNone")h = height(root)output = []link_above = []for i in range(0, h):output.append("")link_above.append("")draw_node(output, link_above, root, 0, 5, ' ')# Create link linesfor i in range(h):for j in range(len(link_above[i])):if (link_above[i][j] != ' '):size = len(output[i - 1])if (size < j + 1):output[i - 1] += " " * (j + 1 - size)jj = jif (link_above[i][j] == 'L'):while (output[i - 1][jj] == ' '):jj += 1for k in range(j+1, jj-1):str1 = output[i - 1]list1 = list(str1)list1[k] = '_'output[i - 1] = ''.join(list1)elif(link_above[i][j] == 'R'):while (output[i - 1][jj] == ' '):jj -= 1# k = j-1for k in range(j-1, jj+1, -1):temp = output[i - 1]list1 = list(temp)list1[k] = '_'output[i - 1] = ''.join(list1)k = k - 1str1 = link_above[i]list1 = list(str1)list1[j] = '|'link_above[i] = ''.join(list1)# Outputfor i in range(h):if (i):print("\t", link_above[i])print("\t", output[i])def traversal(root, result):if root is None:returnif root in result:root.val = str(root.val) + "*"traversal(root.left, result)traversal(root.right, result)class TreeNode:def __init__(self, val, left=None, right=None):self.val = valself.left = leftself.right = rightclass TreeDiameter:def __init__(self):self.treeDiameter = 0def find_diameter(self, root):self.calculate_height(root)return self.treeDiameterdef calculate_height(self, currentNode):global tabtab += 1if currentNode is None:print(tab*"\t" + "currentNode: None ---> backtracking")return 0print(tab*"\t" + "currentNode: ", currentNode.val)print(tab*"\t" + "Recursive call to the left subtree of node ",currentNode.val, sep="")leftTreeHeight = self.calculate_height(currentNode.left)tab -= 1print(tab*"\t" + "Recursive call to the right subtree of node ",currentNode.val, sep="")rightTreeHeight = self.calculate_height(currentNode.right)tab -= 1# if the current node doesn't have a left or right subtree, we can't have# a path passing through it, since we need a leaf node on each sideif leftTreeHeight is not None and rightTreeHeight is not None:# diameter at the current node will be equal to the height of left subtree +# the height of right sub-trees + '1' for the current nodediameter = leftTreeHeight + rightTreeHeight + 1print("")print(tab*"\t" + "Current node: ", currentNode.val, sep="")print(tab*"\t" + "leftTreeHeight: ", leftTreeHeight, sep="")print(tab*"\t" + "rightTreeHeight: ", rightTreeHeight, sep="")print(tab*"\t" + "Diamater: ", leftTreeHeight, " + ",rightTreeHeight, " + 1 = ", diameter, sep="")# update the global tree diameterprint(tab*"\t" + "Updating treeDiameter with the max(", self.treeDiameter,", ", diameter, ") = ", max(self.treeDiameter, diameter), sep="")self.treeDiameter = max(self.treeDiameter, diameter)print("")# height of the current node will be equal to the maximum of the heights of# left or right subtrees plus '1' for the current nodereturn max(leftTreeHeight, rightTreeHeight) + 1def height(root):global ans, k, lh, rh, fif (root == None):return 0left_height = height(root.left)right_height = height(root.right)if (ans < 1 + left_height + right_height):ans = 1 + left_height + right_heightk = rootlh = left_heightrh = right_heightreturn 1 + max(left_height, right_height)def printArray(ints, lenn, f, listnodes):global nodeslistif (f == 0):for i in range(lenn - 1, -1, -1):nodeslist.append(listnodes[i])print(ints[i], end=" ")elif (f == 1):for i in range(lenn):nodeslist.append(listnodes[i])print(ints[i], end=" ")def printPathsRecur(node, path, maxm, pathlen):global fif (node == None):returnglobal listnodeslistnodes[pathlen] = nodepath[pathlen] = node.valpathlen += 1if (node.left == None and node.right == None):if (pathlen == maxm and (f == 0 or f == 1)):printArray(path, pathlen, f, listnodes)f = 2else:printPathsRecur(node.left, path, maxm, pathlen)printPathsRecur(node.right, path, maxm, pathlen)def diameter(root):global ans, lh, rh, f, k, pathLenif (root == None):returnheight_of_tree = height(root)lPath = [0 for i in range(100)]printPathsRecur(k.left, lPath, lh, 0)print(k.val, end=" ")nodeslist.append(k)rPath = [0 for i in range(100)]f = 1printPathsRecur(k.right, rPath, rh, 0)def main():treeDiameter = TreeDiameter()root = TreeNode(1)root.left = TreeNode(2)root.right = TreeNode(3)root.left.left = TreeNode(4)root.right.left = TreeNode(5)root.right.right = TreeNode(6)treeDiameter2 = TreeDiameter()root2 = TreeNode(1)root2.left = TreeNode(2)root2.right = TreeNode(3)root2.left.left = TreeNode(4)root2.right.left = TreeNode(5)root2.right.right = TreeNode(6)root2.left.left = Noneroot2.right.left.left = TreeNode(7)root2.right.left.right = TreeNode(8)root2.right.right.left = TreeNode(9)root2.right.left.right.left = TreeNode(10)root2.right.right.left.left = TreeNode(11)input = [(treeDiameter, root), (treeDiameter2, root2)]for i in input:global k, lh, rh, f, ans, pathLenglobal nodeslistglobal listnodeslistnodes = [0 for i in range(100)]k, lh, rh, f, ans, pathLen = None, 0, 0, 0, 0 - 10 ** 19, 0print("Input tree: ")display_tree(i[1])print("Tree Diameter: " + str(i[0].find_diameter(i[1])))print("Nodes included in the diameter:")diameter(i[1])print("\n")traversal(i[1], nodeslist)display_tree(i[1])print(("-"*100)+"\n")global tabtab = 0nodeslist.clear()listnodes.clear()main()
An asterisk (*) next to a node indicates that it is included in the diameter.
Functions
traversal()
,height()
,printArray()
,printPathsRecur()
anddiameter
, are only for tree printing and do not play a part in the algorithm.