post my 22ms solution here. It's similar idea. but the traverse is a little different. The traverse is going row by row and along diagonal. It uses an array of Trie nodes to keep track of the current progress in the horizontal direction, and one progress Trie on vertical direction using the peek character of the horizontal Trie nodes. Once the vertical progress Trie reaches diagonal line, it's added as the new horizontal Trie, and a new vertical Trie starts.

public class Solution {
private Trie root = new Trie();
public List<List<String>> wordSquares(String[] words) {
List<List<String>> result = new ArrayList<>();
for (String word : words) addWord(word);
int len = words[0].length();
Trie[] heads = new Trie[len];
for (int i = 0; i < root.childrenSize; i++) {
int cid = root.childrenIndexs[i];
heads[0] = root.children[cid];
wordSquares(result, heads, 0, new Trie[len], 0, root, len);
}
return result;
}
private void wordSquares(List<List<String>> result, Trie[] heads, int end, Trie[] newHeads, int newEnd, Trie progress, int len) {
if (end == len - 1) {
ArrayList<String> r = new ArrayList<>(len);
for (Trie trie : heads) r.add(trie.word);
result.add(r);
return;
}
Trie current = newEnd > end ? progress : heads[newEnd];
for (int i = 0; i < current.childrenSize; i++) {
int cid = current.childrenIndexs[i];
if (progress.children[cid] != null) {
newHeads[newEnd] = current.children[cid];
if (newEnd > end) wordSquares(result, newHeads, newEnd, new Trie[len], 0, root, len);
else wordSquares(result, heads, end, newHeads, newEnd + 1, progress.children[cid], len);
}
}
}
class Trie {
Trie[] children = new Trie[26];
int[] childrenIndexs = new int[26];
int childrenSize;
String word;
}
private void addWord(String word) {
Trie curNode = root;
for (int i = 0; i < word.length(); i++) {
if (curNode.children[word.charAt(i) - 'a'] == null) {
curNode.children[word.charAt(i) - 'a'] = new Trie();
curNode.childrenIndexs[curNode.childrenSize] = word.charAt(i) - 'a';
curNode.childrenSize++;
}
curNode = curNode.children[word.charAt(i) - 'a'];
}
curNode.word = word;
}
}