Share my Java Solution DFS + Standard Trie


  • 0
    M
            private class TrieNode {
    		String s;
    		boolean isWord;
    		Map<Character, TrieNode> next;
    		public TrieNode() {
    			next = new HashMap<>();
    		}
    	}
    	private class Trie {
    		TrieNode root = new TrieNode();
    		public Trie() {}
    		public Trie(String[] words) {
    			for (String s : words) insert(s);
    		}
    		private void insert(String word) {
    			TrieNode cur = root;
    			for (int i=0; i<word.length(); ++i) {
    				char c = word.charAt(i);
    				if (!cur.next.containsKey(c)) {
    					cur.next.put(c, new TrieNode());
    				}
    				cur = cur.next.get(c);
    			}
    			cur.isWord = true;
    			cur.s = word;
    		}
    	}
    	
    	public List<List<String>> wordSquares(String[] words) {
            List<List<String>> res = new ArrayList<>();
    		if (words == null || words.length == 0) return res;
    		int n = words[0].length();
    		if (n == 0) return res;
    		Trie trie = new Trie(words);
    		dfs(trie.root, trie.root, new ArrayList<String>(), 0, n, res);
    		return res;
            }
    	
    	private void dfs (TrieNode root, TrieNode node, List<String> pre, int j, int n, List<List<String>> res) {
    		if (j == n) {
    			if (!node.isWord) return;
    			pre.add(node.s);
    			
    			if (pre.size() == n) {
    				res.add(new ArrayList<String>(pre));
    			} else {
    			        //node.isWord = false; //Commented so that words can be used more than once
    			        dfs(root, root, pre, 0, n, res);
    			        //node.isWord = true;
                            }
    			pre.remove(pre.size() - 1);
    		} else {
    			if (j < pre.size()) {
    				char toFind = pre.get(j).charAt(pre.size());
    				if (node.next.containsKey(toFind)) {
    					dfs(root, node.next.get(toFind), pre, j + 1, n, res);
    				}
    			} else {
    				for (TrieNode next : node.next.values()) {
    					dfs(root, next, pre, j + 1, n, res); 
    				}
    			}
    		}
    		
    	}
    

  • 0
    M

    Just realized my original solution is too slow. After utilizing the condition that all characters in words are lower case letters from 'a' to 'z', I revised my solution, which is now much faster.

        private class TrieNode {
            boolean isWord;
            String s;
            TrieNode[] next;
            public TrieNode() {
                next = new TrieNode[26];
            }
        }
        private class Trie{
            TrieNode root = new TrieNode();
            public Trie() {
            }
            public Trie(String[] words) {
                for (String s : words) insert(s);
            }
            public void insert(String word) {
                TrieNode cur = root;
                for (int i=0; i<word.length(); ++i) {
                    int c = word.charAt(i) - 'a';
                    if (cur.next[c] == null) cur.next[c] = new TrieNode();
                    cur = cur.next[c];
                }
                cur.isWord = true;
                cur.s = word;
            }
        }
        
        private void dfs(TrieNode root, TrieNode node, int j, int n, List<String> pre, List<List<String>> res) {
            if (node == null) return;
            if (j == n) {
                if (!node.isWord) return;
                pre.add(node.s);
                if (pre.size() == n) {
                    res.add(new ArrayList<String>(pre));
                } else {
                    dfs(root, root, 0, n, pre, res);
                }
                pre.remove(pre.size() - 1);
            } else {
                if (j < pre.size()) {
                    int toFind = pre.get(j).charAt(pre.size()) - 'a';
                    TrieNode next = node.next[toFind];
                    dfs(root, next, j + 1, n, pre, res);
                    
                } else {
                    for (int i=0; i<node.next.length; ++i) {
                        dfs(root, node.next[i], j + 1, n, pre, res);
                    }
                }
            }
            
        }
        
        public List<List<String>> wordSquares(String[] words) {
            List<List<String>> res = new ArrayList<>();
            if (words == null || words.length == 0) return res;
            Trie trie = new Trie(words);
            dfs(trie.root, trie.root, 0, words[0].length(), new ArrayList<String>(), res);
            return res;
        }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.