Java/Python clear solution with UnionFind Class (Weighting and Path compression)


  • 109

    Union Find
    is an abstract data structure supporting find and unite on disjointed sets of objects, typically used to solve the network connectivity problem.

    The two operations are defined like this:

    find(a,b) : are a and b belong to the same set?

    unite(a,b) : if a and b are not in the same set, unite the sets they belong to.

    With this data structure, it is very fast for solving our problem. Every position is an new land, if the new land connect two islands a and b, we combine them to form a whole. The answer is then the number of the disjointed sets.

    The following algorithm is derived from Princeton's lecture note on Union Find in Algorithms and Data Structures It is a well organized note with clear illustration describing from the naive QuickFind to the one with Weighting and Path compression.
    With Weighting and Path compression, The algorithm runs in O((M+N) log* N) where M is the number of operations ( unite and find ), N is the number of objects, log* is iterated logarithm while the naive runs in O(MN).

    For our problem, If there are N positions, then there are O(N) operations and N objects then total is O(N log*N), when we don't consider the O(mn) for array initialization.

    Note that log*N is almost constant (for N = 265536, log*N = 5) in this universe, so the algorithm is almost linear with N.

    However, if the map is very big, then the initialization of the arrays can cost a lot of time when mn is much larger than N. In this case we should consider using a hashmap/dictionary for the underlying data structure to avoid this overhead.

    Of course, we can put all the functionality into the Solution class which will make the code a lot shorter. But from a design point of view a separate class dedicated to the data sturcture is more readable and reusable.

    I implemented the idea with 2D interface to better fit the problem.

    Java

    public class Solution {
    
        private int[][] dir = {{0, 1}, {0, -1}, {-1, 0}, {1, 0}};
    
        public List<Integer> numIslands2(int m, int n, int[][] positions) {
            UnionFind2D islands = new UnionFind2D(m, n);
            List<Integer> ans = new ArrayList<>();
            for (int[] position : positions) {
                int x = position[0], y = position[1];
                int p = islands.add(x, y);
                for (int[] d : dir) {
                    int q = islands.getID(x + d[0], y + d[1]);
                    if (q > 0 && !islands.find(p, q))
                        islands.unite(p, q);
                }
                ans.add(islands.size());
            }
            return ans;
        }
    }
    
    class UnionFind2D {
        private int[] id;
        private int[] sz;
        private int m, n, count;
    
        public UnionFind2D(int m, int n) {
            this.count = 0;
            this.n = n;
            this.m = m;
            this.id = new int[m * n + 1];
            this.sz = new int[m * n + 1];
        }
    
        public int index(int x, int y) { return x * n + y + 1; }
    
        public int size() { return this.count; }
    
        public int getID(int x, int y) {
            if (0 <= x && x < m && 0<= y && y < n)
                return id[index(x, y)];
            return 0;
        }
    
        public int add(int x, int y) {
            int i = index(x, y);
            id[i] = i; sz[i] = 1;
            ++count;
            return i;
        }
    
        public boolean find(int p, int q) {
            return root(p) == root(q);
        }
    
        public void unite(int p, int q) {
            int i = root(p), j = root(q);
            if (sz[i] < sz[j]) { //weighted quick union
                id[i] = j; sz[j] += sz[i];
            } else {
                id[j] = i; sz[i] += sz[j];
            }
            --count;
        }
    
        private int root(int i) {
            for (;i != id[i]; i = id[i])
                id[i] = id[id[i]]; //path compression
            return i;
        }
    }
    //Runtime: 20 ms
    

    Python (using dict)

    class Solution(object):
        def numIslands2(self, m, n, positions):
            ans = []
            islands = Union()
            for p in map(tuple, positions):
                islands.add(p)
                for dp in (0, 1), (0, -1), (1, 0), (-1, 0):
                    q = (p[0] + dp[0], p[1] + dp[1])
                    if q in islands.id:
                        islands.unite(p, q)
                ans += [islands.count]
            return ans
    
    class Union(object):
        def __init__(self):
            self.id = {}
            self.sz = {}
            self.count = 0
    
        def add(self, p):
            self.id[p] = p
            self.sz[p] = 1
            self.count += 1
    
        def root(self, i):
            while i != self.id[i]:
                self.id[i] = self.id[self.id[i]]
                i = self.id[i]
            return i
    
        def unite(self, p, q):
            i, j = self.root(p), self.root(q)
            if i == j:
                return
            if self.sz[i] > self.sz[j]:
                i, j = j, i
            self.id[i] = j
            self.sz[j] += self.sz[i]
            self.count -= 1
    
    #Runtime: 300 ms
    

  • 2

    Thank @peisi for sharing such nice material on UnionFind. Find someone who may get confused on log*: it means the number of times that we need to take log (base 2) on a number to make it become 1. You may refer to page 31 of the linked note for it :-)


  • 0

    Yes. I forgot to mention that.
    See also wikipedia iterated logarithm:

    https://en.wikipedia.org/wiki/Iterated_logarithm


  • 0
    T

    Very nice. Thanks so much for the solution.


  • 0
    H

    I would like to clarify following line
    int q = islands.getID(x + d[0], y + d[1]);

    Should it call getID or getIndex ?
    When it calls islands.unite(p, q) - here p is index returned in add. So I think q also should be index and not ID.


  • 0
    O

    same idea, here is my concise solution - 20ms

    public class Solution {
        private static final int[][] dir = {{0, 1},{1, 0},{0, -1},{-1, 0}};
        
        public List<Integer> numIslands2(int n, int m, int[][] positions) {
            int[][] map = new int[n + 2][m + 2];
            List<Integer> ans = new ArrayList();
            int islandN = 0;
            UnionSet us = new UnionSet(n, m);
            
            for (int[] p : positions) {
                map[p[0] + 1][p[1] + 1] = 1;
                islandN++;
                for (int[] d : dir)
                    if (map[p[0] + d[0] + 1][p[1] + d[1] + 1] > 0 && us.union(p[0], p[1], p[0] + d[0], p[1] + d[1])) 
                        islandN--;
                ans.add(islandN);
            }
            return ans;
        }
        
        private class UnionSet {
            int n, m;
            int[] p, size;
            
            public UnionSet(int a, int b) {
                n = a; m = b;
                p = new int[getID(n, m)];
                size = new int[getID(n, m)];
            }
            
            private int getID(int i, int j) {
                return i * m + j + 1; // ensure no id == 0;
            }
            
            private int find(int i) {
                if (p[i] == 0) { // == 0 means not yet initialized
                    p[i] = i;
                    size[i] = 1;
                }
                p[i] = (p[i] == i) ? i : find(p[i]);
                return p[i];
            }
            
            private boolean union(int i1, int j1, int i2, int j2) { // true if combines two element of two different sets
                int s1 = find(getID(i1, j1)), s2 = find(getID(i2, j2));
                if (s1 == s2) return false;
                if (size[s1] > size[s2]) {
                    p[s2] = s1;
                    size[s1] += size[s2];
                } else {
                    p[s1] = s2;
                    size[s2] += size[s1];
                }
                return true;
            }
        }
    }

  • 0
    T

    UnionFind2D islands = new UnionFind2D(m, n); // it uses O(m*n) time


  • 1
    C

    For weighting why not use the depth of the tree instead of the size of the tree?


  • 0

    @dietpepsi said in Java/Python clear solution with UnionFind Class (Weighting and Path compression):

    Note that log*N is almost constant (for N = 265536, log*N = 5) in this universe

    265536 isn't that big...


  • 0
    P

    @StefanPochmann It is actually N = 2*65536, logN = 5


  • 0

    @peteraristo said in Java/Python clear solution with UnionFind Class (Weighting and Path compression):

    @StefanPochmann It is actually N = 2*65536, logN = 5

    That's not what it says there. Also, about what you just wrote: If N is 2*65536 then logN isn't 5 unless you have an extremely unusual base, which you should provide.

    Just pointing out markdown accidents in hopes they'll get fixed.


  • 1
    P

    @StefanPochmann I thought it was my typo but it turns out to be the formatting bug. I format it as follows.

    It is N = 2**65536, log*N = 5
    log*(n) is the function that counts the times it takes doing log() to get to 1.
    Here when N = 2**65536, log(log(log(log(log(N)))))=1, it takes 5 log functions, so log*N = 5.


  • 0
    C

    very good codes and excellent explanation

    I like the post with articles. Reading paper is far more interesting and important than just finding a solution for the problem


  • 0

    @dietpepsi Thanks for sharing. Nice solution.
    I dont understand why we should set the length of id array and size array to be [m * n + 1] so I tried to use m * n as length of array. No syntax error occur but the results are not correct.
    For the test case:

    3
    3
    [[0,0],[0,1],[1,2],[2,1]]
    

    I got [1,2,3,4] while we should get [1,1,2,3]

    I dont see the difference between m * n + 1 and m * n and I am also confused about the difference between these two results. Can anyone explain? Thanks a lot.


  • 0

    @Tōsaka-Rin

    Pls don't over think it. Notice -

        public int index(int x, int y) { return x * n + y + 1; }
    

    and

        int i = index(x, y);
        id[i] = i; sz[i] = 1;
    

    when x = 0 and y = 0, index i got its min value 1.
    when x = m - 1 and y = n - 1, index i got its max valuem * n.

    so.. i starts from 1, and its range is [1, m * n].
    id[] array has 0-based numbering, and its size should be m * n + 1, with its value at id[0] unchanged.


  • 0

    @zzhai Thanks for explaining.
    I see where the min index and max value come but I define the index to be idx = x * n + y then the range should start from 0 to m * n - 1. So..idx starts from 0 and ends up with m * n - 1. I dont quite understand why we need to have another set but unchanged value(id[0] as you said). Anyway, thanks for explanation.


  • 0

  • 0

    @zzhai So the difference is you initially fill the id array with -1?


  • 0
    H

    @dietpepsi is 265336 a typo by any chance? I thought you meant N=2^65536 there instead of N=265336.


  • 0

    Really love your python version. Neat and clean.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.