2023-10-02 00:24:05 +00:00
|
|
|
from itertools import product
|
2023-10-02 03:47:05 +00:00
|
|
|
from sympy import symbols, log, diff, exp, Product, Pow
|
2023-10-01 23:09:50 +00:00
|
|
|
|
|
|
|
|
2023-10-02 00:24:05 +00:00
|
|
|
def calc_posterior(p_c1: float, D: int, p_ij: dict[tuple[int, int], float]):
|
|
|
|
priors = {
|
|
|
|
1: p_c1,
|
|
|
|
2: 1 - p_c1,
|
|
|
|
}
|
2023-10-01 23:09:50 +00:00
|
|
|
|
2023-10-02 00:24:05 +00:00
|
|
|
def p_x_given_Ci(xs: list[int], i: int):
|
2023-10-02 03:47:05 +00:00
|
|
|
s = 1.0
|
2023-10-02 00:24:05 +00:00
|
|
|
for j in range(len(xs)):
|
2023-10-02 03:47:05 +00:00
|
|
|
s *= pow(p_ij[i, j], 1.0 - xs[j]) * pow(1.0 - p_ij[i, j], xs[j])
|
2023-10-02 00:24:05 +00:00
|
|
|
return s
|
2023-10-01 23:09:50 +00:00
|
|
|
|
2023-10-02 00:24:05 +00:00
|
|
|
posteriors = {}
|
|
|
|
for i in [1, 2]:
|
|
|
|
for xs in product([0, 1], repeat=D):
|
|
|
|
numer = p_x_given_Ci(xs, i) * priors[i]
|
2023-10-01 23:09:50 +00:00
|
|
|
|
2023-10-02 00:24:05 +00:00
|
|
|
def each_denom(k): return p_x_given_Ci(xs, k) * priors[k]
|
|
|
|
denom = sum(map(each_denom, priors.keys()))
|
|
|
|
posteriors[*xs, i] = numer / denom
|
|
|
|
|
|
|
|
print("Priors:", priors)
|
|
|
|
for xs in product([0, 1], repeat=D):
|
|
|
|
print(f"{xs = }")
|
|
|
|
for i in [1, 2]:
|
|
|
|
prob = posteriors[*xs, i]
|
|
|
|
print(f" * C{i}: {prob:0.3f}")
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
def prob_3c():
|
|
|
|
D = 2
|
|
|
|
p_ij = {}
|
|
|
|
p_ij[1, 0] = 0.6
|
|
|
|
p_ij[1, 1] = 0.1
|
|
|
|
p_ij[2, 0] = 0.6
|
|
|
|
p_ij[2, 1] = 0.9
|
|
|
|
|
|
|
|
calc_posterior(0.2, D, p_ij)
|
|
|
|
calc_posterior(0.8, D, p_ij)
|
|
|
|
|
|
|
|
|
2023-10-02 03:47:05 +00:00
|
|
|
def prob_2a():
|
|
|
|
p, x, theta, k, n = symbols("p x theta k n")
|
|
|
|
|
|
|
|
def get_mle(expr):
|
|
|
|
likelihood_func = Product(expr, (k, 1, n))
|
|
|
|
log_likelihood_func = log(likelihood_func)
|
|
|
|
|
|
|
|
print(diff(log_likelihood_func, x).simplify())
|
|
|
|
|
|
|
|
# print(diff(expr, x))
|
|
|
|
print(get_mle(Pow(p, x) * Pow(1 - p, 1 - x)))
|
|
|
|
print(get_mle((1 / theta) * exp(-x / theta)))
|
|
|
|
|
|
|
|
|
|
|
|
prob_2a()
|
2023-10-02 00:24:05 +00:00
|
|
|
prob_3c()
|