63ms Java solution beats 90%


  • 0
    S
    public class Solution {
        class TrieNode {
            TrieNode[] children;
            List<Integer> indexes;
            public TrieNode() {
                this.children = new TrieNode[26];
                this.indexes = new ArrayList<>();
            }
        }
        public List<List<String>> wordSquares(String[] words) {
            List<List<String>> results = new ArrayList<>();
            int len = words.length;
            if (len == 0) return results;
            TrieNode root = new TrieNode();
            for (int i = 0; i < len; i++) {
                addWord(root, words[i], i);
            }
            
            dfs(words, new ArrayList<>(), 0, words[0].length(), results, root);
            return results;
        }
        
        private void dfs(String[] words, List<String> usedWords, int row, int maxRow, List<List<String>> results, TrieNode root){
            if (row == maxRow) {
                results.add(new ArrayList<>(usedWords));
            } else {
                List<Integer> searches;
                if (row == 0) {
                    searches = root.indexes;
                } else {
                    searches = getIndexes(usedWords, row, root);
                }
                if (searches == null) return;
                for (int i : searches) {
                    usedWords.add(words[i]);
                    dfs(words, usedWords, row + 1, maxRow, results, root);
                    usedWords.remove(row);
                }
            }
        }
        
        private List<Integer> getIndexes(List<String> usedWords, int row, TrieNode root) {
            for (String w : usedWords) {
                root = root.children[w.charAt(row) - 'a'];
                if (root == null) return null;
            }
            return root.indexes;
        }
        
        
        private void addWord(TrieNode root, String word, int index) {
            char[] vals = word.toCharArray();
            root.indexes.add(index);
            for (int i = 0; i < vals.length; i++) {
                int idx = vals[i] - 'a';
                if (root.children[idx] == null) {
                    root.children[idx] = new TrieNode();
                }
                root.children[idx].indexes.add(index);
                root = root.children[idx];
            }
        }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.