C++ DP that beats 93%


  • 0
    J
    class Solution {
    public:
        int findMaximumXOR(vector<int>& nums) {
            if (nums.size() < 2) {
                return 0;
            }
            // split the first group, find the rightmost bit that after xor is 1
            vector<int> group0, group1;
            int bit_idx;
            for (bit_idx = 31; bit_idx >= 0; bit_idx--) {
                group0.clear();
                group1.clear();
                splitByBit(nums, bit_idx, group0, group1);
                if (!group0.empty() && !group1.empty()) {
                    break;
                }
            }
            int prefix = bit_idx >= 0 ? 1 << bit_idx : 0;
            return maxXOR(group0, group1, bit_idx--, prefix);
            
        }
        
        int maxXOR(vector<int>& group0, vector<int>& group1, int bit_idx, int prefix) {
            int result = 0;
            if (bit_idx == -1) {
                return prefix;
            }
            // split group0
            vector<int> group00, group01;
            splitByBit(group0, bit_idx, group00, group01);
            // split group1
            vector<int> group10, group11;
            splitByBit(group1, bit_idx, group10, group11);
            // compare group0 and group1
            bool bit_xor_case0 = !group00.empty() && !group11.empty();
            bool bit_xor_case1 = !group01.empty() && !group10.empty();
            //cout << "prefix " << prefix << endl;
            if (bit_xor_case0) {
                prefix |= 1 << bit_idx;
                result = max(result, maxXOR(group00, group11, bit_idx-1, prefix));
            }
            if (bit_xor_case1) {
                prefix |= 1 << bit_idx;
                result = max(result, maxXOR(group01, group10, bit_idx-1, prefix));
            }
            if (!bit_xor_case0 && !bit_xor_case1) {
                result = max(result, maxXOR(group0, group1, bit_idx-1, prefix));
            }
            return result;
        }
        
        // split nums to two groups by the `bit_idx`-th bit.
        void splitByBit(vector<int>& nums, int bit_idx, vector<int>& group0, vector<int>& group1) {
            int mask = 1 << bit_idx;
            for (auto num : nums) {
                if (int(mask & num) == 0) {
                    group0.push_back(num);
                } else {
                    group1.push_back(num);
                }
            }
            //cout << group0.size() << " " << group1.size() << endl;
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.