Java O(n) solution using Trie


  • 45
    M
        class Trie {
            Trie[] children;
            public Trie() {
                children = new Trie[2];
            }
        }
        
        public int findMaximumXOR(int[] nums) {
            if(nums == null || nums.length == 0) {
                return 0;
            }
            // Init Trie.
            Trie root = new Trie();
            for(int num: nums) {
                Trie curNode = root;
                for(int i = 31; i >= 0; i --) {
                    int curBit = (num >>> i) & 1;
                    if(curNode.children[curBit] == null) {
                        curNode.children[curBit] = new Trie();
                    }
                    curNode = curNode.children[curBit];
                }
            }
            int max = Integer.MIN_VALUE;
            for(int num: nums) {
                Trie curNode = root;
                int curSum = 0;
                for(int i = 31; i >= 0; i --) {
                    int curBit = (num >>> i) & 1;
                    if(curNode.children[curBit ^ 1] != null) {
                        curSum += (1 << i);
                        curNode = curNode.children[curBit ^ 1];
                    }else {
                        curNode = curNode.children[curBit];
                    }
                }
                max = Math.max(curSum, max);
            }
            return max;
        }
    

  • 1
    A

    Good Solution, Using Trie makes problem much more clear and understandable.


  • 2
    A

    @airflyctl said in Java O(n) solution using Trie:

    Good Solution, Using Trie makes problem much more clear and understandable.

    But it can't pass the big set.


  • 5

    @airflyctl Yeah it gets me TLE as well. Ditching the Trie class and just using Object[] gets it accepted in about 185 ms:

    public class Solution {
        public int findMaximumXOR(int[] nums) {
            if(nums == null || nums.length == 0) {
                return 0;
            }
            // Init Trie.
            Object[] root = {null, null};
            for(int num: nums) {
                Object[] curNode = root;
                for(int i = 31; i >= 0; i --) {
                    int curBit = (num >>> i) & 1;
                    if(curNode[curBit] == null) {
                        curNode[curBit] = new Object[]{null, null};
                    }
                    curNode = (Object[]) curNode[curBit];
                }
            }
            int max = Integer.MIN_VALUE;
            for(int num: nums) {
                Object[] curNode = root;
                int curSum = 0;
                for(int i = 31; i >= 0; i --) {
                    int curBit = (num >>> i) & 1;
                    if(curNode[curBit ^ 1] != null) {
                        curSum += (1 << i);
                        curNode = (Object[]) curNode[curBit ^ 1];
                    }else {
                        curNode = (Object[]) curNode[curBit];
                    }
                }
                max = Math.max(curSum, max);
            }
            return max;
        }
    }
    

  • 0
    A

    @StefanPochmann Thanks for help : ) .


  • 0
    M

    @StefanPochmann Thanks for correcting. I sometimes passed the tests in 400ms, but sometimes also got TLE.


  • 0
    S

    I changed the line curSum += (1 << i); to curSum += (1 <<< i); and then there's the Compile Error saying error: > expected. Anyone knows why changing the arithmetic left shift to the logical one would have such error? I thought they'd behave the same: discarding bits on the left and padding 0 on the right.


  • 2
    K

    Very clear code. Here is my improvement:

    public int findMaximumXOR(int[] nums) {
            if(nums == null || nums.length == 0) {
                return 0;
            }
            // Init Trie.
            Trie root = new Trie();
            for(int num: nums) {
                Trie curNode = root;
                for(int i = 31; i >= 0; i --) {
                    int curBit = (num >>> i) & 1;
                    if(curNode.children[curBit] == null) {
                        curNode.children[curBit] = new Trie();
                    }
                    curNode = curNode.children[curBit];
                }
            }
            int max = Integer.MIN_VALUE;
            for(int num: nums) {
                Trie curNode = root;
                int curSum = 0;
                for(int i = 31; i >= 0; i --) {
                    int curBit = (num >>> i) & 1;
                    if(curNode.children[curBit ^ 1] != null) {
                        curSum += (1 << i);
                        curNode = curNode.children[curBit ^ 1];
                    }else {
                        curNode = curNode.children[curBit];
                    }
    
                    // for this case: even if all left bits results are 1s, curSum still cannot catch up max value
                    if (curSum < max && max - curSum >= (1 << i) - 1) {
                        break;
                    }
                }
                max = Math.max(curSum, max);
            }
            return max;
        }

  • 0
    X

    @keaton can u explain this part please? i don`t get why curNode = curNode.children[curBit ^ 1] ? it seems we are going to another branch which is not the current number: num?

    if(curNode.children[curBit ^ 1] != null) {
    curSum += (1 << i);
    curNode = curNode.children[curBit ^ 1];
    }


  • 0
    X

    could u please explain this line? y do u do this: curNode = curNode.children[curBit ^ 1]
    it seems we are using a different number to XOR? I think we should do curNode = curNode.children[curBit]. but it`s wrong. idk y?

    if(curNode.children[curBit ^ 1] != null) {
    curSum += (1 << i);
    curNode = curNode.children[curBit ^ 1];
    }


  • 0
    K

    @xin77 This section is trying to find a value "x" in nums, which makes "x ^ num" with the maximum value. So we are not going through path of num.


  • 0
    N

    @s in java, <<< is not allowed


  • 0
    M

    @StefanPochmann Using a custom binary trie, I could get 55ms. Not sure how much of that was attributed to recent OJ runtime improvements..

    public class Solution {
    static class BinaryTrie {   
        public BinaryTrie zero; 
        public BinaryTrie one;  
    }                           
    
    public int findMaximumXOR(int[] nums) {
                BinaryTrie trie = new BinaryTrie();
        for(int n: nums) {
            BinaryTrie node = trie;
            for(int i = 30; i > -1; i--) {
                if(((n >> i) & 1) == 0) {
                    if(node.zero == null)
                        node.zero = new BinaryTrie();
                    node = node.zero;
                } else {
                    if(node.one == null)
                        node.one = new BinaryTrie();
                    node = node.one;
                }
            }
        }
        int max = 0;
        for(int n: nums) {
            BinaryTrie node = trie;
            int xor = 0;
            for(int i = 30; i > -1; i--){
                int bit = 1 << i;
                if((bit & n) > 0) {
                    if(node.zero == null)
                        node = node.one;
                    else {
                        xor += bit;
                        node = node.zero;
                    }
                } else {
                    if(node.one == null)
                        node = node.zero;
                    else {
                        xor += bit;
                        node = node.one;
                    }
                }
            }
            max = Math.max(max, xor);
        }
        return max;
    }
    

    }


  • 0

    @macrohard I just submitted our solutions five times each. Your solution's times were 102, 110, 89, 82, 106 ms. My solution's times were 48, 75, 75, 74, 75 ms. Apparently something has indeed changed, as that's much faster than before.


  • 0

    @StefanPochmann Hey Stefan, I had just one query that why it would not work if we do :

    curSum += (curBit ^ 1 << i); instead of curSum += (1 << i); as I believe, we are reconstructing a reverse/negative bit image of the number on the xor trie.

    Thank you in advance !


  • 1
    R

    @xin77 As per my understanding - at every bit we are trying to see if there is any flip in the bit in the other num as 0^1 = 1. As soon as we find a flip in the numbers we diverge the branch as we are moving from MSB to LSB to maximize the sum. We are going here greedy to maximize the sum.


  • 0
    Z

    Thanks for the solution! It is a little more efficient, if we get the xor max value during building the tree. Here is my C++.

    class Solution {
    struct treenode {
        treenode* left, *right;
        treenode():left(NULL), right(NULL) {} 
    };
    public:
        int findMaximumXOR(vector<int>& nums) {
            treenode* root = new treenode();
            int ans = 0;
            // build a trie and get the max xor value on the fly for the current trie
            for (int num:nums)  ans = max(ans, build(root, num));
            return ans;
        }
    private:
        int build(treenode* p, int num) {
            int ans = 0, mask = 1<<30;
            treenode* q = p;
            for (int i = 30; i >= 0; i--) {
                ans <<= 1;
                if (num&mask) {
                    //build a trie
                    if (p->right == NULL)  p->right = new treenode();
                    p = p->right;
                    //get xor value on the fly
                    if (q->left) {
                        ans++;
                        q = q->left;
                    }
                    else 
                        q = q->right;
                }
                else {
                    if (p->left == NULL)  p->left = new treenode();
                    p = p->left;
                    if (q->right) {
                        ans++;
                        q = q->right;
                    }
                    else 
                        q = q->left;
                }
                mask >>= 1;
            }
            return ans;
        }
    };
    

  • 0

    My trie solution, verbose, but fast: 36ms (95%) @ 2017-08-20 19:00:52

    class Solution {
        public int findMaximumXOR(int[] nums) {
            int n = nums.length;
            char[][] binArs = new char[n][];
            int maxLen = 0;
            List<Integer> maxLenIndices = new ArrayList<>();
            for (int i = 0; i < n; i++) {
                binArs[i] = Integer.toBinaryString(nums[i]).toCharArray();
                int len = binArs[i].length;
                if (len > maxLen) {
                    maxLen = len;
                    maxLenIndices.clear();
                    maxLenIndices.add(i);
                } else if (len == maxLen) {
                    maxLenIndices.add(i);
                }
            }
            TreeNode root = new TreeNode(1);
            for (char[] bin : binArs) {
                insert(root, bin, maxLen);
            }
            int max = 0;
            for (int idx : maxLenIndices) {
                max = Math.max(max, lookup(root, binArs[idx], 0, 0));
            }
            return max;
        }
    
        public void insert(TreeNode root, char[] chs, int depth) {
            if (depth == 0) {
                root.val = 1;
                return;
            }        
            if (depth > chs.length) {
                if (root.left == null)
                    root.left = new TreeNode(depth);
                insert(root.left, chs, depth - 1);
                return;
            }
            int index = chs.length - depth; //calibrate with example
            if (chs[index] == '1') {
                if (root.right == null)
                    root.right = new TreeNode(depth);
                insert(root.right, chs, depth - 1);
            } else {
                if (root.left == null)
                    root.left = new TreeNode(depth);
                insert(root.left, chs, depth - 1);
            }
        }
    
        public int lookup(TreeNode root, char[] chs, int start, int accum) {
            if (start >= chs.length)
                return accum;
            if (chs[start] == '1') {
                if (root.left != null)
                    return lookup(root.left, chs, start + 1, (accum << 1) + 1);
                else
                    return lookup(root.right, chs, start + 1, accum << 1);
            } else {
                if (root.right != null)
                    return lookup(root.right, chs, start + 1, (accum << 1) + 1);
                else
                    return lookup(root.left, chs, start + 1, accum << 1);
            }
        }
    }
    

    General explanation:

    • Build trie, this part is similar;
    • For the given example in the problem description, note that 25 has the leftmost bit1 (or longest binary representation, if we discard all leading 0s). We start from this binary representation, which is 1 1 0 0 1, and then goes on down, left to right, trying to find a differing bit on each digit position. For example, we will which to find
    left to right
    1 -> 0
    1 -> 0
    0 -> 1
    0 -> 1
    1 -> 0
    

    This greedy approach can be easily done with a trie. Since we can actually have multiple numbers with longest binary representations, we loop over all of them, maintain a max and return that.

    The code is long, but the algorithm is actually very easy. Trie inserting and looking up are all routine operations that most can hack out quickly anyway. The logic in the main method is also very clear-cut.

    Difference in approach compared to OP's solution that might contributed to the acceleration:

    • convert to strings and char arrays in the beginning to avoid repeated bit operations, esp. shifting.
    • avoiding doing all 32 bits. Only do as long as the longest bit string.
    • avoid using array as the instance variable of the node. That introduces overhead.

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.