Clear order statistic (binary search) tree implementation


  • 0
    L

    Consider ranges all starting from i, so we have a[i] + sum[i+1 to j] + suffixSum[j+1] = suffixSum[i]. With this sum, we get lower <= a[i] + sum[i+1 to j] = suffixSum[i] - suffixSum[j+1] <= upper, so suffixSum[i] - upper <= suffixSum[j+1] <= suffixSum[i] - lower < suffixSum[i] - lower + 1. To obtain the number of j that satisfies this constraint (each j forms a valid range starting from i), we use an order statistic tree (augmented binary search tree) to get the ranks of suffixSum[i] - lower + 1 and suffixSum[i] - upper.

    public class Solution {
      class TreeNode {
        long val; // use long to deal with int arithmetic overflow
        TreeNode left, right;
        int lcount; // number of nodes smaller than val
        int size; // for getting rank k element; not used in this problem
    
        TreeNode(long val) {
          this.val = val;
          size = 1;
        }
      }
    
      TreeNode insert(TreeNode root, long val) {
        if (root == null)
          return new TreeNode(val);
    
        root.size++; // not used in this problem
        if (val < root.val) {
          root.lcount++;
          root.left = insert(root.left, val);
        } else {
          root.right = insert(root.right, val);
        }
    
        return root;
      }
    
      int rank(TreeNode root, long val) {
        if (root == null) {
          return 1; // as if the non-existent element gets inserted here
        }
    
        if (val == root.val) {
          return root.lcount + 1;
        } else if (val < root.val) {
          return rank(root.left, val);
        } else {
          return root.lcount + 1 + rank(root.right, val);
        }
      }
    
      public int countRangeSum(int[] arr, int lower, int upper) {
        int count = 0;
        long suffixSum = 0;
        TreeNode root = insert(null, suffixSum);
        for (int i = arr.length - 1; i >= 0; i--) {
          suffixSum += arr[i]; // current suffix sum
          int rankr = rank(root, suffixSum - lower + 1);
          int rankl = rank(root, suffixSum - upper);
          count += rankr - rankl;
          root = insert(root, suffixSum);
        }
    
        return count;
      }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.