Tree
- Tree problems usually just spin around DFS or BFS and utilizes some specific tree properties so it's important to be able to implement these 2 algo while blindfolded
Overview
Traversal
DFS
The idea is usually to use recursion and recursively traverse the left and right subtree while performing some operation
Sometimes it's helpful to use another params in the recursive function to keep track of some data (eg: path sum, current node path, etc)
Preorder Traversal
def preorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []
left = self.preorderTraversal(root.left)
right = self.preorderTraversal(root.right)
return [root.val] + left + right
- Inorder Traversal
def inorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []
left = self.inorderTraversal(root.left)
right = self.inorderTraversal(root.right)
return left + [root.val]+ right
- Postorder Traversal
def postorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []
left = self.postorderTraversal(root.left)
right = self.postorderTraversal(root.right)
return left + right + [root.val]
def buildTree(self, preorder: List[int], inorder: List[int]) -> Optional[TreeNode]:
if not preorder or not inorder:
return None
root = TreeNode(preorder[0])
# We can use rootIndex of inorder because
# preorder: root + left + right
# inorder: left + root + right
# So we can use the rootIndex to split the preorder and inoder to 2 subarrays
rootIndex = inorder.index(root.val)
root.left = self.buildTree(preorder[1:rootIndex+1], inorder[:rootIndex+1])
root.right = self.buildTree(preorder[rootIndex+1:], inorder[rootIndex+1:])
return root
# https://leetcode.com/problems/path-sum/
def hasPathSum(self, root: Optional[TreeNode], targetSum: int) -> bool:
def getSum(node, curSum):
if not node:
return False
if not node.left and not node.right and curSum + node.val == targetSum:
return True
left = getSum(node.left, curSum + node.val)
right = getSum(node.right, curSum + node.val)
return left or right
return getSum(root, 0)
https://leetcode.com/problems/boundary-of-binary-tree
- Actual implementation is not bad, good practice for DFS thinking
- The hard part is understand WTF the description is about
def boundaryOfBinaryTree(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []
def getLeft(node):
if not node:
return []
if not node.left and not node.right:
return []
current = [node.val]
if node.left:
return current + getLeft(node.left)
else:
return current + getLeft(node.right)
def getRight(node):
if not node:
return []
if not node.left and not node.right:
return []
current = [node.val]
if node.right:
return getRight(node.right) + current
else:
return getRight(node.left) + current
leaves = []
def getLeaves(node):
if not node:
return
if not node.left and not node.right:
leaves.append(node.val)
return
getLeaves(node.left)
getLeaves(node.right)
left = getLeft(root.left)
current = [root.val]
right = getRight(root.right)
if root.left or root.right:
getLeaves(root)
return current + left + leaves + right
BFS
- The idea is to use queue to keep track of what to traverse next
- Use when
- Level order traversal
- When problem asks about relationships between nodes at the same height/depth
- When you need to find the minimum depth or shortest path
- Parent relationship
- When you need to keep track of what is the parent of current node (any maybe compare to other node of the same level)
- Level order traversal
# https://leetcode.com/problems/binary-tree-level-order-traversal/
def levelOrder(self, root: Optional[TreeNode]) -> List[List[int]]:
res = []
if not root:
return res
queue = [root]
while queue:
cur_level = []
cur_len = len(queue)
for i in range(cur_len):
node = queue.pop(0)
cur_level.append(node.val)
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
res.append(cur_level)
return res
https://leetcode.com/problems/maximum-level-sum-of-a-binary-tree
def maxLevelSum(self, root: Optional[TreeNode]) -> int:
maxVal = float('-inf')
maxLevel = -1
queue = [(root)]
level = 1
while queue:
curLen = len(queue)
curSum = 0
for _ in range(curLen):
node = queue.pop(0)
curSum += node.val
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
if curSum > maxVal:
maxVal = curSum
maxLevel = level
level += 1
return maxLevel
https://leetcode.com/problems/cousins-in-binary-tree
def isCousins(self, root: Optional[TreeNode], x: int, y: int) -> bool:
if not root:
return False
queue = [(root, None, 0)]
xDepth, xParent = -1, None
yDepth, yParent = -1, None
while queue:
curLen = len(queue)
for i in range(curLen):
node, parent, depth = queue.pop(0)
if node.val == x:
xParent, xDepth = parent, depth
elif node.val == y:
yParent, yDepth = parent, depth
if xParent and yParent:
break
if node.left:
queue.append((node.left, node, depth + 1))
if node.right:
queue.append((node.right, node, depth + 1))
return xDepth == yDepth and xParent != yParent
https://leetcode.com/problems/cousins-in-binary-tree-ii
- More advanced version here
- The trick here is that we have to optimize the calculation of every other nodes which do not share the same parent
- Pre-computation:
node.val = total - node.val - sum(cousins)
=> We can pre-calculate this - HashTable
- Pre-computation:
def replaceValueInTree(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
if not root:
return
queue = [(root, None)]
while queue:
curLen = len(queue)
level = []
for _ in range(curLen):
node, parent = queue.pop(0)
level.append((node, parent))
if node.left:
queue.append((node.left, node))
if node.right:
queue.append((node.right, node))
negate = {}
for i in range(len(level)):
node, parent = level[i]
negate[i] = node.val
if i > 0 and parent == level[i-1][1]:
negate[i-1] += node.val
negate[i] += level[i-1][0].val
total = sum([node.val for node, _ in level])
for i in range(len(level)):
node, _ = level[i]
node.val = total - negate[i]
return root
Problems
- This list provides a good variety of basic pattern to get started with: https://leetcode.com/discuss/study-guide/5020529/Master-Tree-Patterns/
- Does not really go deep into more advanced pattern or tricky question, but more than enough to develop simple tuition for tree
- This is also a good list: https://leetcode.com/discuss/study-guide/1212004/Binary-Trees-study-guide
Variants
Binary Search Tree
Definition: A tree where each node has at most two children, and for any node
- All values in the left subtree are less than the node's value
- All values in the right subtree are greater than the node's value
- In-order traversal yields sorted order
Efficient for search, insert, and delete operations when balanced:
O(log n)
These are some of the fundamental problems that demonstrate BST properties
def sortedArrayToBST(self, nums: List[int]) -> Optional[TreeNode]:
if not nums:
return
mid = len(nums) // 2
root = TreeNode(nums[mid])
root.left = self.sortedArrayToBST(nums[:mid])
root.right = self.sortedArrayToBST(nums[mid+1:])
return root
https://leetcode.com/problems/construct-binary-search-tree-from-preorder-traversal
- Good question to notice BST properties
def bstFromPreorder(self, preorder: List[int]) -> Optional[TreeNode]:
if not preorder:
return
root = TreeNode(preorder[0])
left, right = 1, len(preorder) - 1
while left <= right:
mid = left + (right - left) // 2
if preorder[mid] > root.val:
right = mid-1
else:
left = mid + 1
leftHalf = self.bstFromPreorder(preorder[1:left])
rightHalf = self.bstFromPreorder(preorder[left:])
root.left = leftHalf
root.right = rightHalf
return root
https://leetcode.com/problems/validate-binary-search-tree
def isValidBST(self, root: Optional[TreeNode]) -> bool:
def helper(node, lower, upper):
if not node:
return True
if node.val >= upper or node.val <= lower:
return False
return helper(node.left, lower, node.val) and helper(node.right, node.val, upper)
return helper(root, float('-inf'), float('inf'))
https://leetcode.com/problems/insert-into-a-binary-search-tree
def insertIntoBST(self, root: Optional[TreeNode], val: int) -> Optional[TreeNode]:
if not root:
return TreeNode(val)
if val > root.val:
root.right = self.insertIntoBST(root.right, val)
if val < root.val:
root.left = self.insertIntoBST(root.left, val)
return root
https://leetcode.com/problems/delete-node-in-a-bst
- Important algorithm to memorize
- 3 steps:
- Find the node to remove
- Replace the node by its successor (either smallest in right subtree or largest in left subtree)
- Remove the successor
def deleteNode(self, root: Optional[TreeNode], key: int) -> Optional[TreeNode]:
def findSuccessor(node):
if node.left:
return findSuccessor(node.left)
return node
if not root:
return
if key > root.val:
root.right = self.deleteNode(root.right, key)
elif key < root.val:
root.left = self.deleteNode(root.left, key)
else:
if not root.left and not root.right:
return None
elif not root.left:
return root.right
elif not root.right:
return root.left
else:
successor = findSuccessor(root.right)
root.val = successor.val
root.right = self.deleteNode(root.right, root.val)
return root
Lowest Common Ancestor
https://leetcode.com/discuss/interview-question/6024811
https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
if not root or root.val == p.val or root.val == q.val:
return root
left = self.lowestCommonAncestor(root.left, p, q)
right = self.lowestCommonAncestor(root.right, p, q)
if left and right:
return root
return left or right
https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-ii
- Use the
helper/dfs
function to both find the common ancestor and the existence of each node - Maintain 2 variables to keep track of the existence of each node
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
foundP = False
foundQ = False
def helper(node):
nonlocal foundP
nonlocal foundQ
if not node:
return
left = helper(node.left)
right = helper(node.right)
if node.val == p.val:
foundP = True
return node
if node.val == q.val:
foundQ = True
return node
if left and right:
return node
return left or right
res = helper(root)
return res if foundP and foundQ else None
https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-search-tree
# Recursive
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
if root.val == q.val or root.val == p.val:
return root
if (p.val < root.val < q.val) or (q.val < root.val < p.val):
return root
if p.val < root.val and q.val < root.val:
return self.lowestCommonAncestor(root.left, p, q)
if p.val > root.val and q.val > root.val:
return self.lowestCommonAncestor(root.right, p, q)
# Iterative
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
node = root
while node:
if node.val > p.val and node.val > q.val:
node = node.left
elif node.val < p. val and node.val < q.val:
node = node.right
elif (node.val >= p.val and node.val <= q.val) or (node.val <= p.val and node.val >= q.val):
return node
return None
https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-iii
- Traverse 1 node all the way to the root
- Traverse the other one, the first node that we meet that is visited by the previous traversal is the LCA
def lowestCommonAncestor(self, p: 'Node', q: 'Node') -> 'Node':
visit = set()
while p:
visit.add(p)
p = p.parent
while q:
if q in visit:
return q
visit.add(q)
q = q.parent
return None
https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-iv
def lowestCommonAncestor(self, root: 'TreeNode', nodes: 'List[TreeNode]') -> 'TreeNode':
nodes = set(nodes)
def dfs(node):
if not node or node in nodes:
return node
left = dfs(node.left)
right = dfs(node.right)
if left and right:
return node
return left or right
return dfs(root)
https://leetcode.com/problems/lowest-common-ancestor-of-deepest-leaves
def lcaDeepestLeaves(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
def depth(node):
if not node:
return 0, node
leftDepth, lcaLeft = depth(node.left)
rightDepth, lcaRight = depth(node.right)
if leftDepth == rightDepth:
return leftDepth + 1, node
elif leftDepth > rightDepth:
return leftDepth + 1, lcaLeft
elif leftDepth < rightDepth:
return rightDepth + 1, lcaRight
return depth(root)[1]
Tree Diameter
https://leetcode.com/problems/diameter-of-binary-tree
def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
res = 0
def dfs(node):
nonlocal res
if not node:
return 0
left = dfs(node.left)
right = dfs(node.right)
res = max(res, left + right)
return 1 + max(left, right)
dfs(root)
return res
https://leetcode.com/problems/diameter-of-n-ary-tree
- Same idea as before, just need to take sum of 2 longest path
def diameter(self, root: 'Node') -> int:
res = 0
def dfs(node):
nonlocal res
if not node:
return
childs = []
maxChild, secondMaxChild = 0, 0
for child in node.children:
current = dfs(child)
if current >= maxChild:
secondMaxChild = maxChild
maxChild = current
elif secondMaxChild <= current < maxChild:
secondMaxChild = current
childs.append(current)
res = max(res, maxChild + secondMaxChild)
return 1 + max(childs) if childs else 1
dfs(root)
return res
Patterns
Tree as Graph
- Sometimes a problem is given as a tree, but we actually want to treat (maybe convert) the tree as a [[9. Graph| graph]].
- In a typical tree problem, we only need to
- Move from parent to children (downward)
- Process nodes in a specific order
- Track information along a single path
- These signals/characteristics in a problem should make you consider transforming the tree into a graph:
Non standard movement requirements
- For example with this question: https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree
- To find all valid node if the target is in the middle of the tree, we would have to
- Go up k levels
- Go down k levels,
- Go up one and down k-1, etc
- To find all valid node if the target is in the middle of the tree, we would have to
https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree
def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> List[int]:
def buildGraph(node, parent, graph):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
buildGraph(node.left, node, graph)
if node.right:
graph[node.val].append(node.right.val)
buildGraph(node.right, node, graph)
graph = collections.defaultdict(list)
buildGraph(root, None, graph)
visit = set()
res = []
queue = [(target.val, 0)]
visit.add(target.val)
while queue:
node, dist = queue.pop(0)
if dist > k:
continue
elif dist == k:
res.append(node)
else:
for nb in graph[node]:
if nb not in visit:
queue.append((nb, dist+1))
visit.add(nb)
return res
https://leetcode.com/problems/amount-of-time-for-binary-tree-to-be-infected/
def amountOfTime(self, root: Optional[TreeNode], start: int) -> int:
graph = collections.defaultdict(list)
def buildGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
buildGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
buildGraph(node.right, node)
buildGraph(root, None)
infected = set()
infected.add(start)
queue = [(start, 0)]
res = 0
while queue:
curLen = len(queue)
for i in range(curLen):
node, dist = queue.pop(0)
infected.add(node)
res = max(res, dist)
for nb in graph[node]:
if nb not in infected:
queue.append((nb, dist+1))
return res
https://leetcode.com/problems/step-by-step-directions-from-a-binary-tree-node-to-another
https://leetcode.com/problems/closest-leaf-in-a-binary-tree
def findClosestLeaf(self, root: Optional[TreeNode], k: int) -> int:
graph = collections.defaultdict(list)
leaf = set()
def builGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if not node.left and not node.right:
leaf.add(node.val)
if node.left:
graph[node.val].append(node.left.val)
builGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
builGraph(node.right, node)
builGraph(root, None)
visit = set()
queue = [k]
visit.add(k)
while queue:
curLen = len(queue)
for _ in range(curLen):
curNode = queue.pop(0)
if curNode in leaf:
return curNode
for nb in graph[curNode]:
if nb not in visit:
visit.add(nb)
queue.append(nb)
return -1
Relationship-base queries
- When a problem asks about relationships that aren't purely hierarchal, consider a graph
- Finding nodes at a specific distance
- Finding the distance between any two nodes
- Finding all nodes that can reach a target node
- Finding the shortest path between nodes
https://leetcode.com/problems/find-distance-in-a-binary-tree
def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
graph = collections.defaultdict(list)
def builGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
builGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
builGraph(node.right, node)
builGraph(root, None)
visit = set()
queue = [(p, 0)]
visit.add(p)
while queue:
curLen = len(queue)
for _ in range(curLen):
node, dst = queue.pop(0)
visit.add(node)
if node == q:
return dst
for nb in graph[node]:
if nb not in visit:
queue.append((nb, dst + 1))
return -1
https://leetcode.com/problems/binary-tree-maximum-path-sum
Parent Access need
- If you find yourself thinking "I need to know this node's parent" or "I need to move upward form this node", that's often a signal that a graph representation might be helpful.
https://leetcode.com/problems/find-all-the-lonely-nodes
https://leetcode.com/problems/find-nearest-right-node-in-binary-tree
https://leetcode.com/problems/lowest-common-ancestor-of-deepest-leaves
Misc
NOTES
- An important property when dealing with Full Binary Tree is that it only has 2 childrens, from this we can calculate that:
Binary Tree in Array Representation If a binary tree is represented as an array:
1. **Index 0** represents the root node.
2. For a node at index i:
• **Left Child**: The left child is located at index 2i + 1.
• **Right Child**: The right child is located at index 2i + 2.