#!/usr/bin/env python3 # Author: Leonardo Tamiano def matrix_product(a, b): """ Calculates the product of matrices a and b :param a: matrix :param b: matrix :return: matrix obtained by the product a*b """ rows_a, column_a, rows_b, column_b = len(a), len(a[0]), len(b), len(b[0]) if column_a != rows_b: # matrix product is not defined return None else: return [[sum([a[i][k] * b[k][j] for k in range(0, column_a)]) for j in range(0, column_b)] for i in range(0, rows_a)] def chain_matrix_dynamic(dims, n): """ Calculates the minimum number of matrix multiplications needed to multiply n matrices by solving the following sub_problems: opt[i][j] := least number of multiplications needed to multiply matrices A_i, ..., A_j value[i][j] := best splitting value k for multiplying A_i, ...., A_j. A splitting value if the index of the matrix such that A_i, ..., A_k are multiplied together, as well as A_k+1, ..., A_j, and the total # of multiplications done is the least possible. :param dims: dimensions of the various matrices, matrix i has dimension dims[i-1], dims[i] :param n: number of matrices to multiply :return: dynamic table containing solution to sub_problem value[i][j] """ opt = [[0 for _ in range(0, n)] for _ in range(0, n)] value = [[0 for _ in range(0, n)] for _ in range(0, n)] for d in range(1, n): for i in range(0, n - d): j = i + d # Find the OPT opt[i][j] = min([opt[i][k] + opt[k+1][j] + dims[i]*dims[k + 1]*dims[j + 1] for k in range(i, j)]) # Save the value of K # # TODO: is there a way to save the splitting values that # doesn't require memorizing a full table? for k in range(i, j): if opt[i][j] == opt[i][k] + opt[k+1][j] + dims[i]*dims[k + 1]*dims[j + 1]: value[i][j] = k return value def chain_matrix_multiplication(matrices, dims): """ Multiplies a list of matrices together, using the least number of multiplications. :param matrices: list of matrices to be multiplied together :param dims: dimensions of the matrices, the matrice at position i has dimensions dims[i-1], dims[i]. :return: the matrix obtained by multiplying all the matrices in the list matrices """ opt_split_table = chain_matrix_dynamic(dims, len(matrices)) return chain_matrix_multiplication_recursive(matrices, opt_split_table, 0, len(matrices) - 1) def chain_matrix_multiplication_recursive(matrices, opt_split_table, i, j): """ Multiplies the matrices contained in the sub_list matrices[i:j] in a recursive manner, by using the opt_split_table to do the last amount of total multiplications. :param matrices: matrices to be multipled together :param opt_split_table: table tells the best way split the multiplication between matrices A_i, ..., A_j :param i: left-most index :param j: right-most index :return: product of the matrices in matrices[i:j] """ if i == j: return matrices[i] elif j == i + 1: return matrix_product(matrices[i], matrices[j]) else: split_value = opt_split_table[i][j] m1 = chain_matrix_multiplication_recursive(matrices, opt_split_table, i, split_value) m2 = chain_matrix_multiplication_recursive(matrices, opt_split_table, split_value + 1, j) return matrix_product(m1, m2)