Matrix Chain Multiplication Using Dynamic Programming

Efficient way to multiply a chain of matrices

The idea of this algorithm is to find the minimum number of multiplication operations needed to multiply a chain of matrices. Because matrix multiplication is associative there can be more than one way of multiplying the chain of matrices and each way could incur a different number of multiplication operations.

Consider a chain of matrices: ( A1, A2, A3 )
Ways of multiplying this chain of A1, A2 & A3 is as below…
( ( A1 . A2 ) . A3 )
( A1 . ( A2 . A3 ) )

Note : Matrix A1 can be multiplied with Matrix A2 only if number of columns of A1 is equal to number of rows of A2. Similarly A2 can be multiplied with A3 if number of columns in A2 is same as the number of rows in A3.


Example of Matrix Chain Multiplication

Matrices : A1 dimensions: ( 3 * 5 ) , A2 dimensions: ( 5 * 4 ) and A3 dimensions: ( 4 * 6 )
Option 1 : ( ( A1 . A2 ) . A3 ) = ( ( 3 * 5 ) . ( 5 * 4 ) ) . ( 4 * 6 )
Option 2 : ( A1 . ( A2 . A3 ) ) = ( 3 * 5 ) . ( ( 5 * 4 ) . ( 4 * 6) )

Steps Option 1 Option 2
1 Multiplication operations ( 3 * 5 ) . ( 5 * 4 ) = 3 . 5 . 4 = 60 Multiplication operations ( 5 * 4 ) . ( 4 * 6 ) = 5 . 4 . 6 = 120
2 Resultant matrix ( 3 * 5 ) . ( 5 * 4 ) = ( 3 . 4 ) Resultant matrix ( 5 * 4 ) . ( 4 * 6 ) = ( 5 . 6 )
3 Multiplication operations ( 3 * 4 ) . ( 4 * 6 ) = 3 . 4 . 6 = 72 Multiplication operations ( 3 * 5 ) . ( 5 * 6 ) = 3 . 5 . 6 = 90
4 Resultant matrix ( 3 * 4 ) . ( 4 * 6 ) = ( 3 . 6 ) Resultant matrix ( 3 * 5 ) . ( 5 * 6 ) = ( 3 . 6 )
5 Total Operations = 60 + 72 = 132 Total Operations = 120 + 90 = 210

Option 1 is clearly efficient than Option 2.

Generalizing based on the above example, consider

Matrices Dimensions
A1 (Rows) P0 X (Columns) P1
A2 (Rows) P1 X (Columns) P2
A3 (Rows) P2 X (Columns) P3
AN (Rows) PN-1 X (Columns) PN
Let M[1, N] represents the minimum number of multiplications needed for computing the product of A1, A2, ..., AN.

M[1, N] = minimum of ( M[1, 1] + M[2, N] + P0 . P1 . PN ,
                       M[1, 2] + M[3, N] + P0 . P2 . PN ,
                       M[1, 3] + M[4, N] + P0 . P3 . PN ,
                       ... 
                       M[1, N-1] + M[N, N] + P0 . PN-1. PN )

M[1, 1] + M[2, N] + P0 . P1 . PN indicates A1 . ( A2 . A3 … AN )
The multiplication A1 . ( A2 . A3 … AN ) cancels out P2, P3, P4, …, PN-1.
M[1, 2] + M[3, N] + P0 . P2 . PN indicates ( A1 . A2 ) . ( A3 … AN )
The multiplication ( A1 . A2 ) . ( A3 … AN ) cancels out P3, P4, …, PN-1.

M[1, N-1] + M[N, N] + P0 . PN-1 . PN indicates ( A1 . A2 . A3 . AN-1 ) . AN
The multiplication ( A1 . A2 . A3 . AN-1 ) . AN cancels out P1, P2, …, PN-2.

Thus we can generalize
For k from i upto j-1
            M[ i, j ] = min ( M[ i, k ] + M[ k+1, j ] + P[ i-1 ] . P[ k ] . P[ j ] )

Below is an example of bottom up calculations for finding the minimum number of multiplication operations needed for multiplying the matrices Matrix Chain Calculations Number of multiplications needed for matrices chain of length 1 is 0.

M[1,1] = 0, M[2,2] = 0, M[3,3] = 0, M[4,4] = 0

Finding the least number of multiplication needed for matrices chain of length 2

for(1<=k<2)  M[1,2]  = min( M[1,1] + M[2,2] + P0.P1.P2 i.e 0 + 0 + 5.4.6 = 120 ) = min(120) = 120 
for(2<=k=2) M[2,3] = min( M[2,2] + M[3,3] + P1.P2.P3 i.e 0 + 0 + 4.6.2 = 48 ) = min(48) = 48
for(3<=k<4) M[3,4] = min( M[3,3] + M[4,4] + P2.P3.P4 i.e 0 + 0 + 6.2.7 = 84 ) = min(84) = 84

Finding the least number of multiplication needed for matrices chain of length 3

for(1<=k<3)  M[1,3]  = min( ( M[1,1] + M[2,3] + P0.P1.P3 ) i.e 0 + 48 + 5.4.2 = 88 ,
( M[1,2] + M[3,3] + P0.P1.P3 ) i.e 120 + 0 + 6.5.2 = 180 ) = min(88, 180) = 88
for(2<=k<4) M[2,4] = min( ( M[2,2] + M[3,4] + P1.P2.P4 ) i.e 0 + 84 + 4.6.7 = 252,
( M[2,3] + M[4,4] + P1.P3.P4 ) i.e 48 + 0 + 4.2.7 = 104 ) = min(252, 104) = 104

Finding the least number of multiplication needed for matrices chain of length 4

for(1<=k<4)  M[1,4]  = min( ( M[1,1] + M[2,4] + P0.P1.P4 ) i.e 0 + 104 + 5.4.7 = 244,
( M[1,2] + M[3,4] + P0.P2.P4 ) i.e 120 + 84 + 5.6.7 = 414,
( M[1,3] + M[4,4] + P0.P3.P4 ) i.e 88 + 0 + 5.2.7 = 158 ) = min(244, 414, 158) = 158

Time complexity of matrix chain multiplication : O(n^3), where n is the number of matrices.



Python

Python : Matrix chain multiplication in Python


C++

C++ : Matrix chain multiplication in C++


Java : Matrix Chain Multiplication in Java

class Matrix {

    int max_val = 999999999;

    public int MatrixChainMultiplication (int p[]) {

        //Number of matrices
        int N = p.length - 1;

        // We want to return M[1][N] as the optimal cost of product of multiplying 1..N matrices
        int M[][] = new int[N+1][N+1];


        // Multiplications needed for a single matrix is 0
        for( int i = 0; i <= N; i++) {
            M[i][i] = 0;
        }

        // Loop from chain len of size 2 upto the N (number of matrices)
        for (int len=2; len <= N; len++) {

            /* Chain of length 2 [A1,A2], [A2,A3] and [A3,A4]
               Chain of length 3 [A1,A2,A3] and [A2,A3,A4]
               Chain of length 4 [A1,A2,A3,A4]
               For example: For a chain of length 2, i goes from 1 upto 3 and j goes from 2 upto 4
               i  j
               A1 A2
               A2 A3
               A3 A4
            */
            for (int i=1; i <= N-len+1; i++) {
                int j = i+len-1;
                M[i][j] = max_val;
                for (int k = i; k <= j-1; k++) { // Since i<=k<j, k goes from i upto j-1;
                    int q = M[i][k] + M[k+1][j] + p[i-1] * p[k] * p[j];
                    if (M[i][j] > q){
                        M[i][j] = q;
                    }
                }
            }
        }

        return M[1][N];
    }

    public static void main(String args[]) {

        Matrix m = new Matrix();

        int[] dimensions1 = {40, 20, 30, 10, 30}; // 26000
        System.out.println( m.MatrixChainMultiplication (dimensions1) );

        int[] dimensions2 =  {10, 20, 30, 40, 30}; // 30000
        System.out.println( m.MatrixChainMultiplication (dimensions2) );

        int[] dimensions3 =  {10, 20, 30}; // 6000
        System.out.println( m.MatrixChainMultiplication (dimensions3) );

        int[] dimensions4 =  {5, 4, 6, 2, 7}; // 158
        System.out.println( m.MatrixChainMultiplication (dimensions4) );
    }
}

Output

26000
30000
6000
158

Copyright (c) 2019-2021, Algotree.org.
All rights reserved.