My 30ms bidirectional BFS and DFS based Java solution


  • 11
    H

    Regarding speed, 30ms solution beats 100% other java solutions.

    Couple of things that make this solution fast:

    1. We use Bidirectional BFS which always expand from direction with less nodes

    2. We use char[] to build string so it would be fast

    3. Instead of scanning dict each time, we build new string from existing string and check if it is in dict

    Below is my commented code.

    public List<List<String>> findLadders(String beginWord, String endWord, Set<String> wordList) {
        //we use bi-directional BFS to find shortest path
        
        Set<String> fwd = new HashSet<String>();
        fwd.add(beginWord);
        
        Set<String> bwd = new HashSet<String>();
        bwd.add(endWord);
        
        Map<String, List<String>> hs = new HashMap<String, List<String>>();
        BFS(fwd, bwd, wordList, false, hs);
        
        List<List<String>> result = new ArrayList<List<String>>();
        
        //if two parts cannot be connected, then return empty list
        if(!isConnected) return result;
        
        //we need to add start node to temp list as there is no other node can get start node
        List<String> temp = new ArrayList<String>();
        temp.add(beginWord);
        
        DFS(result, temp, beginWord, endWord, hs);
        
        return result;
    }
    
    //flag of whether we have connected two parts
    boolean isConnected = false;
    
    public void BFS(Set<String> forward, Set<String> backward, Set<String> dict, boolean swap, Map<String, List<String>> hs){
        
        //boundary check
        if(forward.isEmpty() || backward.isEmpty()){
            return;
        }
        
        //we always do BFS on direction with less nodes
        //here we assume forward set has less nodes, if not, we swap them
        if(forward.size() > backward.size()){
            BFS(backward, forward, dict, !swap, hs);
            return;
        }
        
        //remove all forward/backward words from dict to avoid duplicate addition
        dict.removeAll(forward);
        dict.removeAll(backward);
        
        //new set contains all new nodes from forward set
        Set<String> set3 = new HashSet<String>();
        
        //do BFS on every node of forward direction
        for(String str : forward){
            //try to change each char of str
            for(int i = 0; i < str.length(); i++){
                //try to replace current char with every chars from a to z 
                char[] ary = str.toCharArray();
                for(char j = 'a'; j <= 'z'; j++){
                    ary[i] = j;
                    String temp = new String(ary);
                    
                    //we skip this string if it is not in dict nor in backward
                    if(!backward.contains(temp) && !dict.contains(temp)){
                        continue;
                    }
                    
                    //we follow forward direction    
                    String key = !swap? str : temp;
                    String val = !swap? temp : str;
    
                    if(!hs.containsKey(key)) hs.put(key, new ArrayList<String>());
                    
                    //if temp string is in backward set, then it will connect two parts
                    if(backward.contains(temp)){
                        hs.get(key).add(val);
                        isConnected = true;
                    }
                    
                    //if temp is in dict, then we can add it to set3 as new nodes in next layer
                    if(!isConnected && dict.contains(temp)){
                        hs.get(key).add(val);
                        set3.add(temp);
                    }
                }
                
            }
        }
        
        //to force our path to be shortest, we will not do BFS if we have found shortest path(isConnected = true)
        if(!isConnected){
            BFS(set3, backward, dict, swap, hs);
        }
    }
    
    public void DFS(List<List<String>> result, List<String> temp, String start, String end, Map<String, List<String>> hs){
        //we will use DFS, more specifically backtracking to build paths
        
        //boundary case
        if(start.equals(end)){
            result.add(new ArrayList<String>(temp));
            return;
        }
        
        //not each node in hs is valid node in shortest path, if we found current node does not have children node,
        //then it means it is not in shortest path
        if(!hs.containsKey(start)){
            return;
        }
        
        for(String s : hs.get(start)){
            temp.add(s);
            DFS(result, temp, s, end, hs);
            temp.remove(temp.size()-1);
            
        }
    }
    

    The main idea is from awesome solution


  • 1
    O

    Great and clear code. Got to learn from your implementation. Excellent explanation.
    I just got 27ms 99.7% beat.

    public class Solution {
        public List<List<String>> findLadders(String beginWord, String endWord, Set<String> wordList) { 
            Set<String> fwdQueue = new HashSet<String>(), bckQueue = new HashSet<String>(); 
            HashMap<String, ArrayList<String>> h = new HashMap<String, ArrayList<String>>(); 
            fwdQueue.add(beginWord); 
            bckQueue.add(endWord); 
            findLadder(fwdQueue, bckQueue, wordList, true, h); 
             
            List<List<String>> ans = new ArrayList<List<String>>(); 
            List<String> cur = new ArrayList<String>(); 
            cur.add(beginWord); 
            printPath(beginWord, endWord, cur, h, ans);
            return ans;
        } 
         
        private void findLadder(Set<String> fwdQueue, Set<String> bckQueue, Set<String> wordList, boolean direction, HashMap<String, ArrayList<String>> h) { 
            boolean found = false; 
            if (fwdQueue.size() == 0 || fwdQueue.size() == 0) return; 
            if (fwdQueue.size() > bckQueue.size()) { 
                findLadder(bckQueue, fwdQueue, wordList, !direction, h); 
                return; 
            } 
            wordList.removeAll(fwdQueue);
            wordList.removeAll(bckQueue); 
            Set<String> setNew = new HashSet<String>(); 
             
            for (String s : fwdQueue) {
                char[] chs = s.toCharArray(); 
                for (int i = 0; i < chs.length; i++) { 
                    char tmpc = chs[i]; 
                    for (char j = 'a'; j <= 'z'; j++) if (j != tmpc) { 
                        chs[i] = j;
                        String tmp = new String(chs); 
                         
                        if (bckQueue.contains(tmp)) { 
                            found = true; 
                            addPath(s, tmp, direction, h); 
                        } else 
                        if (!found && wordList.contains(tmp)) {
                            setNew.add(tmp); 
                            addPath(s, tmp, direction, h); 
                        } 
                    } 
                    chs[i] = tmpc; 
                } 
            } 
             
            if (!found) { 
                findLadder(bckQueue, setNew, wordList, !direction, h); 
            } 
        } 
         
        private void addPath(String s, String t, boolean dir, HashMap<String, ArrayList<String>> h) { 
            String key = dir ? s : t, val = dir ? t : s; 
            ArrayList l = h.containsKey(key) ? h.get(key) : new ArrayList<String>(); 
            l.add(val); 
            h.put(key, l); 
        } 
         
        private void printPath(String s, String target, List<String> cur, HashMap<String, ArrayList<String>> h, List<List<String>> ans) { 
            if (s.equals(target)) { 
                ans.add(new ArrayList<String>(cur)); 
                return; 
            } 
            if (!h.containsKey(s)) return; 
            ArrayList<String> l = h.get(s); 
            for (String i : l) { 
                cur.add(i); 
                printPath(i, target, cur, h, ans); 
                cur.remove(cur.size() - 1); 
            } 
        }
    }
    

  • 0
    O

    I post my previous ugly code for comparison:

    public static List<List<String>> findLadders(String beginWord, String endWord, Set<String> wordList) { 
        HashMap<String, ArrayList<String>> shead = new HashMap<String, ArrayList<String>>(), stail = new HashMap<String, ArrayList<String>>(); 
        ArrayList<List<String>> qhead = new ArrayList<List<String>>(), qtail = new ArrayList<List<String>>(); 
        List<List<String>> ans = new ArrayList<List<String>>(); 
        List<String> cur, a1New, a2New, connect = new ArrayList<String>(); 
        if (beginWord.equals(endWord)) { 
            cur = new ArrayList<String>(); 
            cur.add(beginWord); 
            ans.add(cur); 
            return ans; 
        } 
        List<String> a1 = new ArrayList<String>(), a2 = new ArrayList<String>(); 
        a1.add(beginWord); a2.add(endWord); 
        qhead.add(a1); putOneMore(shead, beginWord, "."); 
        qtail.add(a2); putOneMore(stail, endWord, "."); 
        int l1 = 0,l2 = 0; 
        //System.out.println("Begin to enter main loop..");
        while (l1 < qhead.size() && l2 < qtail.size()) {
          //System.out.println("l1 and l2 = " + l1 + ", " + l2);
            a1 = qhead.get(l1); 
            a1New = new ArrayList<String>(); 
            for (String iter : a1) { 
              //System.out.println("Processing head string: " + iter);
                char[] a1chars = iter.toCharArray(); 
                for (int pos = 0; pos < a1chars.length; pos++) { 
                    char tmpc = a1chars[pos]; 
                    for (char r = 'a'; r <= 'z'; r++) 
                        if (r != tmpc) { 
                            a1chars[pos] = r; 
                            String sTry = new String(a1chars); 
                            //System.out.println("trying new string: " + sTry); 
                            if (stail.containsKey(sTry)) { //found in tail set, found a connection
                              //System.out.println("Found a connect in tail: " + sTry);
                                connect.add(sTry); 
                                putOneMore(shead, sTry, iter); 
                            } else  
                            if (wordList.contains(sTry)) { //found in rest wordList 
                                if (connect.size() > 0) continue;
                                //System.out.println("Found a str in wordList: " + sTry);
                                a1New.add(sTry); 
                                putOneMore(shead, sTry, iter);
                                //System.out.println("test "+ sTry + " in head: " + shead.containsKey(sTry));
                            } else 
                                continue; //not found
                        }
                    a1chars[pos] = tmpc;
                } 
            } // have searched all the words in current layer 
            if (connect.size() > 0) break; 
            if (a1New.size() > 0) { 
                qhead.add(a1New);  
                wordList.removeAll(a1New); 
            } //head queue add a layer 
             
            a2 = qtail.get(l2); 
            a2New = new ArrayList<String>(); 
            for (String iter : a2) { 
              //System.out.println("Processing tail string: " + iter);
                char[] a2chars = iter.toCharArray(); 
                for (int pos = 0; pos < a2chars.length; pos++) { 
                    char tmpc = a2chars[pos]; 
                    for (char r = 'a'; r <= 'z'; r++) 
                        if (r != tmpc) { 
                            a2chars[pos] = r; 
                            String sTry = new String(a2chars); 
                            //System.out.println("trying new string: " + sTry);
                            if (shead.containsKey(sTry)) { //found in head set, found a connection
                              //System.out.println("Found a connect in head: " + sTry);
                                connect.add(sTry); 
                                putOneMore(stail, sTry, iter); 
                            } else  
                            if (wordList.contains(sTry)) { //found in rest wordList 
                                if (connect.size() > 0) continue; 
                                a2New.add(sTry); 
                                putOneMore(stail, sTry, iter); 
                            } else 
                                continue; //not found 
                        }
                    a2chars[pos] = tmpc;
                } 
            } // have searched all the words in current layer 
            if (connect.size() > 0) break; 
            if (a2New.size() > 0) { 
                qtail.add(a2New);  
                wordList.removeAll(a2New); 
            } 
            l1++; l2++;
            //for (String ss : wordList) System.out.print(ss+ ", "); System.out.println("--> displayed wordList");
        } 
        //System.out.println("connect #: " + connect.size());
        wrSolution(ans, connect, shead, stail); 
        return ans; 
    } 
     
    private static void wrSolution(List<List<String>> ans, List<String> connect, HashMap<String, ArrayList<String>> shead, HashMap<String, ArrayList<String>> stail) { 
        List<List<String>> list1, list2; 
        for (String cnnt : connect) {
          //System.out.println("connect: " + cnnt);
            list1 = new ArrayList<List<String>>();
            list2 = new ArrayList<List<String>>();
            solve(cnnt, shead, list1, new ArrayList<String>(), true); 
            solve(cnnt, stail, list2, new ArrayList<String>(), false); 
            for (List<String> i : list1) 
                for (List<String> j : list2) { 
                    List<String> cur = new ArrayList<String>(i); 
                    cur.add(cnnt); 
                    cur.addAll(j); 
                    ans.add(cur); 
                } 
        } 
    } 
     
    private static void solve(String s, HashMap<String, ArrayList<String>> shead, List<List<String>> list, List<String> cur, boolean headOrTail) {          
        List<String> next = shead.get(s);
        if (next.get(0).equals(".")) {  //found  
          if (!headOrTail)  
            list.add(new ArrayList<String>(cur));
          else {
            List<String> rev = new ArrayList<String>();
            for (int j = cur.size() - 1; j >= 0; j--) rev.add(cur.get(j));
            list.add(rev);
          }
            return; 
        } 
        for (String i : next) { 
            cur.add(i);
            //System.out.println("solver adding: " + i);
            solve(i, shead, list, cur, headOrTail); 
            cur.remove(cur.size() - 1); 
        } 
    } 
     
    private static void putOneMore(HashMap<String, ArrayList<String>> h, String key, String val) { 
        ArrayList<String> l = !h.containsKey(key) ? new ArrayList<String>() : h.get(key); 
        l.add(val);
        h.put(key, l);
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.