C++ Disjoint-set Solution


  • 4
    G
    struct SetNode {
    	SetNode* parent;
    	int rank;
    };
    
    void make_set(SetNode* node) {
    	node->parent = node;
    	node->rank = 0;
    }
    
    SetNode* find(SetNode* node) {
    	if (node == node->parent)
    		return node;
    	else {
    		node->parent = find(node->parent);
    		return node->parent; 
    	}
    }
    
    void set_union(SetNode* node1, SetNode* node2) {
    	
    	SetNode* root1 = find(node1);
    	SetNode* root2 = find(node2);
    
    	if (root1 == root2)
    		return;
    
    	if (root1->rank > root2->rank)
    		root2->parent = root1;
    	else if (root1->rank < root2->rank)
    		root1->parent = root2;
    	else {
    		root1->parent = root2;
    		root2->rank += 1;
    	}
    }
    
    
    class Solution {
    public:
    	vector<int> numIslands2(int m, int n, vector<pair<int, int>>& positions) {
    
    		vector<SetNode> nodes(positions.size());
    		for (int i = 0; i < positions.size(); ++i)
    			make_set(&nodes[i]);
    
    		unordered_map<int, SetNode*> mm;
    		vector<int> res;
    
    		for (int i = 0; i < positions.size(); ++i)
    		{
    			int r = positions[i].first, c = positions[i].second;
    
    			int val = 1;
    			if (r - 1 >= 0) {
    				int idx = (r - 1) * n + c;
    				if (mm.count(idx)) {
    					if (find(&nodes[i]) != find(mm[idx])) {
    						--val;
    						set_union(&nodes[i], mm[idx]);
    					}
    				}
    			}
    
    			if (r + 1 < m) {
    				int idx = (r + 1) * n + c;
    				if (mm.count(idx)) {
    					if (find(&nodes[i]) != find(mm[idx])) {
    						--val;
    						set_union(&nodes[i], mm[idx]);
    					}
    				}
    			}
    
    			if (c - 1 >= 0) {
    				int idx = r * n + c - 1;
    				if (mm.count(idx)) {
    					if (find(&nodes[i]) != find(mm[idx])) {
    						--val;
    						set_union(&nodes[i], mm[idx]);
    					}
    				}
    			}
    
    			if (c + 1 < n) {
    				int idx = r * n + c + 1;
    				if (mm.count(idx)) {
    					if (find(&nodes[i]) != find(mm[idx])) {
    						--val;
    						set_union(&nodes[i], mm[idx]);
    					}
    				}
    			}
    
    			int idx = r * n + c;
    			mm[idx] = &nodes[i];
    
    			if (res.size() == 0)
    				res.push_back(1);
    			else {
    				res.push_back(res.back() + val);
    			}
    		}
    
    		return res;
    	}
    };

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.