|
| 1 | +import sys |
| 2 | +import math |
| 3 | + |
| 4 | +# 1. Express a solution mathematically: |
| 5 | +# We know: Pr(x | P ) = Product(i=1..n, Pr(xi|Pi)) = Product(i=1..n, emission(Pi, xi)) |
| 6 | +# Pr(P) = Product(i=1..n, Pr(Pi-1 --> Pi)) = Product(i=1..n, transition(Pi-1, Pi)) |
| 7 | +# Then Pr(P, x) = Pr(x | P ) * Pr(P) = Product(i=1..n, emission(Pi, xi)) * Product(i=1..n, transition(Pi-1, Pi)) |
| 8 | +# So, Pr(optimal P, x) = Max Pr(x | P ) * Pr(P) = Product(i=1..n, emission(Pi, xi)) * Product(i=1..n, transition(Pi-1, Pi)) |
| 9 | +# 2. Proof: |
| 10 | +# 3. Solutions: we can solve this problem wit a DP technique + Bottom up solution: |
| 11 | +# 1. We build the corresponding graph matrix (|states| . |x|) |
| 12 | +# - M[state_i,p] = max(M[state_j, p - 1] * transition[state_j][state_i] * emission[state_i][xp] |
| 13 | +# - To avoid the risk of stackoverflow error, we'll use logarithm: |
| 14 | +# - log(M[state_i,p]) = max(M[state_j, p - 1] + log(transition[state_j][state_i]) + log(emission[state_i][xp])) |
| 15 | +# 2. From the matrix, M, we'll use the backtracking approach to build the hidden path |
| 16 | +def build_optimal_path_backtrack(x, states, transition, emission, M): |
| 17 | + |
| 18 | + optimal_hidden_path = ['' for _ in range(len(x))] |
| 19 | + |
| 20 | + opt_state = util_column_max_value_index(M, len(states), len(x) - 1) |
| 21 | + optimal_hidden_path[len(x) - 1] = states[opt_state] |
| 22 | + for p in range(len(x) - 1, 0, -1): |
| 23 | + curr_x = x[p] |
| 24 | + |
| 25 | + for s in range(len(states)): |
| 26 | + if M[opt_state][p] == M[s][p - 1] + util_sum_log_vals((transition[ states[s] ][ states[opt_state] ], emission[ states[opt_state] ][ curr_x ])): |
| 27 | + opt_state = s |
| 28 | + break |
| 29 | + optimal_hidden_path[p - 1] = states[opt_state] |
| 30 | + |
| 31 | + return optimal_hidden_path |
| 32 | + |
| 33 | +def optimal_path(x, sigma, states, transition, emission): |
| 34 | + |
| 35 | + # The transitions from the initial state occur with equal probability: Prob(P0 = Si) = 1/|states| * Pr(x = x0 | P0) |
| 36 | + M = [ [util_sum_log_vals((1/len(states), emission[ states[s] ][ x[0] ])) if p == 0 else 0 for p in range(len(x))] for s in range(len(states)) ] |
| 37 | + |
| 38 | + for p in range(len(x) - 1): |
| 39 | + next_x = x[p + 1] |
| 40 | + |
| 41 | + for s in range(len(states)): |
| 42 | + M[s][p + 1] = -math.inf if M[0][p] == -math.inf else M[0][p] + util_sum_log_vals((transition[ states[0] ][ states[s] ], emission[ states[s] ][ next_x ])) |
| 43 | + |
| 44 | + for s in range(1, len(states)): |
| 45 | + for next_s in range(len(states)): |
| 46 | + if M[s][p] == - math.inf: |
| 47 | + continue |
| 48 | + |
| 49 | + candidate_probability = M[s][p] + util_sum_log_vals((transition[ states[s] ][ states[next_s] ], emission[ states[next_s] ][ next_x ])) |
| 50 | + if candidate_probability > M[next_s][p + 1]: |
| 51 | + M[next_s][p + 1] = candidate_probability |
| 52 | + |
| 53 | + optimal_hidden_path = build_optimal_path_backtrack(x, states, transition, emission, M) |
| 54 | + |
| 55 | + return ''.join(optimal_hidden_path) |
| 56 | + |
| 57 | +def util_sum_log_vals(vals): |
| 58 | + sum_log = 0 |
| 59 | + for val in vals: |
| 60 | + assert(val >= 0) |
| 61 | + if val == 0: |
| 62 | + sum_log = - math.inf |
| 63 | + break |
| 64 | + else: |
| 65 | + sum_log += math.log(val) |
| 66 | + |
| 67 | + return sum_log |
| 68 | + |
| 69 | +def util_column_max_value_index(matrix, rows_count, c): |
| 70 | + assert(rows_count > 0) |
| 71 | + assert(c < len(matrix[0])) |
| 72 | + |
| 73 | + row_max_value = 0 |
| 74 | + for r in range(1, rows_count): |
| 75 | + if matrix[r][c] > matrix[row_max_value][c]: |
| 76 | + row_max_value = r |
| 77 | + |
| 78 | + return row_max_value |
| 79 | + |
| 80 | +if __name__ == "__main__": |
| 81 | + x = sys.stdin.readline().strip() |
| 82 | + sys.stdin.readline() # delimiter |
| 83 | + |
| 84 | + sigma = sys.stdin.readline().strip().split() |
| 85 | + sys.stdin.readline() # delimiter |
| 86 | + |
| 87 | + states = sys.stdin.readline().strip().split() |
| 88 | + sys.stdin.readline() # delimiter |
| 89 | + |
| 90 | + chars = sys.stdin.readline().strip().split() |
| 91 | + transition = [sys.stdin.readline().strip().split() for _ in range(len(states))] |
| 92 | + transition = {line[0]:dict(zip(chars, map(float,line[1:]))) for line in transition} |
| 93 | + sys.stdin.readline() # delimiter |
| 94 | + |
| 95 | + chars = sys.stdin.readline().strip().split() |
| 96 | + emission = [sys.stdin.readline().strip().split() for _ in range(len(states))] |
| 97 | + emission = {line[0]:dict(zip(chars, map(float,line[1:]))) for line in emission} |
| 98 | + |
| 99 | + print(optimal_path(x,sigma,states,transition,emission)) |
0 commit comments