# Evolve from straightfoward solution to optimal

• 1.O(mnk). Regular matrix multipication. just fill the output matrix row by row as we do matrix multiplication by hand.

``````vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) {
int m = A.size(), n = A[0].size(), k = B[0].size();
vector<vector<int>> res(m,vector<int>(k));
for(int i=0; i<m;i++)
for(int j=0;j<k;j++)
for(int p=0;p<n;p++)
res[i][j]+=A[i][p]*B[p][j];
return res;
}
``````
1. O(mn'k). For sparse matrix, we do not need to multiply a whole row of A with a whole column of B, we only need to consider the non 0 entries. Here n' is the average number of non 0 entry of a row in A and a column in B. Idealy, we should only consider the intersection of row i of A and colum j of B so there is still redundancy in this approach.
``````vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) {
int m = A.size(), n = A[0].size(), k = B[0].size();
vector<unordered_set<int>> As(m), Bs(k);
for(int i=0;i<m;i++)
for(int j=0; j<n;j++)
if(A[i][j]) As[i].insert(j);
for(int i=0;i<n;i++)
for(int j=0; j<k;j++)
if(B[i][j]) Bs[j].insert(i);
vector<vector<int>> res(m,vector<int>(k));
for(int i=0; i<m;i++)
for(int j=0;j<k;j++)
if (As[i].size() < Bs[j].size())
for(auto a:As[i])
res[i][j]+=A[i][a]*B[a][j];
else
for(auto b:Bs[j])
res[i][j]+=A[i][b]*B[b][j];
return res;
}
``````
1. O(A'k). Another way to improve over #1 is to process each entry in A one by one. A' is the non 0 entries in A.
``````vector<vector<int>> mult(vector<vector<int>> &A, vector<vector<int>> &B) {
int m=A.size(), n=A[0].size(), k=B[0].size();
vector<vector<int>> res(m,vector<int>(k,0));
for(int i=0;i<m;i++)
for(int j=0;j<n;j++) {
if(!A[i][j]) continue;
for(int p=0;p<k;p++)
res[i][p]+=A[i][j]*B[j][p];
}
return res;
}
``````
1. O(A'k') #3 does not consider the fact that B is sparse. We can extract the non 0 entries first. k' is average number of non 0 entries in a row of B. Only the non 0 entries are multiplied. There is no redundancy. So this should be the optimal solution.
``````vector<vector<int>> mult2(vector<vector<int>>& A, vector<vector<int>>& B)  {
int m = A.size(), n = B.size(), k = B[0].size();
vector<vector<int>> C(m,vector<int>(k));
vector<unordered_set<int>> Bs(n);
for(int i=0;i<n;i++)
for(int j=0; j<k;j++)
if(B[i][j]) Bs[i].insert(j);
for(int i=0;i<m;i++)
for(int j=0;j<n;j++) {
if(!A[i][j]) continue;
for(auto b:Bs[j])
C[i][b] += A[i][j]*B[j][b];
}
return C;
}
``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.