Skip to content

Commit 6023b45

Browse files
authored
Add Strassen Matrix Multiplication (TheAlgorithms#2490)
1 parent cdbcb5e commit 6023b45

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Java Program to Implement Strassen Algorithm
2+
3+
// Class Strassen matrix multiplication
4+
public class StrassenMatrixMultiplication {
5+
6+
// Method 1
7+
// Function to multiply matrices
8+
public int[][] multiply(int[][] A, int[][] B)
9+
{
10+
int n = A.length;
11+
12+
int[][] R = new int[n][n];
13+
14+
if (n == 1)
15+
16+
R[0][0] = A[0][0] * B[0][0];
17+
18+
else {
19+
// Dividing Matrix into parts
20+
// by storing sub-parts to variables
21+
int[][] A11 = new int[n / 2][n / 2];
22+
int[][] A12 = new int[n / 2][n / 2];
23+
int[][] A21 = new int[n / 2][n / 2];
24+
int[][] A22 = new int[n / 2][n / 2];
25+
int[][] B11 = new int[n / 2][n / 2];
26+
int[][] B12 = new int[n / 2][n / 2];
27+
int[][] B21 = new int[n / 2][n / 2];
28+
int[][] B22 = new int[n / 2][n / 2];
29+
30+
// Dividing matrix A into 4 parts
31+
split(A, A11, 0, 0);
32+
split(A, A12, 0, n / 2);
33+
split(A, A21, n / 2, 0);
34+
split(A, A22, n / 2, n / 2);
35+
36+
// Dividing matrix B into 4 parts
37+
split(B, B11, 0, 0);
38+
split(B, B12, 0, n / 2);
39+
split(B, B21, n / 2, 0);
40+
split(B, B22, n / 2, n / 2);
41+
42+
// Using Formulas as described in algorithm
43+
44+
// M1:=(A1+A3)×(B1+B2)
45+
int[][] M1
46+
= multiply(add(A11, A22), add(B11, B22));
47+
48+
// M2:=(A2+A4)×(B3+B4)
49+
int[][] M2 = multiply(add(A21, A22), B11);
50+
51+
// M3:=(A1−A4)×(B1+A4)
52+
int[][] M3 = multiply(A11, sub(B12, B22));
53+
54+
// M4:=A1×(B2−B4)
55+
int[][] M4 = multiply(A22, sub(B21, B11));
56+
57+
// M5:=(A3+A4)×(B1)
58+
int[][] M5 = multiply(add(A11, A12), B22);
59+
60+
// M6:=(A1+A2)×(B4)
61+
int[][] M6
62+
= multiply(sub(A21, A11), add(B11, B12));
63+
64+
// M7:=A4×(B3−B1)
65+
int[][] M7
66+
= multiply(sub(A12, A22), add(B21, B22));
67+
68+
// P:=M2+M3−M6−M7
69+
int[][] C11 = add(sub(add(M1, M4), M5), M7);
70+
71+
// Q:=M4+M6
72+
int[][] C12 = add(M3, M5);
73+
74+
// R:=M5+M7
75+
int[][] C21 = add(M2, M4);
76+
77+
// S:=M1−M3−M4−M5
78+
int[][] C22 = add(sub(add(M1, M3), M2), M6);
79+
80+
join(C11, R, 0, 0);
81+
join(C12, R, 0, n / 2);
82+
join(C21, R, n / 2, 0);
83+
join(C22, R, n / 2, n / 2);
84+
}
85+
86+
return R;
87+
}
88+
89+
// Method 2
90+
// Function to subtract two matrices
91+
public int[][] sub(int[][] A, int[][] B)
92+
{
93+
int n = A.length;
94+
95+
int[][] C = new int[n][n];
96+
97+
for (int i = 0; i < n; i++)
98+
for (int j = 0; j < n; j++)
99+
C[i][j] = A[i][j] - B[i][j];
100+
101+
return C;
102+
}
103+
104+
// Method 3
105+
// Function to add two matrices
106+
public int[][] add(int[][] A, int[][] B)
107+
{
108+
109+
int n = A.length;
110+
111+
int[][] C = new int[n][n];
112+
113+
for (int i = 0; i < n; i++)
114+
for (int j = 0; j < n; j++)
115+
C[i][j] = A[i][j] + B[i][j];
116+
117+
return C;
118+
}
119+
120+
// Method 4
121+
// Function to split parent matrix
122+
// into child matrices
123+
public void split(int[][] P, int[][] C, int iB, int jB)
124+
{
125+
for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
126+
for (int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
127+
C[i1][j1] = P[i2][j2];
128+
}
129+
130+
// Method 5
131+
// Function to join child matrices
132+
// into (to) parent matrix
133+
public void join(int[][] C, int[][] P, int iB, int jB)
134+
135+
{
136+
for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
137+
for (int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
138+
P[i2][j2] = C[i1][j1];
139+
}
140+
141+
// Method 5
142+
// Main driver method
143+
public static void main(String[] args)
144+
{
145+
System.out.println("Strassen Multiplication Algorithm Implementation For Matrix Multiplication :\n");
146+
147+
StrassenMatrixMultiplication s = new StrassenMatrixMultiplication();
148+
149+
// Size of matrix
150+
// Considering size as 4 in order to illustrate
151+
int N = 4;
152+
153+
// Matrix A
154+
// Custom input to matrix
155+
int[][] A = { { 1, 2, 5, 4 },
156+
{ 9, 3, 0, 6 },
157+
{ 4, 6, 3, 1 },
158+
{ 0, 2, 0, 6 } };
159+
160+
// Matrix B
161+
// Custom input to matrix
162+
int[][] B = { { 1, 0, 4, 1 },
163+
{ 1, 2, 0, 2 },
164+
{ 0, 3, 1, 3 },
165+
{ 1, 8, 1, 2 } };
166+
167+
// Matrix C computations
168+
169+
// Matrix C calling method to get Result
170+
int[][] C = s.multiply(A, B);
171+
172+
System.out.println("\nProduct of matrices A and B : ");
173+
174+
// Print the output
175+
for (int i = 0; i < N; i++) {
176+
for (int j = 0; j < N; j++)
177+
System.out.print(C[i][j] + " ");
178+
System.out.println();
179+
}
180+
}
181+
}

0 commit comments

Comments
 (0)