Advanced Divide-n-Conquer: Recursion with Byproduct

Most of the time, you use divide-n-conquer to solve one task, like sorting or selection. But in this section, you will learn to use divide-n-conquer in a more interesting way, to solve two tasks simultaneously. One task might be the main task while the other is a byproduct that tags along. In other words, your recursive function will return two solutions, one for each task. This would be hard to do for C/C++/Java, but super easy in Python:

def solve(problem):
    ...          # divide-conquer-combine
    return a, b

and you can use Python’s “pattern-matching” feature to receive two things from a function: a, b = solve(sub_problem). Here is a complete template:

def solve(problem): # two tasks
    subp1, subp2 = ... # divide problem into two subproblems
    a1, b1 = solve(subp1)      # conquer two tasks for subp1
    a2, b2 = solve(subp2)      # conquer two tasks for subp2
    a, b =  ...    # combine (a1, b1) and (a2, b2) to (a, b)
    return a, b     # solutions of the two tasks for problem

We will demonstrate two classical examples of this paradigm.

Caveat: unlike previous sections, we need to distinguish “problem” and “task” in this section. For example, a problem might be an array like [4,3,1,2] and a task would be “sort the array” or “find the median in this array”. If a problem is a binary search tree, a task could be sorted(tree) or search(tree, query) or depth(tree).

Counting Number of Inversions

First, let’s consider how to efficiently count the number of inversions in an unsorted array. We define an inversion (or “inverted pair”) to be \((a_i, a_j)\) where \(i<j\) but \(a_i>a_j\). For example, in [4, 1, 3, 2] there are 4 inversions:

(4, 1), (4, 3), (4, 2), (3, 2)

(As a special case, a sorted array has 0 inversions).

Obviously you can do it in \(O(n^2)\) time by two nested loops which enumerates all pairs. But can you do it faster?

Well, the next faster complexity is \(O(n\log n)\). Can you do it that fast? Whenever you saw \(O(n\log n)\) (esp. in interviews), you should think of sorting, because this complexity comes naturally from many sorting algorithms (quicksort, mergesort, heapsort) and we’ll see in later sections that it is the fastest (internal) sorting can ever be. So \(O(n\log n)\) is inherently related to sorting. Now, can you use sorting to count the number of inversions?

In fact you can! And not just with one sorting algorithm. In this section we’ll see how to use mergesort to solve it, but you should think about (as an exercise) how to use quicksort for it as well. The basic idea is to tag along the counting of inversion onto mergesort, so that the former becomes a byproduct of sorting.

The basic idea is very simple:

The only new thing is how to calculate the crossing inversions between left and right. Why crossing inversions only? Because by the principle of divide-n-conquer, the internal inversions within left and right should already be solved in those two subproblems, and your job at the current level (a) is just just counting the remaining inversions (within a) that are beyond the scopes of left or right alone. Again, the principle of divide-conquer-combine is that the vast majority of the job is already done by the “conquer” steps, and you only need to do the “combination” step that your children can’t do by themselves.

Here is an example: let’s say left = [5, 1, 7] and right = [6, 4, 2] and after the two conquers, we get

sorted_left = [1, 5, 7],   inv_left = 1 (only 1 pair: (5,1))
sorted_right = [2, 4, 6],  inv_right = 3 (3 pairs: (6,4) (6,2) (4,2))

Now we combine them like in mergesort:

[1, 5, 7] [2, 4, 6]
 ^         ^
 *=>
 1                    inv_cross = 0

Whenever you take a left number, there is obviously no inversions, but when you take a right number you definitely have encountered (at least one) inversions, e.g., this (5, 2) pair.

[1, 5, 7] [2, 4, 6]
    ^      ^
           *=>
 1, 2                inv_cross = 1: (5,2)

[1, 5, 7] [2, 4, 6]
    ^         ^
              *=>
 1, 2, 4             inv_cross = 2: (5,2), (5,4)

[1, 5, 7] [2, 4, 6]
    ^            ^
    *=>
 1, 2, 4, 5          inv_cross = 2: (5,2), (5,4)

[1, 5, 7] [2, 4, 6]
       ^         ^
                 *=>
 1, 2, 4, 5, 6       inv_cross = 3: (5,2), (5,4), (7,6)

[1, 5, 7] [2, 4, 6]
       ^           ^
       *=>
 1, 2, 4, 5, 6, 7    inv_cross = 3: (5,2), (5,4), (7,6)

So we counted 3 crossing inversions ((5,2), (5,4), (7,6)), but did we miss anything? Clearly, there are two other crossing inversions that we didn’t count: (7, 2), (7,4). What was the problem?

Here it is: in each step, when you take a number from the right, there are more than just a single pair of inversion. In fact, all the remaining numbers in left are inverted with the current number in right:

[1, 5, 7] [2, 4, 6]
    ^      ^
           *=>
 1, 2                inv_cross = 2: (5,2) implies (7,2)
 
[1, 5, 7] [2, 4, 6]
    ^         ^
              *=>
 1, 2, 4             inv_cross = 4: (5,2), (7,2), (5,4) implies (7,4)

Now we recovered the two missing (implied) inversions. In general, when left[i] > right[j],

[<<<< i >>>>>>>] [..... j ......]
      ^                 ^
                        *=>

the current pair (left[i], right[j]) is obviously inverted, but that also implies that all remaining numbers left[i+1], left[i+2], … (those > numbers above) are also inverted with right[j], because they are even bigger than left[i]:

left[i'] >= left[i] > right[j] for i' = i+1, i+2, ...

So you should add |left|-i to inv_cross whenever you take a number from right.

Now we have a complete method of tagging along number of inversions while doing mergesort. The complexity stays the same, since this tagging along only costs \(O(1)\) per step, or \(O(n)\) total, in the “combine” part, which doesn’t change anything.

Caveat: for illustration purposes, we listed the inversion pairs explicitly when increasing inv_cross, but in reality we can’t do that (otherwise it would cost \(O(n)\) per step instead of \(O(1)\)). In other words, we can count the number of inversions in \(O(n\log n)\) time, but we can’t collect all inversion pairs in that time; the latter task has to be \(O(n^2)\) because in the worst case you have that many inversion pairs (if the input is inversely sorted)!

Exercise: How would you do the same problem with quicksort instead of mergesort? Note that unlike mergesort, quicksort is divide-heavy and combine-light, meaning most of the work is done in the partition step, so you should also count the crossing inversions in the partition step. (think about it: there are no crossing inversions after partitioning, or at the combination step, because everything in left is smaller than everything in right).

Longest Path in Binary Tree

A more interesting example in this “recursion with byproduct” paradigm is to use it to find the longest path in a binary tree (doesn’t need to be a binary search tree). For example, for this tree:

                      4
                    /   \
                  2       6
                 / \     / \
                1   3   5   7
               /             \
              0               9
                             /
                             8

the longest path is 0-1-2-4-6-7-9-8 with a length of 7 edges.

The first observation is we need to go as deep as possible on the two ends; obviously if you stop somewhere in the middle, it’s not optimal. The second “observation” that many students have when first looking at this problem is that the longest path has to go through the root, which is the case in the above example. But is it always the case?

Actually no! What if one side of the tree is tiny and the other side is huge? Then the longest path would be completely embedded in that bigger side. For example:

        2                  2
      /   \                  \ 
    1       6                  6
           / \                / \
          5   7              5   7
         /     \            /     \  
        3       9          3       9
         \     /            \      /
         4     8            4     8

Clearly, for the left tree, the longest path would go be 4-3-5-6-7-9-8 which does not go through the root. If we remove node 1, it would be even more obvious (the right tree).

So how do solve this problem? Well, we can let each subtree return the longest path that is completely embedded in that subtree. We have three cases:

So you can (naively) write two recursive functions for this algorithm:

def depth(t):
  return 0 if t == [] else max(depth(t[0]), depth(t[2]))+1

def longest(t):
  return 0 if t == [] else max(longest(t[0]), longest(t[2]), depth(t[0])+depth(t[2]))

However, this is a really bad solution, not just for aesthetic reasons. Its worst-case complexity is actually \(O(n^2)\) (see below)!

The better solution is to do these two tasks (longest and depth) together, just like sorting and counting inversions. Each node should return two numbers: not just the longest path, but also its depth. This way we guarantee the runtime is \(O(n)\) because it’s just a tree traversal (\(O(1)\) work for combining the subsolutions \((l_1, d_1)\) for left and \((l_2, d_2)\) for right to return \((l, d)\)).

Now let’s see how bad the separate-recursion solution would be. Clearly, depth() is \(O(n)\), not \(O(1)\). So in the balanced case:

\[ T(n) = 2T(n/2) + O(n) = O(n\log n)\]

But in the worst-case (single chain):

\[ T(n) = T(n-1) + O(n) = O(n^2) \]

This is intuitive, because you call depth() at each node, so \(n + (n-1) + ... + 1=O(n^2)\). Note that these depth() calls have many redundant calculations, because depth(root) would call depth(root.left) but depth(root.left) was already called when doing longest(root). If we could “memorize” (or “memoize”) the work already done, then we can avoid all these repetitions and get back to the \(O(n)\) total time, but that’s the topic for Chapter 2 (Dynamic Programming).

In any case, recursion with a byproduct is simple easy and fast!

Historical Notes

The counting inversions problem was taken from Roughgarden’s textbook (3.2). The longest path problem was a commonly asked interview question, but its analysis (for the bad solution) is quite interesting. This paradigm of “recursion with byproduct” was summarized by me.