C++ Trie with explanation


  • 1
    J
    Get idean from https://discuss.leetcode.com/topic/69199/c-trie-69ms-beats-85. Build up 32 level deep trie tree. Each level represents the bit value for each number. iterate the trie for each num : nums to calculate the xor result.
    Time Complexity 2*32*n. so it is O(n).
    class Solution {
    private:
        struct TrieNode
        {
            TrieNode* children[2];
            TrieNode()
            {
                children[0]=NULL;
                children[1]=NULL;
            }
        };
        
        class TrieTree{
        private:
            TrieNode* root;
        public:
            TrieTree()
            {
                root=new TrieNode;
            }
            TrieNode* getRoot(){return root;}
            void BuildTrie(int num)
            {
                TrieNode* cur=root;
                for(int i=31; i>=0; i--)
                {
                    int index=((num>>i)&1);
                    if(cur->children[index] == NULL)
                    {
                        cur->children[index]=new TrieNode;
                    }
                    cur=cur->children[index];
                }
            }
        };
    public:
        int findMaximumXOR(vector<int>& nums) {
            int len=nums.size();
            TrieTree* myTree=new TrieTree;
            for(auto num:nums)
            {
                myTree->BuildTrie(num);
            }
            
            int res=0;
            for(auto num:nums)
            {
                res=max(res, Helper(myTree->getRoot(), num));
            }
            
            return res;
        }
        
        int Helper(TrieNode* cur, int num)
        {
            int res=0;
            
            for(int i=31; i>=0; i--)
            {
                int index=((num>>i)&1)?0:1;
                if(cur->children[index])
                {
                    res<<=1;
                    res|=1;
                    cur=cur->children[index];
                }
                else
                {
                    res<<=1;
                    res|=0;
                    cur=cur->children[index?0:1];
                }
            }
            return res;
        }
    };
    

  • 0

    Hi @jill-brocli, thanks so much for sharing this easy to understand code! I have written a solution in C++11 with inspiration from your solution. For folks having a hard time ( like me ) understanding this problem and the various solutions, I found it helpful to step through the below code after changing the MAX_BIT_POS to 3 in order to ensure the proper max xor is calculated for small numbers less than 16.

    Here's how this solution works:

    Let's assume for simplicity that we have changed MAX_BIT_POS to 3, and that our vector contains the integers [ 6, 8, 10 ].

       3210 (bit position)
     6=0110
     8=1000
    10=1010
    

    First, Generate a trie by iterating through each number and creating the trie representation of that number in binary format, beginning with the left-most bit. Repeat for each number in the array, starting from the trie root for each number. Redundant trie nodes are left "as is" and are simply traversed over. New trie nodes are created when needed.

    Example iterations within the Generate function:

    num=6:
    
    6=0110
      3210 (bit position)
    
                 -----root-----
                /
              _0             <-- bit position 3
                \
                _1           <-- bit position 2
                  \
                  _1         <-- bit position 1
                   /
                 _0          <-- bit position 0
    
    num=8:
    
    8=1000
      3210 (bit position)
    
                 -----root-----
                /              \
              _0               _1    <-- bit position 3
                \              /
                _1           _0      <-- bit position 2
                  \          /
                  _1       _0        <-- bit position 1
                  /        /
                _0       _0          <-- bit position 0
    
    num=10:
    
    10=1010
       3210 (bit position)
    
                 -----root-----
                /              \
              _0               _1    <-- bit position 3
                \              /
                _1           _0      <-- bit position 2
                  \          / \
                  _1       _0  _1    <-- bit position 1
                  /        /   /
                _0       _0  _0      <-- bit position 0
    

    Second, this trie is traversed for each number to findMaximumXOR of each number when each number is maximally xor'ed against to ALL numbers' binary representations in the trie. This is the act of analyzing each number. Try to make the left-most bit a 1 in order to create the largest xor value possible. Repeat from left-most bit to the right-most bit. Each number is analyzed bit-by-bit while at the same time traversing the previously generated trie representation of ALL the numbers in binary format. Each number's bits are analyzed by beginning with the left-most bit and while simultaneously traversing the trie, starting at the root of the trie. We are checking if the trie at each bit position contains an opposite value of each bit in the current number under analysis. If so, then we know that the xor of two opposite bits is equal to 1, so we set current number's maximum xor result x's bit for this current bit position x|=1. Each time this loop iterates, the next bit to the right is analyzed, so we shift the current number's maximum xor result x left by one bit position in order to make room for the next bit's analysis result x <<= 1. This bit shift left-by-one also ensures that the bit set in x corresponds to the proper bit position previously analyzed. Once this loop is completed, x contains the maximum xor value for the current number under analysis compared to ALL other numbers. Check if x is larger than any previously calculated maximum xors, and update maxXOR for each new max xor.

    Third, if this is still hard to understand, I've added a commented chunk of this code at the very end of this post. Without further ado, here is the solution:

    #define MAX_BIT_POS 31
    
    class Solution{
    public:
        int findMaximumXOR(vector<int>& nums){
            auto trie=make_shared<TrieNode>();
            return trie->findMaximumXOR(nums);
        }
        
    private:
        class TrieNode{
        public:
            TrieNode() : _1{nullptr}, _0{nullptr} {}
            
            int findMaximumXOR(const vector<int>& nums){
                int maxXOR=0;
                auto root=Generate(nums);
                for (auto& num : nums){
                    int x=0;
                    auto curr=root;
                    for (int i=MAX_BIT_POS; i>=0; --i){
                        x <<= 1;
                        bool opposite = !(num & (1<<i));
                        if (opposite){
                            if (curr->_1){ x|=1; curr=curr->_1; }
                            else {               curr=curr->_0; }
                        } else {
                            if (curr->_0){ x|=1; curr=curr->_0; }
                            else {               curr=curr->_1; }
                        }
                    }
                    maxXOR=max(maxXOR,x);
                }
                return maxXOR;
            }
            
        private:
            shared_ptr<TrieNode> _1;
            shared_ptr<TrieNode> _0;
            
            shared_ptr<TrieNode> Generate(const vector<int>& nums){
                auto root=make_shared<TrieNode>();
                for (auto& num : nums){
                    auto curr=root;
                    for (int i=MAX_BIT_POS; i>=0; --i){
                        if (num & (1<<i)){
                            if (!curr->_1){ curr->_1=make_shared<TrieNode>(); }
                            curr=curr->_1;
                        } else {
                            if (!curr->_0){ curr->_0=make_shared<TrieNode>(); }
                            curr=curr->_0;
                        }
                    }
                }
                return root;
            }
        };
    };
    

    I believe that the hardest to understand piece of of this code is the following. I've added comments on top of this section of code to help with better understanding:

    //
    // find the value of the current number's current bit under analysis,
    // and try to find the opposite value in the current trie node's children
    //
    bool opposite = !(num & (1<<i));
    
    //
    // opposite is 1
    //
    if (opposite){                            
    
        //
        // opposite found in trie, set current number's max xor (x's) bit
        // for this bit position and traverse forward to this trie node
        //
        if (curr->_1){ x|=1; curr=curr->_1; } 
    
        //
        // opposite NOT found in trie, traverse the trie to the next trie node.
        // we know that the next trie node must be pointed towards through _0
        // since _1 does NOT exist (NOTE: the current number's max xor
        // result for this bit is 0, since an opposite bit value was NOT found.
        // this can be explicitly coded as x|=0 before iterating forward in the trie
        // however, the current value of this bit position is already 0.  So this is unnecessary)
        //
        else {               curr=curr->_0; } 
    
    //
    // opposite is 0
    //
    } else {                                  
        
        //
        // opposite found in trie, set current number's max xor (x's) bit
        // for this bit position and traverse forward to this trie node
        //
        if (curr->_0){ x|=1; curr=curr->_0; } 
    
        //
        // opposite NOT found in trie, traverse the trie to the next trie node.
        // we know that the next trie node must be pointed towards through _1
        // since _0 does NOT exist (NOTE: the current number's max xor
        // result for this bit is 0, since an opposite bit value was NOT found.
        // this can be explicitly coded as x|=0 before iterating forward in the trie
        // however, the current value of this bit position is already 0.  So this is unnecessary)
        //
        else {               curr=curr->_1; } 
    }
    
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.