Evolve from straightfoward solution to optimal


  • 2

    1.O(mnk). Regular matrix multipication. just fill the output matrix row by row as we do matrix multiplication by hand.

    vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) {
    	int m = A.size(), n = A[0].size(), k = B[0].size();
    	vector<vector<int>> res(m,vector<int>(k));
    	for(int i=0; i<m;i++)	
    		for(int j=0;j<k;j++) 
    			for(int p=0;p<n;p++) 
    				res[i][j]+=A[i][p]*B[p][j];
    	return res;
    }
    
    1. O(mn'k). For sparse matrix, we do not need to multiply a whole row of A with a whole column of B, we only need to consider the non 0 entries. Here n' is the average number of non 0 entry of a row in A and a column in B. Idealy, we should only consider the intersection of row i of A and colum j of B so there is still redundancy in this approach.
    vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) {
    	int m = A.size(), n = A[0].size(), k = B[0].size();
    	vector<unordered_set<int>> As(m), Bs(k);
    	for(int i=0;i<m;i++)	
    		for(int j=0; j<n;j++) 
    			if(A[i][j]) As[i].insert(j);
    	for(int i=0;i<n;i++)	
    		for(int j=0; j<k;j++) 
    			if(B[i][j]) Bs[j].insert(i);
    	vector<vector<int>> res(m,vector<int>(k));
    	for(int i=0; i<m;i++)	
    		for(int j=0;j<k;j++) 
    			if (As[i].size() < Bs[j].size()) 
    				for(auto a:As[i]) 
    					res[i][j]+=A[i][a]*B[a][j];
    			else 
    				for(auto b:Bs[j]) 
    					res[i][j]+=A[i][b]*B[b][j];
    	return res;
    }
    
    1. O(A'k). Another way to improve over #1 is to process each entry in A one by one. A' is the non 0 entries in A.
    vector<vector<int>> mult(vector<vector<int>> &A, vector<vector<int>> &B) {
    	int m=A.size(), n=A[0].size(), k=B[0].size();
    	vector<vector<int>> res(m,vector<int>(k,0));
    	for(int i=0;i<m;i++) 
    		for(int j=0;j<n;j++) {
    			if(!A[i][j]) continue;
    			for(int p=0;p<k;p++) 	
    				res[i][p]+=A[i][j]*B[j][p];
    		}
    	return res;
    }
    
    1. O(A'k') #3 does not consider the fact that B is sparse. We can extract the non 0 entries first. k' is average number of non 0 entries in a row of B. Only the non 0 entries are multiplied. There is no redundancy. So this should be the optimal solution.
    vector<vector<int>> mult2(vector<vector<int>>& A, vector<vector<int>>& B)  {
    	int m = A.size(), n = B.size(), k = B[0].size();
    	vector<vector<int>> C(m,vector<int>(k));
    	vector<unordered_set<int>> Bs(n);
    	for(int i=0;i<n;i++)	
    		for(int j=0; j<k;j++) 
    			if(B[i][j]) Bs[i].insert(j);
    	for(int i=0;i<m;i++) 
    		for(int j=0;j<n;j++) {
    			if(!A[i][j]) continue;	
    			for(auto b:Bs[j]) 
    				C[i][b] += A[i][j]*B[j][b];
    		}
    	return C;
    }	
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.