Skip to content

Commit 8b5fa2d

Browse files
author
Hamid Gasmi
committed
#183 is completed: DP + Buttom-up + backtracking
1 parent 51bdf43 commit 8b5fa2d

File tree

13 files changed

+184
-0
lines changed

13 files changed

+184
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import sys
2+
import math
3+
4+
# 1. Express a solution mathematically:
5+
# We know: Pr(x | P ) = Product(i=1..n, Pr(xi|Pi)) = Product(i=1..n, emission(Pi, xi))
6+
# Pr(P) = Product(i=1..n, Pr(Pi-1 --> Pi)) = Product(i=1..n, transition(Pi-1, Pi))
7+
# Then Pr(P, x) = Pr(x | P ) * Pr(P) = Product(i=1..n, emission(Pi, xi)) * Product(i=1..n, transition(Pi-1, Pi))
8+
# So, Pr(optimal P, x) = Max Pr(x | P ) * Pr(P) = Product(i=1..n, emission(Pi, xi)) * Product(i=1..n, transition(Pi-1, Pi))
9+
# 2. Proof:
10+
# 3. Solutions: we can solve this problem wit a DP technique + Bottom up solution:
11+
# 1. We build the corresponding graph matrix (|states| . |x|)
12+
# - M[state_i,p] = max(M[state_j, p - 1] * transition[state_j][state_i] * emission[state_i][xp]
13+
# - To avoid the risk of stackoverflow error, we'll use logarithm:
14+
# - log(M[state_i,p]) = max(M[state_j, p - 1] + log(transition[state_j][state_i]) + log(emission[state_i][xp]))
15+
# 2. From the matrix, M, we'll use the backtracking approach to build the hidden path
16+
def build_optimal_path_backtrack(x, states, transition, emission, M):
17+
18+
optimal_hidden_path = ['' for _ in range(len(x))]
19+
20+
opt_state = util_column_max_value_index(M, len(states), len(x) - 1)
21+
optimal_hidden_path[len(x) - 1] = states[opt_state]
22+
for p in range(len(x) - 1, 0, -1):
23+
curr_x = x[p]
24+
25+
for s in range(len(states)):
26+
if M[opt_state][p] == M[s][p - 1] + util_sum_log_vals((transition[ states[s] ][ states[opt_state] ], emission[ states[opt_state] ][ curr_x ])):
27+
opt_state = s
28+
break
29+
optimal_hidden_path[p - 1] = states[opt_state]
30+
31+
return optimal_hidden_path
32+
33+
def optimal_path(x, sigma, states, transition, emission):
34+
35+
# The transitions from the initial state occur with equal probability: Prob(P0 = Si) = 1/|states| * Pr(x = x0 | P0)
36+
M = [ [util_sum_log_vals((1/len(states), emission[ states[s] ][ x[0] ])) if p == 0 else 0 for p in range(len(x))] for s in range(len(states)) ]
37+
38+
for p in range(len(x) - 1):
39+
next_x = x[p + 1]
40+
41+
for s in range(len(states)):
42+
M[s][p + 1] = -math.inf if M[0][p] == -math.inf else M[0][p] + util_sum_log_vals((transition[ states[0] ][ states[s] ], emission[ states[s] ][ next_x ]))
43+
44+
for s in range(1, len(states)):
45+
for next_s in range(len(states)):
46+
if M[s][p] == - math.inf:
47+
continue
48+
49+
candidate_probability = M[s][p] + util_sum_log_vals((transition[ states[s] ][ states[next_s] ], emission[ states[next_s] ][ next_x ]))
50+
if candidate_probability > M[next_s][p + 1]:
51+
M[next_s][p + 1] = candidate_probability
52+
53+
optimal_hidden_path = build_optimal_path_backtrack(x, states, transition, emission, M)
54+
55+
return ''.join(optimal_hidden_path)
56+
57+
def util_sum_log_vals(vals):
58+
sum_log = 0
59+
for val in vals:
60+
assert(val >= 0)
61+
if val == 0:
62+
sum_log = - math.inf
63+
break
64+
else:
65+
sum_log += math.log(val)
66+
67+
return sum_log
68+
69+
def util_column_max_value_index(matrix, rows_count, c):
70+
assert(rows_count > 0)
71+
assert(c < len(matrix[0]))
72+
73+
row_max_value = 0
74+
for r in range(1, rows_count):
75+
if matrix[r][c] > matrix[row_max_value][c]:
76+
row_max_value = r
77+
78+
return row_max_value
79+
80+
if __name__ == "__main__":
81+
x = sys.stdin.readline().strip()
82+
sys.stdin.readline() # delimiter
83+
84+
sigma = sys.stdin.readline().strip().split()
85+
sys.stdin.readline() # delimiter
86+
87+
states = sys.stdin.readline().strip().split()
88+
sys.stdin.readline() # delimiter
89+
90+
chars = sys.stdin.readline().strip().split()
91+
transition = [sys.stdin.readline().strip().split() for _ in range(len(states))]
92+
transition = {line[0]:dict(zip(chars, map(float,line[1:]))) for line in transition}
93+
sys.stdin.readline() # delimiter
94+
95+
chars = sys.stdin.readline().strip().split()
96+
emission = [sys.stdin.readline().strip().split() for _ in range(len(states))]
97+
emission = {line[0]:dict(zip(chars, map(float,line[1:]))) for line in emission}
98+
99+
print(optimal_path(x,sigma,states,transition,emission))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
AAABA
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
xyxzz
2+
--------
3+
x y z
4+
--------
5+
A B
6+
--------
7+
A B
8+
A 0.641 0.359
9+
B 0.729 0.271
10+
--------
11+
x y z
12+
A 0.117 0.691 0.192
13+
B 0.097 0.42 0.483
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
BABA
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
xyxy
2+
--------
3+
x y
4+
--------
5+
A B
6+
--------
7+
A B
8+
A 0.5 0.5
9+
B 0.5 0.5
10+
--------
11+
x y
12+
A 0.1 0.9
13+
B 0.9 0.1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
AAAA or
2+
BBBB
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
xyxy
2+
--------
3+
x y
4+
--------
5+
A B
6+
--------
7+
A B
8+
A 0.9 0.1
9+
B 0.1 0.9
10+
--------
11+
x y
12+
A 0.5 0.5
13+
B 0.5 0.5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
A
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
x
2+
--------
3+
x y
4+
--------
5+
A B
6+
--------
7+
A B
8+
A 0.4 0.6
9+
B 0.2 0.8
10+
--------
11+
x y
12+
A 0.55 0.45
13+
B 0.5 0.5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
AAAAA
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
zxyxy
2+
--------
3+
x y z
4+
--------
5+
A
6+
--------
7+
A
8+
A 1
9+
--------
10+
x y z
11+
A 0.5 0.5 0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
BC
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
xx
2+
--------
3+
x y
4+
--------
5+
A B C
6+
--------
7+
A B C
8+
A 0.7 0.1 0.2
9+
B 0.5 0.3 0.2
10+
C 1 0 0
11+
--------
12+
x y
13+
A 0 1
14+
B 0.5 0.5
15+
C 1 0

0 commit comments

Comments
 (0)