From 53fcdf679f836e21fa3f9e69b5391ac033f30c38 Mon Sep 17 00:00:00 2001 From: Michael Zhang Date: Sun, 1 Oct 2023 19:24:05 -0500 Subject: [PATCH] 3c --- assignments/hwk01/hw1solve.py | 57 ++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/assignments/hwk01/hw1solve.py b/assignments/hwk01/hw1solve.py index 66b22b2..5d1a5a0 100644 --- a/assignments/hwk01/hw1solve.py +++ b/assignments/hwk01/hw1solve.py @@ -1,14 +1,55 @@ -from sympy.abc import i, k, m, n, x -import sympy +from itertools import product -def prob_2a(): - # f = sympy.Function('f') - def f(x, theta): return (1.0 / theta) * sympy.exp(- x / theta) +def calc_posterior(p_c1: float, D: int, p_ij: dict[tuple[int, int], float]): + priors = { + 1: p_c1, + 2: 1 - p_c1, + } - log = sympy.Sum(f, (k, 1, n)) + def p_x_given_Ci(xs: list[int], i: int): + s = 0 + for j in range(len(xs)): + s += pow(p_ij[i, j], 1.0 - xs[j]) * pow(1.0 - p_ij[i, j], xs[j]) + return s - print(log) + # print("===") + # for i in priors.keys(): + # for xs in product([0, 1], repeat=2): + # xs = list(xs) + # print("xs", xs) + # print(i, xs, p_x_given_Ci(xs, i)) + # print("===") + + posteriors = {} + for i in [1, 2]: + for xs in product([0, 1], repeat=D): + numer = p_x_given_Ci(xs, i) * priors[i] + + def each_denom(k): return p_x_given_Ci(xs, k) * priors[k] + denom = sum(map(each_denom, priors.keys())) + posteriors[*xs, i] = numer / denom + + print("Priors:", priors) + for xs in product([0, 1], repeat=D): + print(f"{xs = }") + for i in [1, 2]: + prob = posteriors[*xs, i] + print(f" * C{i}: {prob:0.3f}") + print() -prob_2a() +def prob_3c(): + D = 2 + p_ij = {} + p_ij[1, 0] = 0.6 + p_ij[1, 1] = 0.1 + p_ij[2, 0] = 0.6 + p_ij[2, 1] = 0.9 + + calc_posterior(0.2, D, p_ij) + calc_posterior(0.8, D, p_ij) + + +# prob_2a() +prob_3c()