yet another blog

[Tech] TIL: Binary lifting

31 Mar 2021

This is a super helpful trick with many applications.

Problem

You have a tree, given two nodes A and B you want to find the lowest common ancestor of two these nodes. Oftentimes you are given Q queries, you can brute force with the total complexityt of O(Q*N) with N is the number of nodes in the tree or we can you the trick that I mention in this post Binary lifting with some pre-processing and you can answer each query in O(logN)

How to build

I denote up[a][v] is the parent of node a that by moving up 2^{v} times.

How to query

First we move the deeper node to the same node as the other. Then we together move two of them until they meet.

Code

I still think it is still easier to look at code and try to figure out the idea.


#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1000;
const int LOG = 10; // since 2 ^ 10 > 1000
vector<int> edges[MAX_N];
int up[MAX_N][LOG];
int depth[MAX_N]; // depth of each node.
void dfs(int a, int par){
    for(int b: edges[a]){
        if(b == par){
            continue;
        }
        depth[b] = depth[a] + 1;
        up[b][0] = a;
        for(int i = 1; i < LOG; i++){
            up[b][i] = up[up[b][i-1]][i - 1];
        }
        dfs(b, a);
    }
}
 
int get_lca(int a, int b){ //O(log(n))
    if(depth[a] < depth[b]){
        swap(a, b);
    }
    //1. move one of them to the same depth
    int k = depth[a] - depth[b];
    for(int j = LOG - 1; j >= 0; j--){
        if(k & (1 << j)){
            a = up[a][j];
        }
    }
    //2. if b is the ancestor of a then return b
    if(a == b) return a;
    //3. move until they met.
    for(int j = LOG - 1; j >= 0; j--){
        if(up[a][j] != up[b][j]){
            a = up[a][j];
            b = up[b][j];
        }
    }
    return up[a][0];
}

int32_t main() {
    int V, E; cin>>V>>E;
    for(int i = 0; i < V; i++){
        edges[i].clear();
    }
    memset(up, sizeof up, 0);
    for(int i = 0; i < E; i++){
        int a, b; cin>>a>>b;
        a--; b--;
        edges[a].push_back(b);
        edges[b].push_back(a);
    }
    dfs(0, -1); // build up
    int q; cin>>q;
    while(q--){
        int a, b; cin>>a>>b;
        a--; b--;
        cout<<get_lca(a, b) + 1<<'\n';
    }    
}

comments powered by Disqus