One way to optimize multiplication of two sparse matrices.

  • 0

    Let the input matrices be A and B.
    Refresher: Matrix multiplication essentially requires calculating the dot product (vector product with scalar result) of each row of A with the corresponding column of B. To this end, the number of rows of the first matrix must be equal to the number of columns of the second one. If A is mxn and B is nxp, then the the result will be mxp. The time complexity of that operation if O(mxnxp). If the matrices are nxn square matrices, complexity becomes O(n^3)

    Solution: In the case of sparse matrices, many rows and columns will have only zeros, therefore multiplying those rows and columns will be unnecessary. Ignoring the all-zero rows of A, and all-zero columns of B will greatly reduce the processing time. One way of doing that is keeping the non-zero rows of A in a set (nzRows), whereas non-zero columns of B in another set (nzCols). Below is the c++ code doing that.


    vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) {
    // Check the validity of the matrices
    if(A.size() == 0 || A[0].size() == 0 || A[0].size() != B.size() || B[0].size() == 0) return vector<vector<int>>();
    int m = A.size(), n = A[0].size(), p = B[0].size();
    vector<vector<int>> res(m, vector<int>(p, 0)); // Result vector will be all zeros
    unordered_set<int> nzRows, nzCols; // nonzero rows of A and nonzero columns of B
    for(int i = 0; i < m; i++) {
    for(int j = 0; j < n; j++){
    if(A[i][j] != 0) {
    nzRows.insert(i); // Mark the nonzero rows of A
    for(int j = 0; j < p; j++) {
    for(int i = 0; i < n; i++) {
    if(B[i][j] != 0) {
    nzCols.insert(j); // mark the nonzero columns of B
    for(int i = 0; i < m; i++) { // All the rows of A
    if(nzRows.count(i) == 0) continue; // if the row i is all zeros jump to the next iteration
    for(int j = 0; j < p; j++) { // All cols of B
    if(nzCols.count(j) == 0) continue; // if the col j is all zeros jump to the next iteration
    for(int k = 0; k < n; k++) res[i][j]+=(A[i][k]*B[k][j]);
    return res;

Log in to reply

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.