-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmatrix_chain.py
More file actions
42 lines (37 loc) · 994 Bytes
/
matrix_chain.py
File metadata and controls
42 lines (37 loc) · 994 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import utils
# matrix chain multiplication
# let T[i,j] = cost of multiplying (Ai...Aj)
# = min_{i<=k<j} (rows(Ai)*cols(Ak)*cols(Aj) + T[i,k] + T[k+1,j])
# = min_{i<=k<j} (P(i-1)*P(k)*P(j) + T[i,k] + T[k+1,j])
# since matrix i is P(i-1)xP(i)
# also need S[i,j] = optimal k for multiplying (Ai...Aj)
def matrix_chain(p):
n=len(p)-1
T=matrix(n,n)
S=matrix(n,n)
for i in range(n-1,-1,-1):
for j in range(i+1,n):
T[i][j]=1e6 # hack
for k in range(i,j):
v=p[i]*p[k+1]*p[j+1] + T[i][k] + T[k+1][j] # shifted the indexes of P to account for zero-indexing on i,j etc.
if (v<T[i][j]):
T[i][j]=v
S[i][j]=k
return T,S
def print_matrix_chain(S,i,j):
if (i==j):
return " A_"+str(i)+" "
else:
s= "("
k=S[i][j]
s+=print_matrix_chain(S,i,k)
s+=print_matrix_chain(S,k+1,j)
s+= ")"
return s
def print_matrix_mult(S,i,j):
if (i == j-1):
print str(i)+" x "+str(j)
else:
k=S[i][j]
print_matrix_mult(S,i,k)
print_matrix_mult(S,k+1,j)