Sharing my bitmask dynamic programming solution (4ms)


  • 11
    D

    I've basically reused the idea from zzz1322's Java code and reimplemented the idea using bitmask based DP.

    dp[len][state][endnode] indicates the number of patterns with length len, visited nodes equal to state and the last node in the pattern is endnode. The state transition part is fairly straightforward.

    int dp[10][1 << 9][10];
    int skip[10][10];
    bool initialized;
    
    class Solution {
    private:
        void InitSkipArray() {
            memset(skip, 0, sizeof(skip));
            skip[1][3] = skip[3][1] = 2;
            skip[1][7] = skip[7][1] = 4;
            skip[3][9] = skip[9][3] = 6;
            skip[7][9] = skip[9][7] = 8;
            skip[1][9] = skip[9][1] = skip[3][7] = skip[7][3] = skip[2][8] = skip[8][2] = skip[4][6] = skip[6][4] = 5;
        }
        
        void calcDP() {
            initialized = true;
            InitSkipArray();
            memset(dp, 0, sizeof(dp));
            for (int i = 1; i <= 9; i++)
                dp[1][1 << (i - 1)][i] = 1;
            for (int len = 1; len < 9; len++) {
                for (int state = 0; state < (1 << 9); state++) {
                    for (int endNode = 1; endNode <= 9; endNode++) {
                        if (dp[len][state][endNode]) {
                            for (int i = 1; i <= 9; i++) {
                                if ((state & (1 << (i - 1))) == 0) {
                                    int nextState = (state | (1 << (i - 1)));
                                    if (skip[endNode][i] == 0 || ((state & (1 << (skip[endNode][i] - 1))) != 0))
                                        dp[len + 1][nextState][i] += dp[len][state][endNode];
                                }
                            }
                        }
                    }
                }
            }
        }
    public:
        int numberOfPatterns(int m, int n) {
           if (!initialized) {
               calcDP();
           }
           int ans = 0;
           for (int state = 0; state < (1 << 9); state++) {
               for (int len = m; len <= n; len++) {
                   for (int endNode = 1; endNode <= 9; endNode++) {
                       ans += dp[len][state][endNode];
                   }
               }
           }
           return ans;
        }
    };
    

  • 0
    B

    It is 0 ms. will you please give more information about you DP idea?


  • 1
    G

    Here is a commented version for those who still want it :
    The idea is the state represent the path taken :

    000000101 meansethe path visits 1 and 3
    010100010 means the path visits 2, 6 and 8

    int dp[10][1 << 9][10];
    int skip[10][10];
    bool initialized;
    
    class Solution {
    private:
        void InitSkipArray() {
            memset(skip, 0, sizeof(skip));
            skip[1][3] = skip[3][1] = 2;
            skip[1][7] = skip[7][1] = 4;
            skip[3][9] = skip[9][3] = 6;
            skip[7][9] = skip[9][7] = 8;
            skip[1][9] = skip[9][1] = skip[3][7] = skip[7][3] = skip[2][8] = skip[8][2] = skip[4][6] = skip[6][4] = 5;
        }
        
        void calcDP() {
            initialized = true;
            InitSkipArray();
            memset(dp, 0, sizeof(dp));
           // this loops initializes the DP : for every element in the android pattern table, there
           // is a path of length 1, ending at that element, which visit only that element
           // thinks about every path of length 1 : there is nine of them, they visit only one point
            for (int i = 1; i <= 9; i++)
                dp[1][1 << (i - 1)][i] = 1;
            for (int len = 1; len < 9; len++) {
                for (int state = 0; state < (1 << 9); state++) {
                    for (int endNode = 1; endNode <= 9; endNode++) {
                        if (dp[len][state][endNode]) { // if it is zero : it has not been initialized, which means no path ends on this node given the state
                            for (int i = 1; i <= 9; i++) { // i represent the potential next element of the android pattern array that you could visit
                                if ((state & (1 << (i - 1))) == 0) { // if i is the element we are at currently we do not consider it : we want to visit a different element
                                    int nextState = (state | (1 << (i - 1))); // this is how the state bit map would you look like if you were to visit i
                                    if (skip[endNode][i] == 0 || ((state & (1 << (skip[endNode][i] - 1))) != 0))  // skipping logic
                                        dp[len + 1][nextState][i] += dp[len][state][endNode];
                                }
                            }
                        }
                    }
                }
            }
        }
    public:
        int numberOfPatterns(int m, int n) {
           if (!initialized) {
               calcDP();
           }
           int ans = 0;
           for (int state = 0; state < (1 << 9); state++) {
               for (int len = m; len <= n; len++) {
                   for (int endNode = 1; endNode <= 9; endNode++) {
                       ans += dp[len][state][endNode];
                   }
               }
           }
           return ans;
        }
    };
    

  • 3
    G

    I add a simple revise to your code which reduce the array dp[][] to two-dimentional, and get ac in 0ms.
    For the illustration of function "vector<int> countBits(int num)", please refer to 338. counting bits. https://leetcode.com/problems/counting-bits/#/description

    int dp[1 << 9][10];
    int skip[10][10];
    bool initialized;
    
    class Solution {
    private:
        void InitSkipArray() {
            memset(skip, 0, sizeof(skip));
            skip[1][3] = skip[3][1] = 2;
            skip[1][7] = skip[7][1] = 4;
            skip[3][9] = skip[9][3] = 6;
            skip[7][9] = skip[9][7] = 8;
            skip[1][9] = skip[9][1] = skip[3][7] = skip[7][3] = skip[2][8] = skip[8][2] = skip[4][6] = skip[6][4] = 5;
        }
        
        void calcDP() {
            initialized = true;
            InitSkipArray();
            memset(dp, 0, sizeof(dp));
            for (int i = 1; i <= 9; i++)
                dp[1 << (i - 1)][i] = 1;
                for (int state = 0; state < (1 << 9); state++) {
                    for (int endNode = 1; endNode <= 9; endNode++) {
                        if (dp[state][endNode]) {
                            for (int i = 1; i <= 9; i++) {
                                if ((state & (1 << (i - 1))) == 0) {
                                    int nextState = (state | (1 << (i - 1)));
                                    if (skip[endNode][i] == 0 || ((state & (1 << (skip[endNode][i] - 1))) != 0))
                                        dp[nextState][i] += dp[state][endNode];
                                }
                            }
                        }
                    }
                }
        }
        
    
        vector<int> countBits(int num) {
            vector<int> bits(num+1, 0);
            for (int i = 1; i <= num; i++) 
                bits[i] = bits[i & (i-1)] + 1;
            return bits;
        }
        
    public:
        int numberOfPatterns(int m, int n) {
           if (!initialized) {
               calcDP();
           }
           int ans = 0;
           vector<int> count=countBits(1<<9);
           for (int state = 0; state < (1 << 9); state++) {
           // the idea here is we do not neccessarily need to keep track of the length, 
           // while substitute it with the number of 1's in bits, to restrain the range from m to n.
               if(count[state]>=m && count[state]<=n)
                   for (int endNode = 1; endNode <= 9; endNode++) {
                       ans += dp[state][endNode];
                   }
           }
           return ans;
        }
    };

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.