Daily bit(e) of C++ | Sum of distances to all nodes
Daily bit(e) of C++ #142, Common interview problem: sum of distances to all nodes.
Today we will look at a common C++ interview problem: Sum of distances to all nodes.
Given a tree with n nodes, represented as a graph using a neighbourhood map, calculate the sum of distances to all other nodes for each node.
The node ids are in the range [0,n).
For example, in the above tree, the sum of distances for nodes 0 and 1 is 4, and for nodes 2 and 3, it is 6.
Before you continue reading the solution, I encourage you to try to solve it yourself. Here is a Compiler Explorer link with a couple of test cases: https://compiler-explorer.com/z/5nj1ejG5G.
Solution
Let's start with a sub-optimal solution. Consider a subtree for whose root we have already calculated the sum of distances. Moving to the parent node means we are one step further from all the nodes in this subtree.
Therefore if we do a straightforward post-order traversal of the tree, we will be able to calculate the sum of distances for the root node by simply summing up the distances of the children subtrees plus the number of nodes in the subtrees (since being one step further will contribute one per node).
Importantly, this gives us the sum of distances only for the root node, so we must repeat the process for each node (treating each node as the tree's root). Because traversal is an O(n) operation, we end up with O(n*n) time complexity.
If you have been following this series for a while, you might be getting an inkling that we are doing a lot of repetitive work, and perhaps there is room for improvement.
Let’s consider two neighbour nodes in the tree.
When we would calculate their corresponding sum of distances, the formulas would be:
distance_sum(x) = subtree_sum(x) + subtree_sum(y) + node_count(y)
distance_sum(y) = subtree_sum(y) + subtree_sum(x) + node_count(x)
distance_sum(x) - distance_sum(y) = node_count(y) - node_count(x)
This formula gives us the opportunity to calculate the answer for a child from the value of a parent.
distance_sum(child) = distance_sum(parent) + node_count(parent) - node_count(child)
distance_sum(child) = distance_sum(parent) + (total_nodes - node_count(child)) - node_count(child)
distance_sum(child) = distance_sum(parent) + total_node - 2*node_count(child)
Therefore, after we calculate the sum of distances for one root node using the post-order traversal, we can traverse the tree using pre-order traversal, filling in the missing values for the non-root nodes.
This leaves us with the optimal O(n) solution.