The idea of this algorithm is to find the minimum number of multiplication operations needed to multiply a chain of matrices. Because matrix multiplication is associative there can be more than one way of multiplying the chain of matrices and each way could incur a different number of multiplication operations.
Consider a chain of matrices: ( A1, A2, A3 )
Ways of multiplying this chain of A1, A2 & A3 is as below…
( ( A1 . A2 ) . A3 )
( A1 . ( A2 . A3 ) )
Note : Matrix A1 can be multiplied with Matrix A2 only if number of columns of A1 is equal to number of rows of A2. Similarly A2 can be multiplied with A3 if number of columns in A2 is same as the number of rows in A3.
Example
Matrices : A1 dimensions: ( 3 * 5 ) , A2 dimensions: ( 5 * 4 ) and A3 dimensions: ( 4 * 6 ) Option 1 : ( ( A1 . A2 ) . A3 ) = ( ( 3 * 5 ) . ( 5 * 4 ) ) . ( 4 * 6 ) Option 2 : ( A1 . ( A2 . A3 ) ) = ( 3 * 5 ) . ( ( 5 * 4 ) . ( 4 * 6) )
Steps | Option 1 | Option 2 |
---|---|---|
1 | Multiplication operations ( 3 * 5 ) . ( 5 * 4 ) = 3 . 5 . 4 = 60 | Multiplication operations ( 5 * 4 ) . ( 4 * 6 ) = 5 . 4 . 6 = 120 |
2 | Resultant matrix ( 3 * |
Resultant matrix ( 5 * |
3 | Multiplication operations ( 3 * 4 ) . ( 4 * 6 ) = 3 . 4 . 6 = 72 | Multiplication operations ( 3 * 5 ) . ( 5 * 6 ) = 3 . 5 . 6 = 90 |
4 | Resultant matrix ( 3 * |
Resultant matrix ( 3 * |
5 | Total Operations = 60 + 72 = 132 | Total Operations = 120 + 90 = 210 |
Option 1 is clearly efficient than Option 2.
Generalizing based on the above example, consider
Matrices | Dimensions |
---|---|
A1 | (Rows) P0 X (Columns) P1 |
A2 | (Rows) P1 X (Columns) P2 |
A3 | (Rows) P2 X (Columns) P3 |
… | … |
AN | (Rows) PN-1 X (Columns) PN |
Let M[1, N] represents the minimum number of multiplications needed for computing the product of A1, A2, ..., AN. M[1, N] = minimum of ( M[1, 1] + M[2, N] + P0 . P1 . PN , M[1, 2] + M[3, N] + P0 . P2 . PN , M[1, 3] + M[4, N] + P0 . P3 . PN , ... M[1, N-1] + M[N, N] + P0 . PN-1. PN )
M[1, 1] + M[2, N] + P0 . P1 . PN indicates A1 . ( A2 . A3 … AN ) The multiplication A1 . ( A2 . A3 … AN ) cancels out P2, P3, P4, …, PN-1. M[1, 2] + M[3, N] + P0 . P2 . PN indicates ( A1 . A2 ) . ( A3 … AN ) The multiplication ( A1 . A2 ) . ( A3 … AN ) cancels out P3, P4, …, PN-1. … M[1, N-1] + M[N, N] + P0 . PN-1 . PN indicates ( A1 . A2 . A3 . AN-1 ) . AN The multiplication ( A1 . A2 . A3 . AN-1 ) . AN cancels out P1, P2, …, PN-2.
Thus we can generalize For k from i upto j-1 M[ i, j ] = min ( M[ i, k ] + M[ k+1, j ] + P[ i-1 ] . P[ k ] . P[ j ] )
Below is an example of bottom up calculations for finding the minimum number of multiplication operations needed for multiplying the matrices Number of multiplications needed for matrices chain of length 1 is 0.
M[1,1] = 0, M[2,2] = 0, M[3,3] = 0, M[4,4] = 0
Finding the least number of multiplication needed for matrices chain of length 2
for(1<=k<2) M[1,2] = min( M[1,1] + M[2,2] + P0.P1.P2 i.e 0 + 0 + 5.4.6 = 120 ) = min(120) = 120 for(2<=k=2) M[2,3] = min( M[2,2] + M[3,3] + P1.P2.P3 i.e 0 + 0 + 4.6.2 = 48 ) = min(48) = 48 for(3<=k<4) M[3,4] = min( M[3,3] + M[4,4] + P2.P3.P4 i.e 0 + 0 + 6.2.7 = 84 ) = min(84) = 84
Finding the least number of multiplication needed for matrices chain of length 3
for(1<=k<3) M[1,3] = min( ( M[1,1] + M[2,3] + P0.P1.P3 ) i.e 0 + 48 + 5.4.2 = 88 , ( M[1,2] + M[3,3] + P0.P1.P3 ) i.e 120 + 0 + 6.5.2 = 180 ) = min(88, 180) = 88 for(2<=k<4) M[2,4] = min( ( M[2,2] + M[3,4] + P1.P2.P4 ) i.e 0 + 84 + 4.6.7 = 252, ( M[2,3] + M[4,4] + P1.P3.P4 ) i.e 48 + 0 + 4.2.7 = 104 ) = min(252, 104) = 104
Finding the least number of multiplication needed for matrices chain of length 4
for(1<=k<4) M[1,4] = min( ( M[1,1] + M[2,4] + P0.P1.P4 ) i.e 0 + 104 + 5.4.7 = 244, ( M[1,2] + M[3,4] + P0.P2.P4 ) i.e 120 + 84 + 5.6.7 = 414, ( M[1,3] + M[4,4] + P0.P3.P4 ) i.e 88 + 0 + 5.2.7 = 158 ) = min(244, 414, 158) = 158
Time complexity : O ( n 3 ), where n is the number of matrices.
Implementation of Matrix Chain Multiplication
from typing import List # For annotations
# Below function computes the minimum number of multiplications needed to
# find the product of the chain of matrices in bottom up fashion
def ChainMultiplication (p : List[int]) -> int :
max_val = 99999999
# Number of matrices
N = len(p) - 1
# We want to return M[1][N] as the optimal cost of product of multiplying 1..N matrices
M = [0] * ( N + 1 )
for r in range ( N + 1 ) :
M[r] = [0] * ( N + 1 )
# Multiplications needed for a single matrix is 0
for i in range ( 1, N + 1 ) :
M[i][i] = 0
# Loop from chain len of size 2 upto the N (number of matrices)
for length in range ( 2, N + 1 ) :
# Chain of length 2 [A1,A2], [A2,A3] and [A3,A4]
# Chain of length 3 [A1,A2,A3] and [A2,A3,A4]
# Chain of length 4 [A1,A2,A3,A4]
# For example: For a chain of length 2, i goes from 1 upto 3 and j goes from 2 upto 4
# i j
# A1 A2
# A2 A3
# A3 A4
#
for i in range ( 1, N - length + 2 ) :
j = i + length - 1
M[i][j] = max_val
for k in range ( i, j ) : # Since i <= k < j, k goes from i upto j-1;
q = M[i][k] + M[k+1][j] + p[i-1] * p[k] * p[j]
if ( M[i][j] > q ) :
M[i][j] = q
return M[1][N]
dimensions1 = [40, 20, 30, 10, 30] # 26000
print(ChainMultiplication (dimensions1))
dimensions2 = [10, 20, 30, 40, 30] # 30000
print(ChainMultiplication (dimensions2))
dimensions3 = [10, 20, 30] # 6000
print(ChainMultiplication (dimensions3))
dimensions4 = [5, 4, 6, 2, 7] # 158
print(ChainMultiplication (dimensions4))
Output
26000
30000
6000
158
#include<iostream>
#include<vector>
#include<map>
using namespace std;
typedef vector<int> VI;
int max_val = 99999999;
class Matrix {
public:
// Below function computes the minimum number of multiplications needed to
// find the product of the chain of matrices in bottom up fashion
int ChainMultiplication (vector<int> p) {
//Number of matrices
int N = p.size()-1;
// We want to return M[1][N] as the optimal cost of product of multiplying 1..N matrices
int M[N+1][N+1];
// Multiplications needed for a single matrix is 0
for (int i=1; i<=N; i++) {
M[i][i] = 0;
}
// Loop from chain len of size 2 upto the N (number of matrices)
for (int len=2; len <= N; len++) {
/* Chain of length 2 [A1,A2], [A2,A3] and [A3,A4]
Chain of length 3 [A1,A2,A3] and [A2,A3,A4]
Chain of length 4 [A1,A2,A3,A4]
For example: For a chain of length 2, i goes from 1 upto 3 and j goes from 2 upto 4
i j
A1 A2
A2 A3
A3 A4
*/
for (int i=1; i <= N-len+1; i++) {
int j = i+len-1;
M[i][j] = max_val;
for (int k=i; k<=j-1; k++) { // Since i<=k<j, k goes from i upto j-1;
int q = M[i][k] + M[k+1][j] + p[i-1]*p[k]*p[j];
if (M[i][j] > q){
M[i][j] = q;
}
}
}
}
return M[1][N];
}
};
int main(){
Matrix m;
vector<int> dimensions1 = {40, 20, 30, 10, 30}; // 26000
cout << m.ChainMultiplication (dimensions1) << endl;
vector<int> dimensions2 = {10, 20, 30, 40, 30}; // 30000
cout << m.ChainMultiplication (dimensions2) << endl;
vector<int> dimensions3 = {10, 20, 30}; // 6000
cout << m.ChainMultiplication (dimensions3) << endl;
vector<int> dimensions4 = {5, 4, 6, 2, 7}; // 158
cout << m.ChainMultiplication (dimensions4) << endl;
return 0;
};
Output
26000
30000
6000
158
class Matrix {
int max_val = 999999999;
public int ChainMultiplication (int p[]) {
//Number of matrices
int N = p.length - 1;
// We want to return M[1][N] as the optimal cost of product of multiplying 1..N matrices
int M[][] = new int[N+1][N+1];
// Multiplications needed for a single matrix is 0
for( int i = 0; i <= N; i++) {
M[i][i] = 0;
}
// Loop from chain len of size 2 upto the N (number of matrices)
for (int len=2; len <= N; len++) {
/* Chain of length 2 [A1,A2], [A2,A3] and [A3,A4]
Chain of length 3 [A1,A2,A3] and [A2,A3,A4]
Chain of length 4 [A1,A2,A3,A4]
For example: For a chain of length 2, i goes from 1 upto 3 and j goes from 2 upto 4
i j
A1 A2
A2 A3
A3 A4
*/
for (int i=1; i <= N-len+1; i++) {
int j = i+len-1;
M[i][j] = max_val;
for (int k = i; k <= j-1; k++) { // Since i<=k<j, k goes from i upto j-1;
int q = M[i][k] + M[k+1][j] + p[i-1] * p[k] * p[j];
if (M[i][j] > q){
M[i][j] = q;
}
}
}
}
return M[1][N];
}
public static void main (String[] args) {
Matrix m = new Matrix();
int[] dimensions1 = {40, 20, 30, 10, 30}; // 26000
System.out.println( m.ChainMultiplication (dimensions1) );
int[] dimensions2 = {10, 20, 30, 40, 30}; // 30000
System.out.println( m.ChainMultiplication (dimensions2) );
int[] dimensions3 = {10, 20, 30}; // 6000
System.out.println( m.ChainMultiplication (dimensions3) );
int[] dimensions4 = {5, 4, 6, 2, 7}; // 158
System.out.println( m.ChainMultiplication (dimensions4) );
}
}
Output
26000
30000
6000
158