In: Computer Science
Divide and Conquer (Strassen’s Matrix Multiplication)
Given two square matrices A and B of size n x n each, find their multiplication matrix.
Naive Method
Following is a simple way to multiply two matrices.
void multiply(int A[][N], int B[][N], int C[][N]) { for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) { C[i][j] = 0; for (int k = 0; k < N; k++) { C[i][j] += A[i][k]*B[k][j]; } } } } |
Time Complexity of above method is O(N3).
Divide and Conquer
Following is simple Divide and Conquer method to multiply two
square matrices.
1) Divide matrices A and B in 4 sub-matrices of size N/2 x N/2 as
shown in the below diagram.
2) Calculate following values recursively. ae + bg, af + bh, ce +
dg and cf + dh.
Implement Strassen’s algorithm in Java / Python as indicated above.. in multiplying two square matrices of size n x n
Submit your solution online (BB), with the generated output
Code - Main.java
import java.util.Scanner;
public class Main
{
//functiont to multiply matrices
public int[][] multiplyMatrices(int[][] first, int[][]
second)
{ //get the size of matices
int size = first.length;
//this matrix it the result matrix that will store the result
int[][] resultMatrix = new int[size][size];
/** base case **/
if (size == 1)
resultMatrix[0][0] = first[0][0] * second[0][0];
else
{
//initialize matrices with size/2 row and cols
int[][] X11 = new int[size/2][size/2];
int[][] X12 = new int[size/2][size/2];
int[][] X21 = new int[size/2][size/2];
int[][] X22 = new int[size/2][size/2];
int[][] Y11 = new int[size/2][size/2];
int[][] Y12 = new int[size/2][size/2];
int[][] Y21 = new int[size/2][size/2];
int[][] Y22 = new int[size/2][size/2];
//this function divides the matrix into first 4 half
splitMatrices(first, X11, 0 , 0);
splitMatrices(first, X12, 0 , size/2);
splitMatrices(first, X21, size/2, 0);
splitMatrices(first, X22, size/2, size/2);
//this function divides the matrix into second 4 half
splitMatrices(second, Y11, 0 , 0);
splitMatrices(second, Y12, 0 , size/2);
splitMatrices(second, Y21, size/2, 0);
splitMatrices(second, Y22, size/2, size/2);
//perform matrix operation
int [][] M1 = multiplyMatrices(addMatrices(X11, X22),
addMatrices(Y11, Y22));
int [][] M2 = multiplyMatrices(addMatrices(X21, X22), Y11);
int [][] M3 = multiplyMatrices(X11, subMatrices(Y12, Y22));
int [][] M4 = multiplyMatrices(X22, subMatrices(Y21, Y11));
int [][] M5 = multiplyMatrices(addMatrices(X11, X12), Y22);
int [][] M6 = multiplyMatrices(subMatrices(X21, X11),
addMatrices(Y11, Y12));
int [][] M7 = multiplyMatrices(subMatrices(X12, X22),
addMatrices(Y21, Y22));
//add the following matrices above M1 M2 M3 M4 M5 M6 into z11
z12 z21 z22
int [][] Z11 = addMatrices(subMatrices(addMatrices(M1, M4), M5),
M7);
int [][] Z12 = addMatrices(M3, M5);
int [][] Z21 = addMatrices(M2, M4);
int [][] Z22 = addMatrices(subMatrices(addMatrices(M1, M3), M2),
M6);
//JOIN THE MATRIX at the last
joinMatrices(Z11, resultMatrix, 0 , 0);
joinMatrices(Z12, resultMatrix, 0 , size/2);
joinMatrices(Z21, resultMatrix, size/2, 0);
joinMatrices(Z22, resultMatrix, size/2, size/2);
}
/** return result **/
return resultMatrix;
}
//function to subtract 2 matrices
public int[][] subMatrices(int[][] first, int[][] second)
{
int size = first.length;
int[][] mat = new int[size][size];
for (int i = 0; i < size; i++)
for (int j = 0; j < size; j++)
mat[i][j] = first[i][j] - second[i][j];
return mat;
}
//function to add 2 matrices
public int[][] addMatrices(int[][] first, int[][] second)
{
int size = first.length;
int[][] mat = new int[size][size];
for (int i = 0; i < size; i++)
for (int j = 0; j < size; j++)
mat[i][j] = first[i][j] + second[i][j];
return mat;
}
//function to split 2 matrices
public void splitMatrices(int[][] P, int[][] mat, int iB, int
jB)
{
for(int i1 = 0, i2 = iB; i1 < mat.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < mat.length; j1++, j2++)
mat[i1][j1] = P[i2][j2];
}
//function to join 2 matrices
public void joinMatrices(int[][] mat, int[][] P, int iB, int
jB)
{
for(int i1 = 0, i2 = iB; i1 < mat.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < mat.length; j1++, j2++)
P[i2][j2] = mat[i1][j1];
}
public static void main (String[] args)
{
Scanner scan = new Scanner(System.in);
System.out.println("Strassen Matrix Multiplication ");
//create object of main class
Main s = new Main();
System.out.println("Enter size of matrix ");
int N = scan.nextInt();
//ask user to enter first matrices
System.out.println("Enter matrix first : ");
int[][] first = new int[N][N];
for (int i = 0; i < N; i++)
for (int j = 0; j < N; j++)
first[i][j] = scan.nextInt();
//ask user to enter second matrices
System.out.println("Enter matrix second : ");
int[][] second = new int[N][N];
for (int i = 0; i < N; i++)
for (int j = 0; j < N; j++)
second[i][j] = scan.nextInt();
int[][] mat = s.multiplyMatrices(first, second);
//print the multiplication matix
System.out.println("\nMultiplication of first and second matrices
is : ");
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
System.out.print(mat[i][j] +" ");
System.out.println();
}
}
}
Screenshots -