Saturday, 27 August 2016

Trying to code Strassen's Matrix Multiplication Algorithm (JAVA)

Originally beneficial to reduce the running time 0(n3) of normal matrix multiplication. Strassen's Matrix Multiplication is really useful and produces a recurrence

T(n) = 7T(n/2) + 0(n2)

which is yet to be learnt to solve.

Right now I have managed to do this much. Will hopefully upload the complete code. Right now I have accomplished making the intermediate submatrices.

The Steps involved in Strassen's Matrix Multiplication are:

1. Divide the input matrices A and B and Output matrix C into n/2 x n/2 submatrices. This step takes     0(1) time (We'll not make new arrays but will do the index calculations to search for appropriate         elements).
2. Create 10 matrices S1,S1,....S10, each of which is n/2xn/2.

Since I have just reached until this step, I am going to share this much.
The matrices would be the following:

S1=B12-B22
S2=A11+A12
S3=A21+A22
S4=B21-B11
S5=A11+A22
S6=B11+B22
S7=A12-A22
S8=B21+B22
S9=A11-A21
S10=B11+B12

Here is the code:



package javaapplication1;
import java.util.*;
public class StrassensMatrixMultiply {

    public static void main(String[] args) {
        Scanner obj=new Scanner(System.in);
        System.out.println("Enter the order of Matrices you wanna multiply.<Note - Enter only a power of 2> : ");
        int n=obj.nextInt();
        int A[][]=new int[n][n];
        int B[][]=new int[n][n];
       
        System.out.println("Enter the elements of the 1st Matrix: ");
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                A[i][j]=obj.nextInt();
            }
        }
       
        System.out.println("Enter the elements of the 2nd Matrix: ");
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                B[i][j]=obj.nextInt();
            }
        }
       
        int s1[][]=new int[n/2][n/2];
        int s2[][]=new int[n/2][n/2];
        int s3[][]=new int[n/2][n/2];
        int s4[][]=new int[n/2][n/2];
        int s5[][]=new int[n/2][n/2];
        int s6[][]=new int[n/2][n/2];
        int s7[][]=new int[n/2][n/2];
        int s8[][]=new int[n/2][n/2];
        int s9[][]=new int[n/2][n/2];
        int s10[][]=new int[n/2][n/2];
       
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s1[i][j]=B[i][j+(n/2)]-B[i+(n/2)][j+(n/2)];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s2[i][j]=A[i][j]+A[i][j+(n/2)];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s3[i][j]=A[i+(n/2)][j]+A[i+(n/2)][j+(n/2)];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s4[i][j]=B[i+(n/2)][j]-B[i][j];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s5[i][j]=A[i][j]+A[i+(n/2)][j+(n/2)];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s6[i][j]=B[i][j]+B[i+(n/2)][j+(n/2)];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s7[i][j]=A[i][j+(n/2)]-A[i+(n/2)][j+(n/2)];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s8[i][j]=B[i+(n/2)][j]+B[i+(n/2)][j+(n/2)];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s9[i][j]=A[i][j]-A[i+(n/2)][j];
            }
        }
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                s10[i][j]=B[i][j]+B[i][j+(n/2)];
            }
        }
       
       
       
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                int sum=0;                                       //Here I am just trying to perform the multiplication,
                for(int k=0;k<n/2;k++){                 //which is just the next step and will be upfdating soon.
                    sum=sum+s1[i][k]*A[k][j];
                }
                s1[i][j]=sum;
            }
        }
       
        for(int i=0;i<n/2;i++){
            for(int j=0;j<n/2;j++){
                System.out.println(s1[i][j]);
            }
        }
       
       
       
    }
   
}



No comments:

Post a Comment