-
-
Notifications
You must be signed in to change notification settings - Fork 46.8k
Add matrix_multiplication #10045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add matrix_multiplication #10045
Changes from all commits
7672d1a
642af23
5e6100c
3819098
5a5a84d
2cbbe3a
f233a54
19a7879
d8040f5
6c5b418
ea541e6
86dc5a9
de92d3d
8325b64
e6fda58
cccf2a9
5f9331a
35f8b7a
aa0da39
d451d62
dabd6c9
4a118f3
39b3363
8656afd
552a873
641c063
2fd6024
667055a
0d83b21
2af89bc
b2e2308
1c72a65
0ce40a2
34942b3
891da5e
f52f3b5
24d265b
435efe6
e36fa6a
84c0273
9550244
3718979
e30c558
5824fcb
1bcbd84
532e385
b1b6219
84adad8
44cd487
ab314c6
2cf9090
2e4151c
9765fea
fbe1dfc
8bd1970
2502119
37a5dab
909331b
e05b524
39d0871
5203d8b
620ba2c
c6c5cfd
aa7e917
8bfeeb2
bf6d95a
e42fafc
d87b51f
ac92bef
1422b87
21f6956
f4db7d9
cc0b4aa
22c19e3
d4e7c77
057a2b4
7340528
8ef6244
9d8c6f1
9a486c9
306bba0
09674b1
09aca9f
7b31464
46837d0
87e37dd
5a044dd
d7fc696
5650d8b
a3052c2
7d28a5c
b72b27f
25131f6
278149b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# @Author : ojas-wani | ||
# @File : matrix_multiplication_recursion.py | ||
# @Date : 10/06/2023 | ||
|
||
|
||
""" | ||
Perform matrix multiplication using a recursive algorithm. | ||
https://en.wikipedia.org/wiki/Matrix_multiplication | ||
""" | ||
# type Matrix = list[list[int]] # psf/black currenttly fails on this line | ||
Matrix = list[list[int]] | ||
cclauss marked this conversation as resolved.
Show resolved
Hide resolved
cclauss marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Variable and function names should follow the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Variable and function names should follow the |
||
|
||
matrix_1_to_4 = [ | ||
[1, 2], | ||
[3, 4], | ||
] | ||
|
||
matrix_5_to_8 = [ | ||
[5, 6], | ||
[7, 8], | ||
] | ||
|
||
matrix_5_to_9_high = [ | ||
[5, 6], | ||
[7, 8], | ||
[9], | ||
] | ||
|
||
matrix_5_to_9_wide = [ | ||
[5, 6], | ||
[7, 8, 9], | ||
] | ||
|
||
matrix_count_up = [ | ||
[1, 2, 3, 4], | ||
[5, 6, 7, 8], | ||
[9, 10, 11, 12], | ||
[13, 14, 15, 16], | ||
] | ||
|
||
matrix_unordered = [ | ||
[5, 8, 1, 2], | ||
[6, 7, 3, 0], | ||
[4, 5, 9, 1], | ||
[2, 6, 10, 14], | ||
] | ||
matrices = ( | ||
matrix_1_to_4, | ||
matrix_5_to_8, | ||
matrix_5_to_9_high, | ||
matrix_5_to_9_wide, | ||
matrix_count_up, | ||
matrix_unordered, | ||
) | ||
|
||
|
||
def is_square(matrix: Matrix) -> bool: | ||
cclauss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
>>> is_square([]) | ||
True | ||
>>> is_square(matrix_1_to_4) | ||
True | ||
>>> is_square(matrix_5_to_9_high) | ||
False | ||
""" | ||
len_matrix = len(matrix) | ||
return all(len(row) == len_matrix for row in matrix) | ||
|
||
|
||
def matrix_multiply(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: | ||
cclauss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
>>> matrix_multiply(matrix_1_to_4, matrix_5_to_8) | ||
[[19, 22], [43, 50]] | ||
""" | ||
return [ | ||
[sum(a * b for a, b in zip(row, col)) for col in zip(*matrix_b)] | ||
for row in matrix_a | ||
] | ||
|
||
|
||
def matrix_multiply_recursive(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: | ||
""" | ||
:param matrix_a: A square Matrix. | ||
:param matrix_b: Another square Matrix with the same dimensions as matrix_a. | ||
:return: Result of matrix_a * matrix_b. | ||
:raises ValueError: If the matrices cannot be multiplied. | ||
|
||
>>> matrix_multiply_recursive([], []) | ||
[] | ||
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_8) | ||
[[19, 22], [43, 50]] | ||
>>> matrix_multiply_recursive(matrix_count_up, matrix_unordered) | ||
[[37, 61, 74, 61], [105, 165, 166, 129], [173, 269, 258, 197], [241, 373, 350, 265]] | ||
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_9_wide) | ||
Traceback (most recent call last): | ||
... | ||
ValueError: Invalid matrix dimensions | ||
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_9_high) | ||
Traceback (most recent call last): | ||
... | ||
ValueError: Invalid matrix dimensions | ||
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_count_up) | ||
Traceback (most recent call last): | ||
... | ||
ValueError: Invalid matrix dimensions | ||
""" | ||
if not matrix_a or not matrix_b: | ||
return [] | ||
if not all( | ||
(len(matrix_a) == len(matrix_b), is_square(matrix_a), is_square(matrix_b)) | ||
): | ||
raise ValueError("Invalid matrix dimensions") | ||
|
||
# Initialize the result matrix with zeros | ||
result = [[0] * len(matrix_b[0]) for _ in range(len(matrix_a))] | ||
|
||
# Recursive multiplication of matrices | ||
def multiply( | ||
cclauss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
i_loop: int, | ||
j_loop: int, | ||
k_loop: int, | ||
matrix_a: Matrix, | ||
matrix_b: Matrix, | ||
result: Matrix, | ||
) -> None: | ||
""" | ||
:param matrix_a: A square Matrix. | ||
:param matrix_b: Another square Matrix with the same dimensions as matrix_a. | ||
:param result: Result matrix | ||
:param i: Index used for iteration during multiplication. | ||
:param j: Index used for iteration during multiplication. | ||
:param k: Index used for iteration during multiplication. | ||
>>> 0 > 1 # Doctests in inner functions are never run | ||
True | ||
""" | ||
if i_loop >= len(matrix_a): | ||
return | ||
if j_loop >= len(matrix_b[0]): | ||
return multiply(i_loop + 1, 0, 0, matrix_a, matrix_b, result) | ||
if k_loop >= len(matrix_b): | ||
return multiply(i_loop, j_loop + 1, 0, matrix_a, matrix_b, result) | ||
result[i_loop][j_loop] += matrix_a[i_loop][k_loop] * matrix_b[k_loop][j_loop] | ||
return multiply(i_loop, j_loop, k_loop + 1, matrix_a, matrix_b, result) | ||
|
||
# Perform the recursive matrix multiplication | ||
multiply(0, 0, 0, matrix_a, matrix_b, result) | ||
return result | ||
|
||
|
||
if __name__ == "__main__": | ||
from doctest import testmod | ||
|
||
failure_count, test_count = testmod() | ||
if not failure_count: | ||
matrix_a = matrices[0] | ||
for matrix_b in matrices[1:]: | ||
print("Multiplying:") | ||
for row in matrix_a: | ||
print(row) | ||
print("By:") | ||
for row in matrix_b: | ||
print(row) | ||
print("Result:") | ||
try: | ||
result = matrix_multiply_recursive(matrix_a, matrix_b) | ||
for row in result: | ||
print(row) | ||
assert result == matrix_multiply(matrix_a, matrix_b) | ||
except ValueError as e: | ||
print(f"{e!r}") | ||
print() | ||
matrix_a = matrix_b | ||
|
||
print("Benchmark:") | ||
from functools import partial | ||
from timeit import timeit | ||
|
||
mytimeit = partial(timeit, globals=globals(), number=100_000) | ||
for func in ("matrix_multiply", "matrix_multiply_recursive"): | ||
print(f"{func:>25}(): {mytimeit(f'{func}(matrix_count_up, matrix_unordered)')}") |
Uh oh!
There was an error while loading. Please reload this page.