Java DP+DFS, Memoization+DFS, and DP Pruning Solutions with Analysis


  • 16
    L

    I've been struggling with this problem for a long time, and I'd love to share three different strategies I have tried to solve it. All of them are ACed.

    Method 1: DP + DFS. Very similar to Word Break I, but instead of using a boolean dp array, I used an array of Lists to maintain all of the valid start positions for every end position. Then just do classic backtracking to find all solutions. The time complexity is O(n*m) + O(n * number of solutions), where n is the length of the input string, m is the length of the longest word in the dictionary. The run time was 6ms. It is very efficient because DP is used to find out all the valid answers, and no time is wasted on doing the backtracking.

    public List<String> wordBreak(String s, Set<String> wordDict) {
        List<Integer>[] starts = new List[s.length() + 1]; // valid start positions
        starts[0] = new ArrayList<Integer>();
        
        int maxLen = getMaxLen(wordDict);
        for (int i = 1; i <= s.length(); i++) {
            for (int j = i - 1; j >= i - maxLen && j >= 0; j--) {
                if (starts[j] == null) continue;
                String word = s.substring(j, i);
                if (wordDict.contains(word)) {
                    if (starts[i] == null) {
                        starts[i] = new ArrayList<Integer>();
                    }
                    starts[i].add(j);
                }
            }
        }
        
        List<String> rst = new ArrayList<>();
        if (starts[s.length()] == null) {
            return rst;
        }
        
        dfs(rst, "", s, starts, s.length());
        return rst;
    }
    
    
    private void dfs(List<String> rst, String path, String s, List<Integer>[] starts, int end) {
        if (end == 0) {
            rst.add(path.substring(1));
            return;
        }
        
        for (Integer start: starts[end]) {
            String word = s.substring(start, end);
            dfs(rst, " " + word + path, s, starts, start);
        }
    }
    
    private int getMaxLen(Set<String> wordDict) {
        int max = 0;
        for (String s : wordDict) {
            max = Math.max(max, s.length());
        }
        return max;
    }
    

    Method 2: Memoization + Backtracking. Before I came up with Method 1, I also tried using a HashMap to memoize all the possible strings that can be formed starting from index i. I referred to this post from @Pixel_
    The time complexity is O(len(wordDict) ^ len(s / minWordLenInDict)) as @Pixel_ mentioned. The space complexity would be larger than other methods though. Here is my code:

    public List<String> wordBreak(String s, Set<String> wordDict) {
        HashMap<Integer, List<String>> memo = new HashMap<>(); // <Starting index, rst list>
        return dfs(s, 0, wordDict, memo);
    }
    
    private List<String> dfs(String s, int start, Set<String> dict, HashMap<Integer, List<String>> memo) {
        if (memo.containsKey(start)) {
            return memo.get(start);
        }
        
        List<String> rst = new ArrayList<>();
        if (start == s.length()) {
            rst.add("");
            return rst;
        }
        
        String curr = s.substring(start);
        for (String word: dict) {
            if (curr.startsWith(word)) {
                List<String> sublist = dfs(s, start + word.length(), dict, memo);
                for (String sub : sublist) {
                    rst.add(word + (sub.isEmpty() ? "" : " ") + sub);
                }
            }
        }
        
        memo.put(start, rst);
        return rst;
    }
    

    Method 3: DP Prunning + Backtracking. My very first solution is like this: using a boolean array to memoize whether a substring starting from position i to the end is breakable. This works well for worst cases like: s = "aaaaaaaaaaaab", dict = ["a", "aa", "aaa", "aaaa"]. However, for cases like: s = "aaaaaaaaaaaaa", dict = ["a", "aa", "aaa", "aaaa"], the time complexity is still O(2^n). Here is the code:

    public List<String> wordBreak(String s, Set<String> wordDict) {
        List<String> rst = new ArrayList<>();
        if (s == null || s.length() == 0 || wordDict == null) {
            return rst;
        }
        
        boolean[] canBreak = new boolean[s.length()];
        Arrays.fill(canBreak, true);
        StringBuilder sb = new StringBuilder();
        dfs(rst, sb, s, wordDict, canBreak, 0);
        return rst;
    }
    
    private void dfs(List<String> rst, StringBuilder sb, String s, Set<String> dict, 
        boolean[] canBreak, int start) {
        if (start == s.length()) {
            rst.add(sb.substring(1));
            return;
        }
        
        if (!canBreak[start]) {
            return;
        }
        
        for (int i = start + 1; i <= s.length(); i++) {
            String word = s.substring(start, i);
            if (!dict.contains(word)) continue;
            
            int sbBeforeAdd = sb.length();
            sb.append(" " + word);
            
            int rstBeforeDFS = rst.size();
            dfs(rst, sb, s, dict, canBreak, i);
            if (rst.size() == rstBeforeDFS) {
                canBreak[i] = false;
            }
            sb.delete(sbBeforeAdd, sb.length());
        }
    }
    
    private int getMaxLen(Set<String> wordDict) {
        int max = 0;
        for (String s : wordDict) {
            max = Math.max(max, s.length());
        }
        return max;
    }
    

  • 0
    B

    For the 2nd method, what if the input string cannot be parsed into sentences?
    As in your code, all sublist will be added to the result no matter the substring can be parsed or not.
    This means the input string will always find some corresponding sentences, doesn't it?


  • 0
    B

    And for 1st answer, would you please explain a bit more of it? I'm totally confused when coming to the analysis of time complexity of backtracking...
    Why isn't it O(n*m + num of solutions)? As you said, the complexity of backtracking should be just the number of solutions as it doesn't waste any time on any other combinations.

    Thanks in advance.


  • 0
    B

    Understood.
    Sorry for bothering.


  • 0
    X

    I believe that your time complexity analysis for the 1st solution is wrong, considering this test case:
    "a", "aa", "aaa", "aaaa", "aaaaa", "aaaaaa"...], input string: aaaaaaaaaaaaaaaaaaaaaaaaaaaaa
    The first level recursion times is k,
    the second level recursion time (since there will be k recursion in the second level) is k-1, k-2; k-3, k-4....1
    The third level will be expanded by expanding the (k-1 + k-2 + ....1) recursions....
    etc...

    There are total n / k recursion level on average.

    Thus, the total recursion times would still be k ^ (n / k).


  • 0

    @l-wang I had got the method using DP and DFS with the solution of Word Break, but the runtime was about 30 ms. Then I found that I used String's equals() instead of Set's contains() to compare strings. Finally, my solution became as fast as yours after the modification.


  • 0
    B

    @l-wang said in Java DP+DFS, Memoization+DFS, and DP Pruning Solutions with Analysis:

    I've been struggling with this problem for a long time, and I'd love to share three different strategies I have tried to solve it. All of them are ACed.

    Method 1: DP + DFS. Very similar to Word Break I, but instead of using a boolean dp array, I used an array of Lists to maintain all of the valid start positions for every end position. Then just do classic backtracking to find all solutions. The time complexity is O(n*m) + O(n * number of solutions), where n is the length of the input string, m is the length of the longest word in the dictionary. The run time was 6ms. It is very efficient because DP is used to find out all the valid answers, and no time is wasted on doing the backtracking.

    public List<String> wordBreak(String s, Set<String> wordDict) {
        List<Integer>[] starts = new List[s.length() + 1]; // valid start positions
        starts[0] = new ArrayList<Integer>();
        
        int maxLen = getMaxLen(wordDict);
        for (int i = 1; i <= s.length(); i++) {
            for (int j = i - 1; j >= i - maxLen && j >= 0; j--) {
                if (starts[j] == null) continue;
                String word = s.substring(j, i);
                if (wordDict.contains(word)) {
                    if (starts[i] == null) {
                        starts[i] = new ArrayList<Integer>();
                    }
                    starts[i].add(j);
                }
            }
        }
        
        List<String> rst = new ArrayList<>();
        if (starts[s.length()] == null) {
            return rst;
        }
        
        dfs(rst, "", s, starts, s.length());
        return rst;
    }
    
    
    private void dfs(List<String> rst, String path, String s, List<Integer>[] starts, int end) {
        if (end == 0) {
            rst.add(path.substring(1));
            return;
        }
        
        for (Integer start: starts[end]) {
            String word = s.substring(start, end);
            dfs(rst, " " + word + path, s, starts, start);
        }
    }
    
    private int getMaxLen(Set<String> wordDict) {
        int max = 0;
        for (String s : wordDict) {
            max = Math.max(max, s.length());
        }
        return max;
    }
    

    Method 2: Memoization + Backtracking. Before I came up with Method 1, I also tried using a HashMap to memoize all the possible strings that can be formed starting from index i. I referred to this post from @Pixel_
    The time complexity is O(len(wordDict) ^ len(s / minWordLenInDict)) as @Pixel_ mentioned. The space complexity would be larger than other methods though. Here is my code:

    public List<String> wordBreak(String s, Set<String> wordDict) {
        HashMap<Integer, List<String>> memo = new HashMap<>(); // <Starting index, rst list>
        return dfs(s, 0, wordDict, memo);
    }
    
    private List<String> dfs(String s, int start, Set<String> dict, HashMap<Integer, List<String>> memo) {
        if (memo.containsKey(start)) {
            return memo.get(start);
        }
        
        List<String> rst = new ArrayList<>();
        if (start == s.length()) {
            rst.add("");
            return rst;
        }
        
        String curr = s.substring(start);
        for (String word: dict) {
            if (curr.startsWith(word)) {
                List<String> sublist = dfs(s, start + word.length(), dict, memo);
                for (String sub : sublist) {
                    rst.add(word + (sub.isEmpty() ? "" : " ") + sub);
                }
            }
        }
        
        memo.put(start, rst);
        return rst;
    }
    

    Method 3: DP Prunning + Backtracking. My very first solution is like this: using a boolean array to memoize whether a substring starting from position i to the end is breakable. This works well for worst cases like: s = "aaaaaaaaaaaab", dict = ["a", "aa", "aaa", "aaaa"]. However, for cases like: s = "aaaaaaaaaaaaa", dict = ["a", "aa", "aaa", "aaaa"], the time complexity is still O(2^n). Here is the code:

    public List<String> wordBreak(String s, Set<String> wordDict) {
        List<String> rst = new ArrayList<>();
        if (s == null || s.length() == 0 || wordDict == null) {
            return rst;
        }
        
        boolean[] canBreak = new boolean[s.length()];
        Arrays.fill(canBreak, true);
        StringBuilder sb = new StringBuilder();
        dfs(rst, sb, s, wordDict, canBreak, 0);
        return rst;
    }
    
    private void dfs(List<String> rst, StringBuilder sb, String s, Set<String> dict, 
        boolean[] canBreak, int start) {
        if (start == s.length()) {
            rst.add(sb.substring(1));
            return;
        }
        
        if (!canBreak[start]) {
            return;
        }
        
        for (int i = start + 1; i <= s.length(); i++) {
            String word = s.substring(start, i);
            if (!dict.contains(word)) continue;
            
            int sbBeforeAdd = sb.length();
            sb.append(" " + word);
            
            int rstBeforeDFS = rst.size();
            dfs(rst, sb, s, dict, canBreak, i);
            if (rst.size() == rstBeforeDFS) {
                canBreak[i] = false;
            }
            sb.delete(sbBeforeAdd, sb.length());
        }
    }
    
    private int getMaxLen(Set<String> wordDict) {
        int max = 0;
        for (String s : wordDict) {
            max = Math.max(max, s.length());
        }
        return max;
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.