...

/

All Paths for a Sum (medium)

All Paths for a Sum (medium)

Problem Statement

Given a binary tree and a number ‘S’, find all paths from root-to-leaf such that the sum of all the node values of each path equals ‘S’.

Try it yourself

Try solving this question here:

import java.util.*;
class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode(int x) {
val = x;
}
};
class FindAllTreePaths {
public static List<List<Integer>> findPaths(TreeNode root, int sum) {
List<List<Integer>> allPaths = new ArrayList<>();
// TODO: Write your code here
return allPaths;
}
public static void main(String[] args) {
TreeNode root = new TreeNode(12);
root.left = new TreeNode(7);
root.right = new TreeNode(1);
root.left.left = new TreeNode(4);
root.right.left = new TreeNode(10);
root.right.right = new TreeNode(5);
int sum = 23;
List<List<Integer>> result = FindAllTreePaths.findPaths(root, sum);
System.out.println("Tree paths with sum " + sum + ": " + result);
}
}

Solution

This problem follows the Binary Tree Path Sum pattern. We can follow the same DFS approach. There will be two differences:

  1. Every time we find a root-to-leaf path, we will store it in a list.
  2. We will traverse all paths and will not stop processing after finding the first path.

Code

Here is what our algorithm will look like:

import java.util.*;
class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode(int x) {
val = x;
}
};
class FindAllTreePaths {
public static List<List<Integer>> findPaths(TreeNode root, int sum) {
List<List<Integer>> allPaths = new ArrayList<>();
List<Integer> currentPath = new ArrayList<Integer>();
findPathsRecursive(root, sum, currentPath, allPaths);
return allPaths;
}
private static void findPathsRecursive(TreeNode currentNode, int sum, List<Integer> currentPath,
List<List<Integer>> allPaths) {
if (currentNode == null)
return;
// add the current node to the path
currentPath.add(currentNode.val);
// if the current node is a leaf and its value is equal to sum, save the current path
if (currentNode.val == sum && currentNode.left == null && currentNode.right == null) {
allPaths.add(new ArrayList<Integer>(currentPath));
} else {
// traverse the left sub-tree
findPathsRecursive(currentNode.left, sum - currentNode.val, currentPath, allPaths);
// traverse the right sub-tree
findPathsRecursive(currentNode.right, sum - currentNode.val, currentPath, allPaths);
}
// remove the current node from the path to backtrack,
// we need to remove the current node while we are going up the recursive call stack.
currentPath.remove(currentPath.size() - 1);
}
public static void main(String[] args) {
TreeNode root = new TreeNode(12);
root.left = new TreeNode(7);
root.right = new TreeNode(1);
root.left.left = new TreeNode(4);
root.right.left = new TreeNode(10);
root.right.right = new TreeNode(5);
int sum = 23;
List<List<Integer>> result = FindAllTreePaths.findPaths(root, sum);
System.out.println("Tree paths with sum " + sum + ": " + result);
}
}

Time complexity

The time complexity of the above algorithm is O(N2)O(N^2), where ‘N’ is the total number of nodes in the tree. This is due to the fact that we traverse each node once (which will take O(N)O(N)), and for every leaf node, we might have to store its path (by making a copy of the current path) which will take O(N)O(N) ...