Short Java Binary Index Tree BEAT 97.33% With Detailed Explanation


  • 30

    This is the Binary Index Tree.

    Here is a very good explanation.

    What is Binary Index Tree

    The basic idea is:

    1, we should build an array with the length equals to the max element of the nums array as BIT.
    2, To avoid minus value in the array, we should first add the (min+1) for every elements 
    (It may be out of range, where we can use long to build another array. But no such case in the test cases so far.)
    3, Using standard BIT operation to solve it.
    

    Here is the code, welcome to judge me:

    public class Solution {
        public List<Integer> countSmaller(int[] nums) {
            List<Integer> res = new LinkedList<Integer>();
            if (nums == null || nums.length == 0) {
                return res;
            }
            // find min value and minus min by each elements, plus 1 to avoid 0 element
            int min = Integer.MAX_VALUE;
            int max = Integer.MIN_VALUE;
            for (int i = 0; i < nums.length; i++) {
                min = (nums[i] < min) ? nums[i]:min;
            }
            int[] nums2 = new int[nums.length];
            for (int i = 0; i < nums.length; i++) {
                nums2[i] = nums[i] - min + 1;
                max = Math.max(nums2[i],max);
            }
            int[] tree = new int[max+1];
            for (int i = nums2.length-1; i >= 0; i--) {
                res.add(0,get(nums2[i]-1,tree));
                update(nums2[i],tree);
            }
            return res;
        }
        private int get(int i, int[] tree) {
            int num = 0;
            while (i > 0) {
                num +=tree[i];
                i -= i&(-i);
            }
            return num;
        }
        private void update(int i, int[] tree) {
            while (i < tree.length) {
                tree[i] ++;
                i += i & (-i);
            }
        }
    }
    

  • -1
    C

    Wow, This is a awesome solution.. very fast when the elements are near each other. But not fast when the range is large.


  • 8
    Z

    Actually, you can sort the array firstly and then map them to their order number, so that this can be solved by a tree which size is equal to the array size.


  • 1
    J

    Excellent solution!

    Well in the case of [2147483647,-2147483648,-1,0], there will be error like Line 35: java.lang.ArrayIndexOutOfBoundsException: -2147483647.

    I have no idea of how to solve it :(

    Leave the question here, if you have any solution, please comment my answer, thank you.


  • 11
    J

    Map each number into its corresponding ordered index first.

    class BIT {
    	int n;
    	int[] bit;
    
    	BIT(int size) {
    		this.n = size + 1;
    		this.bit = new int[this.n];
    	}
    
    	void update(int i) {
    		while (i <= n - 1) {
    			bit[i]++;
    			i = i + (i & -i);
    		}
    	}
    
    	int sum(int i) {
    		int ans = 0;
    		while (i > 0) {
    			ans += bit[i];
    			i = i - (i & -i);
    		}
    		return ans;
    	}
    }
    
    public List<Integer> countSmaller(int[] nums) {
    	List<Integer> counts = new LinkedList<Integer>();
    
    	if (nums == null || nums.length == 0)
    		return counts;
    
    	int[] orderedNums = nums.clone();
    	Arrays.sort(orderedNums);
    	int[] nums2 = IntStream.of(nums)
    			.map(x -> Arrays.binarySearch(orderedNums, x) + 1).toArray();
    
    	BIT bit = new BIT(nums2.length);
    	for (int i = nums2.length - 1; i >= 0; i--) {
    		counts.add(0, bit.sum(nums2[i]));
    		bit.update(nums2[i] + 1);
    	}
    
    	return counts;
    }

  • 0
    L

    It's a good point, otherwise, [Integer.MIN_VALUE, Integer.MAX_VALUE] will give you a hard time.


  • 0
    F

    It's a great post. I didn't come with the solution with binary indexed tree. Thanks a lot.

    int[] tree = new int[max+1]; is really brilliant. Only in this way we could build our own binary index tree to update/get the element, which smaller than current.


  • 1
    E

    @Joyce_Lee
    change the bit array type from int to long


  • 0
    J

    @emmonenirvana Thanks!


  • 6

    @monkeyGoCrazy Really brilliant solution, I don't think it can be solved using Fenwick Tree until seeing your code. However, I'd like to add more comment to make code more clear. Here is my code:

    public class Solution {
        public List<Integer> countSmaller(int[] nums) {
            if(nums == null || nums.length == 0) return new ArrayList<>();
    
            // find min value and minus min by each elements, plus 1 to avoid 0 element
            int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE;
            for(int i = 0; i < nums.length; i++) min = Math.min(min, nums[i]);;
            for(int i = 0; i < nums.length; i++) {
                nums[i] = nums[i] - min + 1;
                max = Math.max(max, nums[i]);
            }
    
            List<Integer> res = new ArrayList<>();
            int[] fenwickTree = new int[max + 1];
            for(int i = nums.length - 1; i >= 0; i--) {
                // the index of nums[i] is nums[i] - 1
                // we need to find the sum (-INF, nums[i] - 1], so the index is nums[i] - 2
                res.add(0, getSum(fenwickTree, nums[i] - 2));
                
                // after searching, we need to update the fenwick tree for the next round
                // the new added number is nums[i], but its index of original is nums[i] - 1
                updateFenwickTree(fenwickTree, nums[i] - 1, 1);
            }
            return res;
        }
        
        // the index is the index of original array
        private void updateFenwickTree(int[] fenwickTree, int index, int value) {
            // the index of fenwick tree is one larger than the index of original array
            for(int i = index + 1; i < fenwickTree.length; i += i & (-i)) {
                fenwickTree[i] += value;
            }
        }
        
        // the index is the index of original array
        private int getSum(int[] fenwickTree, int index) {
            int sum = 0;
            // the index of fenwick tree is one larger than the index of original array
            for(int i = index + 1; i > 0; i -= i & (-i)) {
                sum += fenwickTree[i];
            }
            return sum;
        }
    }
    

  • 6

    Another BIT Solution written in Java

    public class Solution {
        public List<Integer> countSmaller(int[] nums) {
            if(nums == null || nums.length == 0) return new ArrayList<>();
            
            // clone the original array and sort it, store <value, position> into hash map
            Map<Integer, Integer> map = new HashMap<>();
            int[] sortedNum = nums.clone();
            Arrays.sort(sortedNum);
            for(int i = 0; i < nums.length; i++) map.put(sortedNum[i], i);
            
            // create fenwick tree whose length is one larger than the original array
            int[] fenwickTree = new int[nums.length + 1];
            List<Integer> res = new ArrayList<>();
            for(int i = nums.length - 1; i >= 0; i--) {
                res.add(0, getSum(fenwickTree, map.get(nums[i]) - 1));
                updateFenwickTree(fenwickTree, map.get(nums[i]), 1);
            }
            return res;
        }
    
        // the index is the index of original array
        private void updateFenwickTree(int[] fenwickTree, int index, int value) {
            // the index of fenwick tree is one larger than the index of original array
            for(int i = index + 1; i < fenwickTree.length; i += i & (-i)) {
                fenwickTree[i] += value;
            }
        }
    
        // the index is the index of original array
        private int getSum(int[] fenwickTree, int index) {
            int sum = 0;
            // the index of fenwick tree is one larger than the index of original array
            for(int i = index + 1; i > 0; i -= i & (-i)) {
                sum += fenwickTree[i];
            }
            return sum;
        }
    }
    

  • 0
    H

    How is the complexity? Since the length of array "int[] tree" is greater than the length of array "int[] nums", so the "update" and "get" operation of the tree is no longer O(lgn), where n is the length of nums, right? So the overall complexity is no longer O(nlgn)?


  • 1

    @xietao0221 Very smart idea to compress the original array and reassign their value to 0~n since only the relative order matters not the absolute value. Duplicates will be maped to the index of the largest one so that equal values to the left won't be counted since their count are added to the same BIT node instead of spread into the adjacent indices and counted in getSum().


  • 0
    Y

    Who can tell me why "find min value and minus min by each elements, plus 1 to avoid 0 element"? Does the index of Binary Indexed Tree starts from 1 ?


  • 0
    D

    Memory efficient version, in case the given array is sparse.
    Map size would be O(lg(max element) + n)

        public List<Integer> countSmaller(int[] nums) {
            Map<Integer, Integer> bit = new HashMap<>();
            List<Integer> ret = new ArrayList<>();
            if (nums.length==0) return ret;
            int max = Integer.MIN_VALUE, min = Integer.MAX_VALUE;
            for (int i=0;i<nums.length;i++) {
                max = Math.max(nums[i], max);
                min = Math.min(nums[i], min);
            }
            
            int adjust = 0;
            if (min<0) adjust = -min;
            max+=adjust+1;
            for (int i=nums.length-1;i>=0;i--) {
                int n = nums[i] + adjust;
                ret.add(get(bit, n));
                update(bit, n+1, 1, max);
            }
            Collections.reverse(ret);
            return ret;
        }
        
        int get(Map<Integer, Integer> bit, int idx) {
            int sum=0;
            while (idx>0) {
                sum+=bit.getOrDefault(idx,0);
                idx-=(idx&-idx);
            }
            return sum;
        }
        
        void update(Map<Integer, Integer> bit, int idx, int val, int max) {
            while (idx<=max+1) {
                bit.put(idx, bit.getOrDefault(idx, 0) + val);
                idx+=(idx&-idx);
            }
        }
    

  • 0
    G

    @jason.junchen

    counts.add(0, bit.sum(nums2[i]));
    bit.update(nums2[i] + 1);
    

    could you explain how to determine if we need add one in add or update function?


  • 0

    Thanks for this BIT idea. I came up with two versions of BIT solution. One use rank in sorted array as BIT index and iterate from the back of the original array; the other use the index from the back in the original array as BIT index and iterate through the sorted array.
    Please also check my other three solutions :)

    1. Binary Indexed Tree, iterate from the back of the array, use the rank as BIT index:
        private void update(int[]BIT, int index, int val) {
            index++;
            while (index < BIT.length) {
                BIT[index] += val;
                index += index & (-index);
            }
        }
        private int getSum(int[]BIT, int index) {
            index++;
            int sum = 0;
            while (index > 0) {
                sum += BIT[index];
                index -= index & (-index);
            }
            return sum;
        }
        public List<Integer> countSmaller(int[] nums) {
            LinkedList<Integer> result = new LinkedList<>();
            if (nums.length == 0) return result;
            int[] sorted = nums.clone();
            Arrays.sort(sorted);
            HashMap<Integer, Integer> map = new HashMap<>();
            for (int i = 0; i < sorted.length; i++) {
                map.put(sorted[i], i);
            }
            int[] BIT = new int[nums.length+1];
            for (int i = nums.length-1; i >= 0; i--) {
                result.addFirst(getSum(BIT, map.get(nums[i])-1));
                update(BIT, map.get(nums[i]), 1);
            }
            return result;
        }
    
    1. Binary Indexed Tree, iterate according to sorted order and use original index (from the back) as BIT index:
        private void update(int[]BIT, int index, int val) {
            index++;
            while (index < BIT.length) {
                BIT[index] += val;
                index += index & (-index);
            }
        }
        private int getSum(int[]BIT, int index) {
            index++;
            int sum = 0;
            while (index > 0) {
                sum += BIT[index];
                index -= index & (-index);
            }
            return sum;
        }
        public List<Integer> countSmaller(int[] nums) {
            int[] sorted = nums.clone();
            Arrays.sort(sorted);
            HashMap<Integer, LinkedList<Integer>> map = new HashMap<>();
            for (int i = 0; i < nums.length; i++) {
                LinkedList<Integer> l = map.getOrDefault(nums[i], new LinkedList<>());
                l.add(i);
                map.put(nums[i], l);
            }
            int[] BIT = new int[nums.length+1], count = new int[nums.length];
            for (int i = 0; i < sorted.length; i++) {
                if (i > 0 && sorted[i] == sorted[i-1]) continue;
                LinkedList<Integer> list = map.get(sorted[i]);
                for (int j : list) {
                    count[j] = getSum(BIT, nums.length-1-j);
                }
                for (int j : list) {
                    update(BIT, nums.length-1-j, 1);
                }
            }
            LinkedList<Integer> result = new LinkedList<>();
            for (int c : count) result.add(c);
            return result;
        }
    

  • 0
    D

    my c++ code by using BIT:
    (I use array id convert nums[i] to the relative number
    example: nums: 5 2 1 6 0 0
    id: 0 1 2 3 4 4
    -------->(sort)id: 4 5 2 1 0 3
    --------> num: 4 3 2 5 1 1
    so, calculate the number of num from len - 1 to i by using BIT

    It is similar to your answer but some little different
    code:
    '''
    class Solution {
    public:
    vector<int> v;
    int j;
    int lowbit(int x)
    {
    return x & (-x);
    }
    void insert(int id)
    {
    for (int i = id; i <= j; i += lowbit(i))
    v[i]++;
    }
    int query(int id)
    {
    int sum = 0;
    while (id > 0)
    {
    sum += v[id];
    id -= lowbit(id);
    }
    return sum;
    }
    vector<int> countSmaller(vector<int>& nums) {
    if (nums.empty())
    return nums;

        vector<int> res;
        vector<int> id;
        vector<int> num;
        int len = nums.size();
        id.resize(len, 0);
        num.resize(len, 0);
        res.resize(len, 0);
        
        for (int i = 0; i < len; i++)
            id[i] = i;
        sort(id.begin(), id.end(), [&](int a, int b){return nums[a] < nums[b];});
        j = 0;
        for (int i = 0; i < len; i++)
        {
            if (i && (nums[id[i]] == nums[id[i - 1]]))
                num[id[i]] = j;
            else
                num[id[i]] = ++j;
        }
        
        v.resize(j + 5, 0);
        for (int i = len - 1; i >= 0; i--)
        {
            res[i] = query(num[i] - 1);
            insert(num[i]);
        }
        return res;
    }
    

    };
    '''


  • 0
    D

    @YangCao index of Binary Indexed Tree must start from 1 because of the lowbit of 0 is also 0,so it will fall into the dead cycle


  • 0
    P

    @conquerTheCode
    maxsize does not matter in case of time complexity. Because you don't need to visit each and every index.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.