A solution based on @bartoszkp's, with missing test cases


  • 6

    Let's say we find the MSB that can be set to 1 in the result. Then we can partition the whole thing into two subsets. One element must be taken from one subset, the other from the other one. Then we move on to next bit that could possibly be set to one, but this time we're restricted to picking elements from different subsets generated at the first step. We do that by partitioning each subset in two subsets again based on the value of the next candidate bit, and we try to combine elements in such a way so that bit is set to one. There are two subsets now in each of the original subsets, so we try to combine elements in two ways based on that next bit: 1-0 and 0-1. The whole thing goes on recursively until we run out of bits, and then we just return the maximum, and everyone is happy. Or at least that's the idea.

    This idea occurred to me when I was solving this problem at first, but I thought it would be too slow because I may get bad splits, so I switched to prefixes/sets instead. But then I saw this solution based on the same idea, which looked pretty impressive. During the discussion with the author we came to the conclusion that bad splits won't degrade runtime to O(n^2) because recursion depth is limited by the number of bits anyway.

    Unlike that solution, my original idea was to use a mask that indicates which bits could be possibly set. Say, if a certain bit is 1 in all numbers, or is 0 in all numbers, there is no way to get 1 in that position. That means we should only consider bits that are set in some numbers, but not in all of them. “Set in some” = OR, “set in all” = AND, ”set in not all” = NOT AND, and therefore the mask for such bits is or & ~and.

    While I was at it, I realized that the original solution by @bartoszkp had a bug that wasn't detected by the OJ. It started with the MSB of the maximum element, but if that bit is set in all numbers, then the very first split will be wrong. The use of my mask incidentally fixed that too. A simple test case to demonstrate the problem: [4, 6, 7].

    Another interesting test case: [8, 10, 2]. The code below fails it if the two lines testing for (mask & msb) == 0 in the helper function are commented out. That's because it tries to partition the array on the mask 4, but that bit is cleared in all numbers. And yet, it passes the OJ too.

    Now here is one version of the code, that is wrong too. I'm posting it because it clearly demonstrates a flaw with this approach in general.

    class Solution {
    public:
        int findMaximumXOR(vector<int>& nums) {
            auto orOp = [](int a, int b) { return a | b; };
            auto andOp = [](int a, int b) { return a & b; };
            mask = accumulate(nums.cbegin(), nums.cend(), 0, orOp)
                & ~accumulate(nums.cbegin(), nums.cend(), 0x7FFFFFFF, andOp);
            auto msb = computeMsb(mask);
            auto msbSplit = msbPartition(nums.begin(), nums.end(), msb);
            return findMaximumXor(nums.begin(), msbSplit, msbSplit, nums.end(), msb >> 1);
        }
        
        int computeMsb(int n) {
            auto msb = n;
            msb |= msb >> 1;
            msb |= msb >> 2;
            msb |= msb >> 4;
            msb |= msb >> 8;
            msb |= msb >> 16;
            return msb - (msb >> 1);
        }
        
        vector<int>::iterator msbPartition(const vector<int>::iterator &beginIt,
                                           const vector<int>::iterator &endIt,
                                           int msb) {
            auto msbSet = [msb](int n) { return (n & msb) != 0; };
            return partition(beginIt, endIt, msbSet);
        }
        
        int findMaximumXor(const vector<int>::iterator& beginLeft,
                           const vector<int>::iterator& endLeft,
                           const vector<int>::iterator& beginRight,
                           const vector<int>::iterator& endRight,
                           int msb) {
            if (distance(beginLeft, endLeft) == 1 && distance(beginRight, endRight) == 1)
                return *beginLeft ^ *beginRight;
            if (msb == 0 || beginLeft == endLeft || beginRight == endRight)
                return 0;
            if ((mask & msb) == 0)
                return findMaximumXor(beginLeft, endLeft, beginRight, endRight, msb >> 1);
            auto splitLeft = msbPartition(beginLeft, endLeft, msb);
            auto splitRight = msbPartition(beginRight, endRight, msb);
            auto result1 = findMaximumXor(beginLeft, splitLeft, splitRight, endRight, msb >> 1);
            auto result2 = findMaximumXor(splitLeft, endLeft, beginRight, splitRight, msb >> 1);
            return max(result1, result2);
        }
    
    private:
        int mask;
    };
    

    The test case where it fails is [14, 15, 9, 3, 2] (not in the OJ either). It goes like this: first we split it like 14, 15, 9 / 3, 2, then we split the left part as 14, 15 / 9. And then, when we try to match 14, 15 with 3, 2 we get a problem. Even though the next candidate power of 2 is 1, we can't set it because 2 is set in all numbers now. And yet it wasn't set in all numbers to begin with, so our mask fails to skip it.

    This flaw originally comes from the idea that we match subsets having different values of a certain bit. However, there is no guarantee that subsets even exist for that bit. Our mask only provides guarantee for the MSB, and later on it's just a hint. That means we need to check for that again after we partition. The fixed code is below, and I hope I got it right this time:

    class Solution {
    public:
        int findMaximumXOR(vector<int>& nums) {
            auto orOp = [](int a, int b) { return a | b; };
            auto andOp = [](int a, int b) { return a & b; };
            mask = accumulate(nums.cbegin(), nums.cend(), 0, orOp)
                & ~accumulate(nums.cbegin(), nums.cend(), 0x7FFFFFFF, andOp);
            if (mask == 0)
                return 0;
            auto msb = computeMsb(mask);
            auto msbSplit = msbPartition(nums.begin(), nums.end(), msb);
            return findMaximumXor(nums.begin(), msbSplit, msbSplit, nums.end(), msb >> 1);
        }
        
        int computeMsb(int n) {
            auto msb = n;
            msb |= msb >> 1;
            msb |= msb >> 2;
            msb |= msb >> 4;
            msb |= msb >> 8;
            msb |= msb >> 16;
            return msb - (msb >> 1);
        }
        
        vector<int>::iterator msbPartition(const vector<int>::iterator &beginIt,
                                           const vector<int>::iterator &endIt,
                                           int msb) {
            auto msbSet = [msb](int n) { return (n & msb) != 0; };
            return partition(beginIt, endIt, msbSet);
        }
        
        int findMaximumXor(const vector<int>::iterator& beginLeft,
                           const vector<int>::iterator& endLeft,
                           const vector<int>::iterator& beginRight,
                           const vector<int>::iterator& endRight,
                           int msb) {
            if (msb == 0 || (distance(beginLeft, endLeft) == 1 && distance(beginRight, endRight) == 1))
                return *beginLeft ^ *beginRight;
            if ((mask & msb) == 0)
                return findMaximumXor(beginLeft, endLeft, beginRight, endRight, msb >> 1);
            auto splitLeft = msbPartition(beginLeft, endLeft, msb);
            auto splitRight = msbPartition(beginRight, endRight, msb);
            auto result = 0;
            if (distance(beginLeft, splitLeft) > 0 && distance(splitRight, endRight) > 0)
                result = findMaximumXor(beginLeft, splitLeft, splitRight, endRight, msb >> 1);
            if (distance(splitLeft, endLeft) > 0 && distance(beginRight, splitRight) > 0)
                result = max(result, findMaximumXor(splitLeft, endLeft, beginRight, splitRight, msb >> 1));
            if (result == 0) // no way to set this bit to 1
                result = findMaximumXor(beginLeft, endLeft, beginRight, endRight, msb >> 1);
            return result;
        }
    
    private:
        int mask;
    };
    

    It runs for 26 ms, beating 99%.

    The result == 0 line executes when both ifs above fail to run. That happens if we have bad splits on both sides, just like in the last mentioned test case.

    Another last funny test case is [15, 15, 9, 3, 2]. The problem description doesn't say numbers can't be duplicated. Well, the code above passes it thanks to the msb == 0 check in the beginning of the recursive function. It looks kind of funny because we just return a XOR of two randomly picked elements from both sides in that case without even checking how many elements are there. However, when we get to msb == 0, we have already split both sides based on every possible bit, so it's either that both subsets have size 1, or it's that they are all duplicates, and therefore picking first elements is just as fine.

    The last, but not least, is the test case where all numbers are duplicate. That is checked by mask == 0 in the top-level function. Funny thing, that test could be removed, but only because partitioning happens to use (n & msb) != 0. If I change it to (n & msb) == 0 (which is OK in general), then we'll have a problem: the first MSB partitioning will generate an N/0 partition, which means beginRight will be end(), and we'll get UB by trying to dereference it. With (n & msb) != 0 it generates an 0/N partitioning, so both beginLeft and beginRight point to the same element, and therefore eventually we return zero.

    Recap of the test cases to be added: [4, 6, 7], [8, 10, 2], [14, 15, 9, 3, 2], [15, 15, 9, 3, 2].


  • 1

    Very nice analysis, @SergeyTachenov . I have added all four test cases you suggested, thank you so much.


  • 0
    Z

    I think your solution is pretty cool! The coding seems too professional, which may scare people a bit, such as me. I rewrote it in C++, assuming I understand it right. It runs 22 ms, beats 99.54% currently. array partition like quick sort

    Here is my understanding:

    1. working from most significant bit on the left towards right. Obviously, if the more significant bit is 1, the xor value is greater than those with this bit = 0.
    2. Similar to quick sort, we partition a certain range of the array nums in place. The left subrange has current bit = 1, and the right subrange has current bit = 0. Let's name them as A and B.
    3. In order to find the greatest XOR value, we have to take 1 number from A, and 1 number from B. To set next bit to 1, there must be a subrange in A and a subrange in B having opposite bit; otherwise, next bit will be 0. So we partition range A and B recursively.
    4. When partitioning range A and B, there are 3 cases. (1) A has both bit 1 and 0 ranges. We check whether B has bit 0 or bit 1 range or both. (2) A has only bit 1. We check whether B has bit 0. (3) A has only bit 0. We check whether B has bit 1.

    The code with comments is as below.

    class Solution {
    public:
        int findMaximumXOR(vector<int>& nums) {
            int n = nums.size();
            return helper(nums, 0, n-1, 0, n-1, 0, 30);
        }
    private:
        // (ls, le) and (rs, re) are two ranges of nums, which gives max xor value to current bit;
        // bit decreases from 30 to 0, i.e., working from most significant bit on the left towards right;
        // Similar to quicksort, partition (ls, le) to two ranges (ls, j-1) and (j, le) by swapping elements
        // the range on the left with current bit = 1, and the range on right is 0; We do the same to (rs, re)
        // In order to set the current bit in the answer, i.e. val, to be 1, the left (ls, le) and right (rs,re) ranges must have subranges with opposite bit. If so, val = (val << 1) + 1; otherwise, val = val << 1.
        int helper(vector<int>& nums, int ls, int le, int rs, int re, int val, int bit) {
            if (bit == -1) return val;
            int mask = 1<<bit, j = ls, k = rs;
            for (int i = ls; i <= le; i++) 
                if (nums[i]&mask) swap(nums[i], nums[j++]);
            for (int i = rs; i <= re; i++) 
                if (nums[i]&mask) swap(nums[i], nums[k++]);
            // the left range has two subranges, the answer is max of (bit 1 subrange on the left and bit 0 subrange on the right) or (bit 0 subrange on the left and bit 1 subrange on the right)
            if (j > ls && j <= le) {
                int ans = 0;
                if (k > rs) 
                    ans = helper(nums, j, le, rs, k-1, val*2+1, bit-1);
                if (k <= re) 
                    ans = max(ans, helper(nums, ls, j-1, k, re, val*2+1, bit-1));
                return ans;
            }
            // the left range has only bit 0 subrange
            else if (j <= ls) {
                // check whether the right range has bit 1 subrange
                if (k > rs) 
                    return helper(nums, ls, le, rs, k-1, val*2+1, bit-1);
                else 
                    return helper(nums, ls, le, rs, re, val*2, bit-1);
            }
            // the left range has only bit 1 subrange
            else {
                // check whether the right range has bit 0 subrange
                if (k <= re) 
                    return helper(nums, ls, le, k, re, val*2+1, bit-1);
                else 
                    return helper(nums, ls, le, rs, re, val*2, bit-1);
            }
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.