Java clear segment tree solution

  • 1

    Post this because I haven't seen any clean solution using Segment tree.
    Not quick compared to merge sort solution and BIT solution. Plus it uses more space.
    O(n log n) time O(n log n) space.

    public class Solution {
        //Segment tree based solution, map distinct elements in nums in order to [0,1, ..., N-1] for better memory utlization; 
        //add and count elements of nums from right to left: at iteration i, count sum of all elements smaller than nums[i] using Segment.getSum, then add 1 to the count of nums[i] in segment using update
        public List<Integer> countSmaller(int[] nums) {
            //map distinct elements in nums in order to [0,1,...N-1] , e.g. [5,2,6,5,1] ---> [2,1,3,2,0]
            int[] numsSorted = Arrays.copyOf(nums, nums.length);        //fist sort nums [5,2,6,5,1] ---> [1,2,5,5,6]
                                                                        //then create map [1,2,5,5,6] ---> [0,1,2,2,3]
            HashMap<Integer, Integer> nums2Idx = new HashMap<Integer, Integer>();
            int stIdx = 0;         
            for(int i = 0; i < numsSorted.length; i++)
                if(!nums2Idx.containsKey(numsSorted[i]))       nums2Idx.put(numsSorted[i], stIdx++);
            //create segment tree while count smaller elements from right to left
            SegmentTree st = new SegmentTree(stIdx);
            Integer[] res = new Integer[nums.length];
            for(int i = nums.length-1; i >= 0; i--){
                res[i] = st.getSum(0, nums2Idx.get(nums[i])-1);          //sum of all added nums smaller than nums[i]
                st.increment(nums2Idx.get(nums[i]), 1);                  //add 1 count for nums[i]
            return Arrays.asList(res);
    class SegmentTree{
        private int[] st;
        private final int N;
        public SegmentTree(int n){
            this.N = n;
            if(N == 0)  return;
            int height = (int) Math.ceil( Math.log(n) / Math.log(2) );
            int maxSize = 2 * (int) Math.pow(2, height) - 1;
            st = new int[maxSize];
        public int getSum(int i, int j){
            if(i > j)   return 0;
            return getSumUtil(0, N-1, i, j, 0);    
        private int getSumUtil(int ss, int se, int qs, int qe, int si){
            if(qe < ss || se < qs)      return 0;
            if(qs <= ss && se <= qe)    return st[si];
            int mid = getMid(ss, se);
            return getSumUtil(ss, mid, qs, qe, 2*si+1) + getSumUtil(mid+1, se, qs, qe, 2*si+2);
        public void increment(int idx, int val){
            incrementUtil(0, N-1, idx, val, 0);
        private void incrementUtil(int ss, int se, int idx, int val, int si){
            if(ss > idx || se < idx)    return;
            st[si] += val;
            if(ss < se){
                int mid = getMid(ss, se);
                incrementUtil(ss, mid, idx, val, 2*si+1);
                incrementUtil(mid+1, se, idx, val, 2*si+2);
        private int getMid(int s, int e){
            return s + (e - s)/2;

Log in to reply

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.