Complicated segmentree solution, hope to find a better one


  • 7
    M
    public class Solution {
        static class segmentTreeNode {
            int start, end, count;
            segmentTreeNode left, right;
            segmentTreeNode(int start, int end, int count) {
                this.start = start;
                this.end = end;
                this.count = count;
                left = null;
                right = null;
            }
        }
        public static List<Integer> countSmaller(int[] nums) {
            // write your code here
            List<Integer> result = new ArrayList<Integer>();
            
            int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE;
            for (int i : nums) {
                min = Math.min(min, i);
                
            }
            if (min < 0) {
                for (int i = 0; i < nums.length; i++) {
                    nums[i] -= min;//deal with negative numbers, seems a dummy way
                }
            }
            for (int i : nums) {
                max = Math.max(max, i);
            }
            segmentTreeNode root = build(0, max);
            for (int i = 0; i < nums.length; i++) {
                updateAdd(root, nums[i]);
            }
            for (int i = 0; i < nums.length; i++) {
                updateDel(root, nums[i]);
                result.add(query(root, 0, nums[i] - 1));
            }
            return result;
        }
        public static segmentTreeNode build(int start, int end) {
            if (start > end) return null;
            if (start == end) return new segmentTreeNode(start, end, 0);
            int mid = (start + end) / 2;
            segmentTreeNode root = new segmentTreeNode(start, end, 0);
            root.left = build(start, mid);
            root.right = build(mid + 1, end);
            root.count = root.left.count + root.right.count;
            return root;
        }
    
        public static int query(segmentTreeNode root, int start, int end) {
            if (root == null) return 0;
            if (root.start == start && root.end == end) return root.count;
            int mid = (root.start + root.end) / 2;
            if (end < mid) {
                return query(root.left, start, end);
            } else if (start > end) {
                return query(root.right, start, end);
            } else {
                return query(root.left, start, mid) + query(root.right, mid + 1, end);
            }
        }
    
        public static void updateAdd(segmentTreeNode root, int val) {
            if (root == null || root.start > val || root.end < val) return;
            if (root.start == val && root.end == val) {
                root.count ++;
                return;
            }
            int mid = (root.start + root.end) / 2;
            if (val <= mid) {
                updateAdd(root.left, val);
            } else {
                updateAdd(root.right, val);
            }
            root.count = root.left.count + root.right.count;
        }
    
        public static void updateDel(segmentTreeNode root, int val) {
            if (root == null || root.start > val || root.end < val) return;
            if (root.start == val && root.end == val) {
                root.count --;
                return;
            }
            int mid = (root.start + root.end) / 2;
            if (val <= mid) {
                updateDel(root.left, val);
            } else {
                updateDel(root.right, val);
            }
            root.count = root.left.count + root.right.count;
        }
    }

  • 21
    L

    Using a Binary Indexed Tree (Fenwick tree) can shorten the code a lot. :P

    public class Solution {
    
        private void add(int[] bit, int i, int val) {
            for (; i < bit.length; i += i & -i) bit[i] += val;
        }
    
        private int query(int[] bit, int i) {
            int ans = 0;
            for (; i > 0; i -= i & -i) ans += bit[i];
            return ans;
        }
    
        public List<Integer> countSmaller(int[] nums) {
            int[] tmp = nums.clone();
            Arrays.sort(tmp);
            for (int i = 0; i < nums.length; i++) nums[i] = Arrays.binarySearch(tmp, nums[i]) + 1;
            int[] bit = new int[nums.length + 1];
            Integer[] ans = new Integer[nums.length];
            for (int i = nums.length - 1; i >= 0; i--) {
                ans[i] = query(bit, nums[i] - 1);
                add(bit, nums[i], 1);
            }
            return Arrays.asList(ans);
        }
    }

  • 0
    L

    It is better to do discretization first to make sure each number is mapped to [0, n). Otherwise, the size of the segment tree can be as large as the max key in the array, say 2147483647.


  • 0

    This is just amazing! Thanks for sharing this solution!


  • 0
    Y

    Could you explain the idea? Thanks!


  • 0
    L

    @zjuzhanxf, you can see here for a detailed tutorial on BIT (Fenwick Tree).


  • 0
    D

    Impressed, BIT is awesome!


  • 0
    Z

    If the input is very spread, your solution would take up a lot of memory. You should build remove duplicates and build tree from the sorted array.

    class Solution(object):
    
    def countSmaller(self, nums):
        """
        :type nums: List[int]
        :rtype: List[int]
        """
        array = sorted(list(set(nums)))
        indexes = dict([(value,index) for index,value in enumerate(array)])
        tree = SegmentTree()
        root = tree.build(0,len(array)-1)
        rst = [0]*len(nums)
        for i in range(len(nums)-1,-1,-1):
            #find index to modify
            index = indexes[nums[i]]
            tree.modify(root,index,1)
            rst[i] = tree.query(root,0, index-1)
        return rst
    

    class SegmentTreeNode:

    def __init__(self, start, end, sum = 0):
        self.start, self.end, self.sum = start, end, sum
        self.left, self.right = None, None
    

    class SegmentTree:

    def build(self, start, end):
        # write your code here
        if start > end:
            return None
        if start == end:
            return SegmentTreeNode(start,end)
        node = SegmentTreeNode(start,end)
        node.left = self.build(start, (start+end)/2)
        node.right = self.build((start+end)/2+1,end)
        return node
    
    def modify(self, root, index, value):
        # write your code here
        if root is None:
            return 
        if root.start== root.end and root.start == index:
            root.sum += value
        middle = (root.start + root.end)/2
        if index > middle:
            self.modify(root.right,index,value)
        else:
            self.modify(root.left,index,value)
        if root.left and root.right:
            root.sum = root.left.sum + root.right.sum
        return
    
    def query(self, root, start, end):
        if start > end or root is None:
            return 0
        if start == root.start and end == root.end:
            return root.sum
        middle = (root.start + root.end)/2
        # interval in left
        if end <= middle:
            return self.query(root.left,start,end)
        # interval in right
        if start > middle:
            return self.query(root.right,start,end)
        # interval breaks into two section
        else:
            return self.query(root.left,start,middle) + self.query(root.right,middle+1,end)
    

    `


  • 0
    X

    why did you build the segment tree with value rather than the index of the array?

    segmentTreeNode root = build(0, max);

    change to :
    segmentTreeNode root = build(0, nums.length - 1);

    If you use the index of the array, then you won't worry the negative number


  • 0
    L

    @mickeyliu6 Recently, I learned a more elegant way of writing a bit without discretization.

    void add(Map<Long, Integer> bit, long x) {
        for (x += 1L << 31; x < 1L << 32; x += x & -x) bit.put(x, bit.getOrDefault(x, 0) + 1);
    }
    
    int sum(Map<Long, Integer> bit, long x) {
        int ans = 0;
        for (x += 1L << 31; x > 0; x -= x & -x) ans += bit.getOrDefault(x, 0);
        return ans;
    }
    
    public List<Integer> countSmaller(int[] nums) {
        Map<Long, Integer> bit = new HashMap<>();
        Integer[] res = new Integer[nums.length];
        for (int i = nums.length - 1; i >= 0; i--) {
            res[i] = sum(bit, (long) nums[i] - 1);
            add(bit, nums[i]);
        }
        return Arrays.asList(res);
    }
    

  • 0
    Q

    For segment tree solution, what about this?

    class Solution {
    public:
        vector<int> countSmaller(vector<int>& nums) {
            vector<int> res (nums.size());
            auto pos_num = sorted(enumerate(nums), [](std::pair<size_t, int> lhs, std::pair<size_t, int> rhs) { return lhs.second < rhs.second || (lhs.second == rhs.second && lhs.first < rhs.first); });
            const int N = nums.size();
            std::vector<int> seg(N*2);
            for (auto pair: pos_num) {
                int t = N + pair.first;
                for (++seg[t]; t > 1; t >>= 1) seg[t>>1] = seg[t] + seg[t^1];
                int sum = 0;
                for (int l = N + pair.first + 1, r = 2*N; l < r; l >>= 1, r >>= 1) {
                    if (l&1) sum += seg[l++];
                    if (r&1) sum += seg[--r];
                }
                res[pair.first] = sum;
            }
            return res;
        }
    
    // enumerate(), return a vector of pair<index, value>, index starting from zero.
    template <typename Iterator>
    std::vector<std::pair<size_t, typename Iterator::value_type>> enumerate(Iterator first, Iterator last) {
        std::vector<std::pair<size_t, typename Iterator::value_type>> res;
        for (size_t i = 0; first != last; ++i, ++first) res.push_back({i, *first});
        return res;
    }
    template <typename Container>
    std::vector<std::pair<size_t, typename Container::value_type>> enumerate(const Container &c) { return enumerate(std::begin(c), std::end(c)); }
    
    // sorted(), return a vector of sorted sequence.
    template <typename Iterator, typename Compare = std::less<typename Iterator::value_type>>
    std::vector<typename Iterator::value_type> sorted(Iterator first, Iterator last, Compare comp = Compare()) {
        auto res = std::vector<typename Iterator::value_type> (first, last);
        std::sort(res.begin(), res.end(), comp);
        return res;
    }
    template <typename Container, typename Compare = std::less<typename Container::value_type>>
    std::vector<typename Container::value_type> sorted(const Container &c, Compare comp = Compare()) { return sorted(std::begin(c), std::end(c), comp); }
    };

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.