Java Union Find Solution


  • 43
    C

    class UF {

    public int count = 0;
    public int[] id = null;
    
    public UF(int m, int n, char[][] grid) {
        for(int i = 0; i < m; i++) {
            for(int j = 0; j < n; j++) {
                if(grid[i][j] == '1') count++;
            }
        }
        id = new int[m * n];
        for(int i = 0; i < m * n; i++) {
            id[i] = i;
        }
    }
    
    public int find(int p) {
        while(p != id[p]) {
            id[p] = id[id[p]];
            p = id[p];
        }
        return p;
    }
    
    public boolean isConnected(int p, int q) {
        int pRoot = find(p);
        int qRoot = find(q);
        if(pRoot != qRoot) return false;
        else return true;
    }
    
    public void union(int p, int q) {
        int pRoot = find(p);
        int qRoot = find(q);
        if(pRoot == qRoot) return;
        id[pRoot] = qRoot;
        count--;
    }
    

    }

    public int numIslands(char[][] grid) {
        if(grid.length == 0 || grid[0].length == 0) return 0;
        int m = grid.length, n = grid[0].length;
        UF uf = new UF(m , n, grid);
        
        for(int i = 0; i < m; i++) {
            for(int j = 0; j < n; j++) {
                if(grid[i][j] == '0') continue;
                int p = i * n + j;
                int q;
                if(i > 0 && grid[i - 1][j] == '1') {
                    q = p - n;
                    uf.union(p, q);
                }
                if(i < m - 1 && grid[i + 1][j] == '1') {
                    q = p + n;
                    uf.union(p, q);
                }
                if(j > 0 && grid[i][j - 1] == '1') {
                    q = p - 1;
                    uf.union(p, q);
                }
                if(j < n - 1 && grid[i][j + 1] == '1') {
                    q = p + 1;
                    uf.union(p, q);
                }
            }
        }
        return uf.count;
    }

  • 34
    J

    You only need to do it for two directions, which are "right" and "down", because all "left" and "up" has already been seen if exists connection (undirected for Union Find).


  • 0
    J

    By the way, does anyone know the time complexity of using Union-Find?


  • 5
    J

    If done correctly, you can think of each operation as constant (amortized) time. This is not technically correct: they are slightly worse than constant. But it is so "slightly" that it will never, ever, matter to you.

    This implementation isn't quite correct, as the union function is not considering rank. (You should always point the tree of less rank to the tree of more rank.) So the "constant" runtime doesn't necessarily apply.


  • 0
    R

    You are so brilliant!!!


  • 6
    O
    public class Solution {
        public int numIslands(char[][] g) {
            if (g.length < 1 || g[0].length < 1) return 0;
            int n = g.length, m = g[0].length, island = 0;
            UnionFindSet uf = new UnionFindSet(n, m);
            
            for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                if (g[i][j] == '1') {
                    uf.find(i, j);
                    if (i > 0 && g[i - 1][j] == '1') uf.merge(i - 1, j, i, j);
                    if (j > 0 && g[i][j - 1] == '1') uf.merge(i, j - 1, i, j);
                }
            
            for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                if (g[i][j] == '1' && uf.isSetHead(i, j)) island++;
            
            return island;
        }
        
        private class UnionFindSet { // 2d
            int n, m;
            int[] size, p;
            
            public UnionFindSet(int nn, int mm) {
                n = nn; m = mm;
                size = new int[n * m]; p = new int[n * m];
                Arrays.fill(p, -1);
            }
            
            private int id(int i, int j) {
                return i * m + j;
            }
            
            private int find(int i, int j) {
                return find(id(i, j));
            }
            
            private int find(int x) {
                if (p[x] == -1) {
                    size[x] = 1;
                    p[x] = x;
                    return x;
                }
                p[x]  = (p[x] == x) ? x : find(p[x]);
                return p[x];
            }
            
            private void merge(int i1, int j1, int i2, int j2) {
                int s1 = find(i1, j1), s2 = find(i2, j2);
                if (s1 == s2) return;
                if (size[s1] > size[s2]) {
                    p[s2] = s1; size[s1] += size[s2];
                } else {
                    p[s1] = s2; size[s2] += size[s1];
                }
            }
            
            private boolean isSetHead(int i, int j) {
                return id(i, j) == find(i, j);
            }
        }
    }

  • 0
    X

    @jie27 very smart, thanks for the sharing


  • 0
    D
    This post is deleted!

  • 4
    T

    Thanks for sharing. I followed your logic and implement a version by using a more standard union-find class.

    public class Solution {
    public int numIslands(char[][] grid) {
    if(grid == null || grid.length == 0 || grid[0].length == 0) return 0;
    int row = grid.length;
    int col = grid[0].length;

        UnionFind island = new UnionFind(row, col, grid);
        
        for(int i = 0; i < row; i++){
            for(int j = 0; j < col; j++){
                if(grid[i][j] == '1'){
                    int p = i * col + j;
                    //right
                    if(j < col - 1 && grid[i][j + 1] == '1'){
                        int q = i * col + j + 1;
                        if(!island.find(p, q)){
                            island.union(p, q);
                        }
                    }
                    //down
                    if(i < row - 1 && grid[i + 1][j] == '1'){
                        int q = (i + 1) * col + j;
                        if(!island.find(p, q)){
                            island.union(p, q);
                        }
                    }
                }
            }
        }
        return island.size();
    }
    

    }

    class UnionFind{
    private int[] id, size;
    private int count;

    public UnionFind(int row, int col, char[][] grid){
        id = new int[row * col];
        size = new int[row * col];
        
        for(int i = 0; i < row; i++){
            for(int j = 0; j < col; j++){
                if(grid[i][j] == '1') this.count++;
            }
        }
        
        for(int i = 0; i < row * col; i++){
            id[i] = i;
            size[i] = 1;
        }
    }
    
    public int size(){return this.count;}
    
    private int root(int i){
        while(i != id[i]){
            id[i] = id[id[i]];
            i = id[i];
        }
        return i;
    }
    
    public boolean find(int p, int q){
        return root(p) == root(q);
    }
    
    public void union(int p, int q){
        int i = root(p);
        int j = root(q);
        
        if(size[i] < size[j]){
            id[i] = j;
            size[j] += size[i];
        }
        else{
            id[j] = i;
            size[i] -= size[j];
        }
        count --;
    }
    

    }


  • 0
    S

    Excellent solution ! Thank you for sharing.
    By the way, I think checking out grid[i+1][j], grid[i][j] is satisfying for each cell. Because checking grid[i-1][j], grid[i][j-1] is overlapped with grid[i-1][j]'s downward checking and grid[i][j-1]'s rightward checking.


  • 0
    M

    @tianqi5
    Great. This is so called the weightedUnion.
    In your union function, it is better to guard below
    int i = root(p); j = root(q);

    with if (i == j) return;


  • 1
    H

    Nice solution. I have a similar one using Union Find from the idea of "Number of Island II" with some improvements:

    1. We don't need to check for four directions (cells), if we scan from left-top to right-bottom, we only need to check for top and left cells.
    2. Optimize the UF by ranking and path compression.
    class Solution {
        int[][] dir = {{0,-1},{-1,0}};//only check for top and left cells
        
        public int numIslands(char[][] grid) {
            int m = grid.length;
            if(m==0) return 0;
            int n = grid[0].length;
            int[] ids = new int[m*n];
            Arrays.fill(ids,-1);
            int[] sz = new int[m*n];
            
            int count=0;
            for(int i=0;i<m;i++)
            {
                for(int j=0;j<n;j++)
                {
                    if(grid[i][j]=='1')
                    {
                        int id=n*i+j;
                        //if(ids[id]!=-1) continue;
                        count++;
                        ids[id]=id;
                        sz[id]=1;
                        
                        for(int[] d:dir)
                        {
                            int x=i+d[0];
                            int y=j+d[1];
                            if(x<0||y<0||x>=m||y>=n||grid[x][y]!='1') continue;
                            
                            int idnew = n*x+y;
                            int root = find(ids,idnew);
                            if(root!=id)
                            {
                                count--;
                                id=union(id,root,ids,sz);
                            }
                        }
                    }
                }
            }
            return count;
        }
        
        //quick find with path compression 
        public int find(int[] ids, int id) {
            while(ids[id]!=id)
            {
                ids[id]=ids[ids[id]];
                id = ids[id];
            }
            return id;
        }
        
        //weighted union
        public int union(int id, int root, int[] ids, int[] sz) {
            if(sz[id]<sz[root])
            {
                ids[id]=root;
                sz[root]+=sz[id];
                return root;
            }
            else
            {
                ids[root]=id;
                sz[id]+=sz[root];
                return id;
            }
        }
    }
    

  • 0
    J

    @hjy06 said in Java Union Find Solution:

    Nice solution. I have a similar one using Union Find from the idea of "Number of Island II" with some improvements:

    1. We don't need to check for four directions (cells), if we scan from left-top to right-bottom, we only need to check for top and left cells.
    2. Optimize the UF by ranking and path compression.
    class Solution {
        int[][] dir = {{0,-1},{-1,0}};//only check for top and left cells
        
        public int numIslands(char[][] grid) {
            int m = grid.length;
            if(m==0) return 0;
            int n = grid[0].length;
            int[] ids = new int[m*n];
            Arrays.fill(ids,-1);
            int[] sz = new int[m*n];
            
            int count=0;
            for(int i=0;i<m;i++)
            {
                for(int j=0;j<n;j++)
                {
                    if(grid[i][j]=='1')
                    {
                        int id=n*i+j;
                        //if(ids[id]!=-1) continue;
                        count++;
                        ids[id]=id;
                        sz[id]=1;
                        
                        for(int[] d:dir)
                        {
                            int x=i+d[0];
                            int y=j+d[1];
                            if(x<0||y<0||x>=m||y>=n||grid[x][y]!='1') continue;
                            
                            int idnew = n*x+y;
                            int root = find(ids,idnew);
                            if(root!=id)
                            {
                                count--;
                                id=union(id,root,ids,sz);
                            }
                        }
                    }
                }
            }
            return count;
        }
        
        //quick find with path compression 
        public int find(int[] ids, int id) {
            while(ids[id]!=id)
            {
                ids[id]=ids[ids[id]];
                id = ids[id];
            }
            return id;
        }
        
        //weighted union
        public int union(int id, int root, int[] ids, int[] sz) {
            if(sz[id]<sz[root])
            {
                ids[id]=root;
                sz[root]+=sz[id];
                return root;
            }
            else
            {
                ids[root]=id;
                sz[id]+=sz[root];
                return id;
            }
        }
    }
    

    refactored and becomes more readable

    class Solution {
        int[][] dir = {{0, -1}, {-1, 0}};//only check for top and left cells
        public int numIslands(char[][] grid) {//assume valid grid
            int rows = grid.length;
            if (rows == 0) return 0;
            int cols = grid[0].length;
            int[] roots = new int[rows * cols];
            Arrays.fill(roots, -1);
            int[] size = new int[rows * cols];
            
            int count = 0;
            for (int i = 0; i < rows; i++) {
                for (int j = 0; j < cols; j++) {
                    if (grid[i][j] == '1') {
                        int id = cols * i + j;//id for [i, j] block
                        if (roots[id] != -1) continue;
                        count++;
                        roots[id] = id;
                        size[id] = 1;
                        
                        for (int[] d : dir) {
                            int neighborX = i + d[0];
                            int neighborY = j + d[1];
                            if (neighborX < 0 || neighborY < 0 || neighborX >= rows || neighborY >= cols 
                                || grid[neighborX][neighborY] != '1') {
                                continue;
                            }
                            int neighborId = cols * neighborX + neighborY;
                            int root = find(roots, neighborId);
                            if (root != id) {
                                count--;//
                                id = union(id, root, roots, size);
                            }
                        }
                    }
                }
            }
            return count;
        }
        
        public int find(int[] roots, int i) {//quick find with path compression 
            while (roots[i] != i) {
                roots[i] = roots[roots[i]];
                i = roots[i];
            }
            return i;
        }
        
        public int union(int root1, int root2, int[] roots, int[] size) {//weighted union
            if (size[root1] < size[root2]) {
                roots[root1] = root2;
                size[root2] += size[root1];
                size[root1] = size[root2];//optional
                return root2;
            } else {//size[root1] >= size[root2]
                roots[root2] = root1;
                size[root1] += size[root2];
                size[root2] = size[root1];//optional
                return root1;
            }
        }
    }
    

    assume the # of 1 is n and the largest island contains k 1s, then the time complexity is O(n*k). space complexity is O(rows * cols)


  • 0
    S
    This post is deleted!

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.