Java rolling hash solution


  • 0

    While the KMP solutions are pretty awesome, when I first looked at this problem I had no idea about KMP. So I did first the dumbest solution: try adding the last characters to the beginning and look if it's a palindrome or not. That's, of course, O(n^2) in the average and worst cases, although the best is O(n), if the string is a palindrome to begin with.

    So the next step was to look at possible improvements. It occurred to me that we don't change the string much when adding a single character, so it should be possible to use rsync-like rolling hash to check whether the two parts are likely to be reverse-equal. Of course, we still need to double-check if the hashes are equal, but most of the time it's just updating and comparing hashes, which is just O(1).

    My hash function is based on Fletcher's checksum, but a bit simpler. sum0 is just the sum of all characters. sum1 is the sum of all values that sum0 takes during the computation. For example, if we consider a contrived string consisting of characters with codes 1, 2, 3, 4, 5, the hash will be

    sum0 = 1 + 2 + 3 + 4 + 5
    sum1 = 1 + (1 + 2) + (1 + 2 + 3) + (1 + 2 + 3 + 4) + 1 + 2 + 3 + 4 + 5
         = 5 * 1 + 4 * 2 + 3 * 3 + 2 * 4 + 1 * 5
    

    It is quite easy to update this hash when appending or removing a character, and it is not that hard to compute the hash of two concatenated strings either: just add up sum0 values, and sum1 values are also added, but additional correction is needed to take into account that the right-hand string no longer starts at index 0. This is done by adding the product of the left-hand sum0 with the length of the right-hand string (because it's exactly how many times it would have been added if the hash was calculated for the whole string)

    It is not as efficient as KMP, but has about the same amortized complexity. Runs in 15 ms (beating 45%) while my implementation of KMP runs for 10 ms.

    public String shortestPalindrome(String s) {
        StringBuilder added = new StringBuilder();
        int len = s.length();
        int lastIndex = len - 1;
        RollingHash hashAdded = new RollingHash();
        RollingHash hashLeft = new RollingHash(s, 0, len / 2);
        RollingHash hashRight = new RollingHash(s, len, (len + 1) / 2);
        while (!RollingHash.equal(hashAdded, hashLeft, hashRight)
               || !isPalindrome(added + s)) {
            char newChar = s.charAt(lastIndex--);
            if (len % 2 == 0) {
                // len is to become odd, the right part remains
                added.append(newChar);
                hashAdded.append(newChar);
                ++len;
                int mid = len / 2 - added.length();
                hashLeft.remove(s.charAt(mid)); // mid char
            } else {
                // about to become even
                int mid = len / 2 - added.length();
                hashRight.append(s.charAt(mid)); // mid char
                added.append(newChar);
                hashAdded.append(newChar);
                ++len;
            }
        }
        return added + s;
    }
    
    private static boolean isPalindrome(CharSequence s) {
        for (int i = 0, j = s.length() - 1; i < j; ++i, --j) {
            if (s.charAt(i) != s.charAt(j))
                return false;
        }
        return true;
    }
    
    private static class RollingHash {
        private int sum0 = 0, sum1 = 0, count = 0;
    
        private RollingHash() {
        }
    
        private RollingHash(CharSequence s, int start, int end) {
            for (int i = end >= start ? start : start - 1;
                    end >= start ? i < end : i >= end;
                    i = (end >= start) ? i + 1 : i - 1) {
                append(s.charAt(i));
            }
        }
        
        private void append(char c) {
            sum0 += c;
            sum1 += sum0;
            ++count;
        }
        
        private void remove(char c) {
            sum1 -= sum0;
            sum0 -= c;
            --count;
        }
        
        public static boolean equal(RollingHash added, RollingHash left, RollingHash right) {
            return added.sum0 + left.sum0 == right.sum0
                    && added.sum1 + left.sum1 + left.count * added.sum0 == right.sum1;
        }
    }

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.