Solve it using Union Find


  • 91
    J
    class UF
    {
    private:
    	int* id;     // id[i] = parent of i
    	int* rank;  // rank[i] = rank of subtree rooted at i (cannot be more than 31)
    	int count;    // number of components
    public:
    	UF(int N)
    	{
    		count = N;
    		id = new int[N];
    		rank = new int[N];
    		for (int i = 0; i < N; i++) {
    			id[i] = i;
    			rank[i] = 0;
    		}
    	}
    	~UF()
    	{
    		delete [] id;
    		delete [] rank;
    	}
    	int find(int p) {
    		while (p != id[p]) {
    			id[p] = id[id[p]];    // path compression by halving
    			p = id[p];
    		}
    		return p;
    	}
    	int getCount() {
    		return count;
    	}
    	bool connected(int p, int q) {
    		return find(p) == find(q);
    	}
    	void connect(int p, int q) {
    		int i = find(p);
    		int j = find(q);
    		if (i == j) return;
    		if (rank[i] < rank[j]) id[i] = j;
    		else if (rank[i] > rank[j]) id[j] = i;
    		else {
    			id[j] = i;
    			rank[i]++;
    		}
    		count--;
    	}
    };
    
    class Solution {
    public:
        void solve(vector<vector<char>> &board) {
            int n = board.size();
            if(n==0)    return;
            int m = board[0].size();
            UF uf = UF(n*m+1);
            
            for(int i=0;i<n;i++){
                for(int j=0;j<m;j++){
                    if((i==0||i==n-1||j==0||j==m-1)&&board[i][j]=='O') // if a 'O' node is on the boundry, connect it to the dummy node
                        uf.connect(i*m+j,n*m);
                    else if(board[i][j]=='O') // connect a 'O' node to its neighbour 'O' nodes
                    {
                        if(board[i-1][j]=='O')
                            uf.connect(i*m+j,(i-1)*m+j);
                        if(board[i+1][j]=='O')
                            uf.connect(i*m+j,(i+1)*m+j);
                        if(board[i][j-1]=='O')
                            uf.connect(i*m+j,i*m+j-1);
                        if(board[i][j+1]=='O')
                            uf.connect(i*m+j,i*m+j+1);
                    }
                }
            }
            
            for(int i=0;i<n;i++){
                for(int j=0;j<m;j++){
                    if(!uf.connected(i*m+j,n*m)){ // if a 'O' node is not connected to the dummy node, it is captured
                        board[i][j]='X';
                    }
                }
            }
        }
    };
    

    Hi. So here is my accepted code using Union Find data structure. The idea comes from the observation that if a region is NOT captured, it is connected to the boundry. So if we connect all the 'O' nodes on the boundry to a dummy node, and then connect each 'O' node to its neighbour 'O' nodes, then we can tell directly whether a 'O' node is captured by checking whether it is connected to the dummy node.
    For more about Union Find, the first assignment in the algo1 may help:
    https://www.coursera.org/course/algs4partI


  • 56
    N

    just another version in java:

    public class Solution {
        
        int[] unionSet; // union find set
        boolean[] hasEdgeO; // whether an union has an 'O' which is on the edge of the matrix
        
        public void solve(char[][] board) {
            if(board.length == 0 || board[0].length == 0) return;
            
            // init, every char itself is an union
            int height = board.length, width = board[0].length;
            unionSet = new int[height * width];
            hasEdgeO = new boolean[unionSet.length];
            for(int i = 0;i<unionSet.length; i++) unionSet[i] = i;
            for(int i = 0;i<hasEdgeO.length; i++){
                int x = i / width, y = i % width;
                hasEdgeO[i] = (board[x][y] == 'O' && (x==0 || x==height-1 || y==0 || y==width-1));
            }
            
            // iterate the matrix, for each char, union it + its upper char + its right char if they equals to each other
            for(int i = 0;i<unionSet.length; i++){
                int x = i / width, y = i % width, up = x - 1, right = y + 1;
                if(up >= 0 && board[x][y] == board[up][y]) union(i,i-width);
                if(right < width && board[x][y] == board[x][right]) union(i,i+1);
            }
            
            // for each char in the matrix, if it is an 'O' and its union doesn't has an 'edge O', the whole union should be setted as 'X'
            for(int i = 0;i<unionSet.length; i++){
                int x = i / width, y = i % width;
                if(board[x][y] == 'O' && !hasEdgeO[findSet(i)]) 
                    board[x][y] = 'X'; 
            }
        }
        
        private void union(int x,int y){
            int rootX = findSet(x);
            int rootY = findSet(y);
            // if there is an union has an 'edge O',the union after merge should be marked too
            boolean hasEdgeO = this.hasEdgeO[rootX] || this.hasEdgeO[rootY];
            unionSet[rootX] = rootY;
            this.hasEdgeO[rootY] = hasEdgeO;
        }
        
        private int findSet(int x){
            if(unionSet[x] == x) return x;
            unionSet[x] = findSet(unionSet[x]);
            return unionSet[x];
        }
    }

  • 0
    S
    This post is deleted!

  • 0
    S

    Thanks for your post. However it would be better to share solution with correct code format and elaborated thoughts. Please read the Discuss FAQ for more info. Take a look at good sharing example


  • 3
    J

    I like your Java version. So I write another C++ version based on yours. :p

    class Solution {
    private:
        struct Point {
            int x;
            int y;
            Point* group_;
            bool hasExit_;
            bool hasExit() {
                if (!hasExit_ && this != group_ && group_->hasExit()) {
                    hasExit_ = true;  // speed optimization
                }
                return hasExit_;
            }
            void hasExit(bool a) {
                hasExit_ = a;
            }
            Point* getGroup() {
                if (group_ != this) {
                    group_ = group_->getGroup();  // speed optimization
                }
                return group_;
            }
            void setGroup(Point* point) {
                group_ = point;
            }
            void connectTo(Point &point) {
                Point* p = point.getGroup();
                Point* q = getGroup();
                bool b = q->hasExit() || p->hasExit();
                p->hasExit(b);
                q->setGroup(p);
            }
            Point(int x_, int y_) : x(x_), y(y_) {}
        };
    public:
        void solve(vector<vector<char>> &board) {
            if (board.empty() || board[0].empty()) return;
            int rows = board.size();
            int cols = board[0].size();
            vector<vector<Point>> points(rows, vector<Point>(cols, Point(0, 0)));
            for (int i = 0; i < rows; i++) {
                for (int j = 0; j < cols; j++) {
                    Point &p = points[i][j];
                    p.x = i;
                    p.y = j;
                    p.setGroup(&p);
                    p.hasExit(i == 0 || j == 0 || i == rows - 1 || j == cols - 1);
                }
            }
            for (int i = 0; i < rows; i++) {
                for (int j = 0; j < cols; j++) {
                    if (board[i][j] != 'O') continue;
                    if (i > 0 && board[i-1][j] == board[i][j]) {
                        points[i][j].connectTo(points[i-1][j]);  // it -> up
                    }
                    if (j > 0 && board[i][j-1] == board[i][j]) {
                        points[i][j].connectTo(points[i][j-1]);  // it -> left
                    }
                }
            }
            for (int i = 0; i < rows; i++) {
                for (int j = 0; j < cols; j++) {
                    if (board[i][j] != 'O') continue;
                    if (!points[i][j].hasExit()) board[i][j] = 'X';
                }
            }
        }
    };
    

  • 0
    C

    Yes, this is similar to the two-pass algorithm for connected-component labeling.


  • 0
    G

    Nice!! I got the same source of inspiration as you had!!

     class Solution {
    private:
        struct node {
            bool percolates;
            int cnt;
            struct node *root;
            
            node():percolates(false),cnt(0),root(NULL) {};
        };
        
        struct node *findRoot(struct node *now) {
            if (now->root)  return findRoot(now->root);
            else            return now;
        }
        
        void connect(struct node *p1, struct node *p2) {
            struct node *root1 = findRoot(p1);
            struct node *root2 = findRoot(p2);
            
            if (root1 == root2) return;
            
            if (root1->percolates) {
                assert(!root2->percolates);
                root2->root = root1;
            } else if (root2->percolates) {
                assert(!root1->percolates);
                root1->root = root2;
            } else {
                if (root1->cnt >= root2->cnt) {
                    root1->cnt += (root2->cnt + 1);
                    root2->root = root1;
                } else {
                    root2->cnt += (root1->cnt + 1);
                    root1->root = root2;
                }
            }
        }
        
        bool is_percolates(struct node *now) {
            return findRoot(now)->percolates;
        }
        
    public:
        void solve(vector<vector<char>> &board) {
            struct node *percolate = new struct node;
            percolate->percolates = true;
            
            vector<vector<struct node *> > table;
            
            for (int i = 0; i < board.size(); i++) {
                vector<struct node *> temp;
                for (int j = 0; j < board[i].size(); j++) {
                    struct node *now = new struct node;
                    if (board[i][j] == 'O') {
                        if (!i || !j || (i+1) == board.size() || (j+1) == board[i].size()) {
                            //Boundery case
                            connect(percolate, now);
                        }
                        //Check upper one and left one
                        if (j && board[i][j-1] == 'O') connect(temp[j-1], now);
                        if (i && board[i-1][j] == 'O') connect(table[i-1][j], now);
                    }
                    
                    temp.push_back(now);
                }
                table.push_back(temp);
            }
            
            for (int i = 0; i < board.size(); i++) {
                for (int j = 0; j < board[i].size(); j++) {
                    if (is_percolates(table[i][j])) board[i][j] = 'O';
                    else                            board[i][j] = 'X';
                }
            }
        }
    };

  • 0
    I

    Great solution!


  • 1
    K

    Why can't rank be more than 31 here?


  • 1
    H

    two suggestions

    1. you can combine the initialization of unionSet[] and hasEdgeO[] together
    2. in the union part, you can only find and union "O"s, now you are doing union-find to all "x"s as well

  • 0
    Y

    I think rank is actually the size of the connected components. In the function connect(int p, int q), it calls find(p) and find(q), note the find() function will do "path compression", which will flatten the tree. So I think rank is a little misleading here...I think size will be a more appropriate name. reference: https://www.cs.princeton.edu/~rs/AlgsDS07/01UnionFind.pdf


  • 0
    N

    Is checking left neighbor and above neighbor necessary. I don't see https://leetcode.com/discuss/52833/8ms-c-solution-using-union-find-with-a-dummy-point solution check them.


  • 0
    D

    To check left & above is needed here but not needed in the link you provide is because here the bottom bound & right bound is set to only union() with RootO. If you omit the part to test left & above, then these two bounds will never union() with their above/left neighbors.
    So if you union() those bounds with RootO in advance, and loop through all, then you can check only two sides.


  • 4
    F

    Nice thought. I came up with this Java union-find with path compression and weighted union. Currently its run time is 17 ms. Can this be further improved? Thank you.

    public class Solution {
    
        private int[] ids;
        // Weight (size) of each union set
        private int[] sizes;
        // The id of union set for 'O's on edge
        private int OOnEdge;
        int m;
        int n;
    
        public void solve(char[][] board) {
            if((m = board.length) == 0 || (n = board[0].length) == 0) return;
    
            ids = new int[m * n];
            sizes = new int[m * n];
            Arrays.fill(sizes, 1);
            OOnEdge = -1;
    
            for (int i = 0; i < m; i++) {
                for (int j = 0; j < n; j++) {
                    if (board[i][j] == 'X') {
                        continue;
                    }
                    int index = i * n + j;
                    ids[index] = index;
                    // Nodes on edges
                    if (i == 0 || j == 0 || i == m - 1 || j == n - 1) {
                        if (OOnEdge == -1) {
                            // Set OOnEdge if it has not been set yet
                            OOnEdge = index;
                        } else {
                            // If OOnEdge is already set, unite it with index
                            unite(OOnEdge, index);
                        }
                    }
                    // Unite node with its upper neighbor
                    if (i > 0 && board[i - 1][j] == 'O') {
                        unite(index, index - n);
                    }
                    // Unite node with its left neighbor
                    if (j > 0 && board[i][j - 1] == 'O') {
                        unite(index, index - 1);
                    }
                }
            }
    
            // Find
            for (int i = 1; i < m - 1; i++) {
                for (int j = 1; j < n - 1; j++) {
                    if (board[i][j] == 'X') {
                        continue;
                    }
                    int index = i * n + j;
                    if (OOnEdge == -1 || !find(index, OOnEdge)) {
                        board[i][j] = 'X';
                    }
                }
            }
        }
    
        private void unite(int a, int b){
            int i = findRoot(a);
            int j = findRoot(b);
    
            // Weighted quick union
            if (sizes[i] < sizes[j]) {
                ids[i] = j;
                sizes[j] += sizes[i];
            } else {
                ids[j] = i;
                sizes[i] += sizes[j];
            }
        }
    
        private boolean find(int a, int b){
            return findRoot(a) == findRoot(b);
        }
    
        private int findRoot(int i) {
            while (i != ids[i]) {
                // Path compression
                ids[i] = ids[ids[i]];
                i = ids[i];
            }
    
            return i;
        }
    }

  • 0
    D

    Agreed, the rank array should be the size array. But when do the path compression, shouldn't we also change the rank array? like this:
    rank[id[i]] -= rank[i];
    id[i] = id[id[i]];


  • 0
    W

    @jakwings why just have the left and up


  • 6
    R

    Cleaner Java code

    public class Solution {
        int rows, cols;
        
        public void solve(char[][] board) {
            if(board == null || board.length == 0) return;
            
            rows = board.length;
            cols = board[0].length;
            
            // last one is dummy, all outer O are connected to this dummy
            UnionFind uf = new UnionFind(rows * cols + 1);
            int dummyNode = rows * cols;
            
            for(int i = 0; i < rows; i++) {
                for(int j = 0; j < cols; j++) {
                    if(board[i][j] == 'O') {
                        if(i == 0 || i == rows-1 || j == 0 || j == cols-1) {
                            uf.union(node(i,j), dummyNode);
                        }
                        else {
                            if(i > 0 && board[i-1][j] == 'O')  uf.union(node(i,j), node(i-1,j));
                            if(i < rows-1 && board[i+1][j] == 'O')  uf.union(node(i,j), node(i+1,j));
                            if(j > 0 && board[i][j-1] == 'O')  uf.union(node(i,j), node(i, j-1));
                            if(j < cols-1 && board[i][j+1] == 'O')  uf.union(node(i,j), node(i, j+1));
                        }
                    }
                }
            }
            
            for(int i = 0; i < rows; i++) {
                for(int j = 0; j < cols; j++) {
                    if(uf.isConnected(node(i,j), dummyNode)) {
                        board[i][j] = 'O';
                    }
                    else {
                        board[i][j] = 'X';
                    }
                }
            }
        }
        
        int node(int i, int j) {
            return i * cols + j;
        }
    }
    
    class UnionFind {
        int [] parents;
        public UnionFind(int totalNodes) {
            parents = new int[totalNodes];
            for(int i = 0; i < totalNodes; i++) {
                parents[i] = i;
            }
        }
        
        void union(int node1, int node2) {
            int root1 = find(node1);
            int root2 = find(node2);
            if(root1 != root2) {
                parents[root2] = root1;
            }
        }
        
        int find(int node) {
            while(parents[node] != node) {
                parents[node] = parents[parents[node]];
                node = parents[node];
            }
            
            return node;
        }
        
        boolean isConnected(int node1, int node2) {
            return find(node1) == find(node2);
        }
    }

  • 0

    @rhasan_82 You seems to use quick find instead of quick union in your solution, thus it may have difference with time complexity.


  • 0

    Can anyone explain what does the rank array in original solution do? Thanks a lot.


  • 0
    F

    Thanks for this unoin find way and I rewrite in python:

    class Solution(object):
        def solve(self, board):
            """
            :type board: List[List[str]]
            :rtype: void Do not return anything, modify board in-place instead.
            """
            if not board:
                return
            m = len(board)
            n = len(board[0])
            uf = UnionFind(m * n + 1)
            dummyNode = m * n
            for j in range(n):
                if board[0][j] == 'O':
                    uf.union(self.node(0, j, n), dummyNode)
                if board[m - 1][j] == 'O':
                    uf.union(self.node(m - 1, j, n), dummyNode)
            for i in range(m):
                if board[i][0] == 'O':
                    uf.union(self.node(i, 0, n), dummyNode)
                if board[i][n - 1] == 'O':
                    uf.union(self.node(i, n - 1, n), dummyNode)
            for i in range(1, m - 1):
                for j in range(1, n - 1):
                    if board[i][j] != 'O':
                        continue
                    if board[i - 1][j] == 'O':
                        uf.union(self.node(i - 1, j, n), self.node(i, j, n))
                    if board[i + 1][j] == 'O':
                        uf.union(self.node(i + 1, j, n), self.node(i, j, n))
                    if board[i][j - 1] == 'O':
                        uf.union(self.node(i, j - 1, n), self.node(i, j, n))
                    if board[i][j + 1] == 'O':
                        uf.union(self.node(i, j + 1, n), self.node(i, j, n))
            for i in range(m):
                for j in range(n):
                    if uf.find(self.node(i, j, n)) == uf.find(dummyNode):
                        board[i][j] = 'O'
                    else:
                        board[i][j] = 'X'
        def node(self, i, j, n):
            return i * n + j;
    class UnionFind(object):
        def __init__(self, n):
            self.count = n
            self.ids = [i for i in range(n)]
            self.sz = [1 for i in range(n)]
        def union(self, p, q):
            i = self.find(p)
            j = self.find(q)
            if i == j:
                return
            elif self.sz[i] < self.sz[j]:
                self.ids[i] = j
                self.sz[j] += self.sz[i]
                self.count -= 1
            else:
                self.ids[j] = i
                self.sz[i] += self.sz[j]
                self.count -= 1
        def find(self, p):
            while self.ids[p] != p:
                self.ids[p] = self.ids[self.ids[p]]
                p = self.ids[p]
            return p
    
        def count(self):
            return self.count
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.