Clean Java code with Union Find


  • 1
    R
    public class Solution {
        int rows, cols;
        
        public void solve(char[][] board) {
            if(board == null || board.length == 0) return;
            
            rows = board.length;
            cols = board[0].length;
            
            // last one is dummy, all outer O are connected to this dummy
            UnionFind uf = new UnionFind(rows * cols + 1);
            int dummyNode = rows * cols;
            
            for(int i = 0; i < rows; i++) {
                for(int j = 0; j < cols; j++) {
                    if(board[i][j] == 'O') {
                        if(i == 0 || i == rows-1 || j == 0 || j == cols-1) {
                            uf.union(node(i,j), dummyNode);
                        }
                        else {
                            if(i > 0 && board[i-1][j] == 'O')  uf.union(node(i,j), node(i-1,j));
                            if(i < rows && board[i+1][j] == 'O')  uf.union(node(i,j), node(i+1,j));
                            if(j > 0 && board[i][j-1] == 'O')  uf.union(node(i,j), node(i, j-1));
                            if(j < cols && board[i][j+1] == 'O')  uf.union(node(i,j), node(i, j+1));
                        }
                    }
                }
            }
            
            for(int i = 0; i < rows; i++) {
                for(int j = 0; j < cols; j++) {
                    if(uf.isConnected(node(i,j), dummyNode)) {
                        board[i][j] = 'O';
                    }
                    else {
                        board[i][j] = 'X';
                    }
                }
            }
        }
        
        int node(int i, int j) {
            return i * cols + j;
        }
    }
    
    class UnionFind {
        int [] parents;
        public UnionFind(int totalNodes) {
            parents = new int[totalNodes];
            for(int i = 0; i < totalNodes; i++) {
                parents[i] = i;
            }
        }
        
        void union(int node1, int node2) {
            int root1 = find(node1);
            int root2 = find(node2);
            if(root1 != root2) {
                parents[root2] = root1;
            }
        }
        
        int find(int node) {
            while(parents[node] != node) {
                parents[node] = parents[parents[node]]; // path compression
                node = parents[node];
            }
            
            return node;
        }
        
        boolean isConnected(int node1, int node2) {
            return find(node1) == find(node2);
        }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.