C++ Union Find Solution with explanation


  • 0
    R

    Create a node for each number with parent as NULL
    if the node has not been visited, the parent of the node is NULL, if it has been visited, it have a parents that either could be itself or some other nodes. So if both node in a pair has NULL as parent, we increment the count, if both has parent and are different, we decrement the count. Using set to keep track of total number of unvisited node.

    class Solution {
    public:
        struct node{
            int val;
            node* parent;
            node(int x) : val(x), parent(NULL){}
        };
        node* find(node* a){
            node* t = a;
            if(a->parent == NULL){
                return NULL;
            }
            while(t->parent != NULL && t->parent != t){
                t = t->parent;
            }
            return t;
        }
        
        int u(node* a, node* b){
            node* m = find(a);
            node* n = find(b);
            bool res = false;
            if(m == NULL && n == NULL){
                a->parent = a;
                b->parent = a;
                res = true;
                return 1;
            }
            if(m == NULL){
                a->parent = n;
                return 0;
            }
            else if(n == NULL){
                b->parent = m;
                return 0;
            }
            else{
                if(m != n){
                    m->parent = n;
                    return -1;
                }
            }
            return res;
        }
        
        int countComponents(int n, vector<pair<int, int>>& edges) {
            vector<node*> v;
            int cnt = 0;
            unordered_set<int> s;
            for(int i = 0; i < n; i++){
                s.insert(i);
                node* n = new node(i);
                v.push_back(n);
            }
            
            for(int i = 0; i < edges.size(); i++){
                int a = edges[i].first;
                int b = edges[i].second;
                cnt += u(v[a], v[b]);
                if(s.find(a) != s.end()){
                    s.erase(a);
                }
                if(s.find(b) != s.end()){
                    s.erase(b);
                }
            }
            return cnt + s.size();
        }
    };

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.