複数の分布を混合したモデル
有限個のガウス分布を混合したモデル
where
混合ガウスモデルは
確率変数
★のデータ
次のGMMに従う確率変数
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)
def gmm_sampling(sample_size):
z_sample = np.random.binomial(n=1, p=0.5, size=(sample_size,))
x_sample = np.random.randn(sample_size)
return x_sample + 10.0 * z_sample - 5.0
sample = gmm_sampling(10000)
plt.hist(sample, bins=100)
plt.show()
各データの独立同分布性を仮定すると
モデルの仮定では、
となる。
最尤推定量は式(1)を最大化することで求まる。
最大化の手法は二種類
隠れ変数があるモデルの最尤推定量を求めるアルゴリズム
式(2)を
なので、
となる
(2)
任意の分布
とする。
を示せ (Jensenの不等式)。
尤度(式(3)の左側)を最大化する代わりに尤度の下界(式(3)の右側)を最大化する:
式(2)で
よって
を解くと
適当な終了条件で繰り返しをやめる
補助的なものを持っていてもよい
内部状態として他のクラスのオブジェクトを持っていてもよい
正規分布クラス
GMM クラス
まずは正規分布クラスを実装してみよう
正規分布クラスを実装せよ(以前の資料からコピペでOK)
import numpy as np
class Gaussian:
def __init__(self, dim):
'''コンストラクタ(みたいなもの)
オブジェクトを作るときに初めに実行される。
内部状態の初期化に使う
'''
self.dim = dim
'''
self = オブジェクトを指す。 self.dim は、オブジェクトの dim という変数を指す。
上の命令は、 self.dim に dim の値を代入することを表す
'''
self.set_mean(np.random.randn(dim)) # オブジェクトの mean という変数をランダムに初期化
self.set_cov(np.identity(dim))
def log_pdf(self, X):
''' 確率密度関数の対数を返す
Parameters
----------
X : numpy.array, shape (sample_size, dim)
Returns
-------
log_pdf : array, shape (sample_size,)
'''
if X.shape[1] != self.dim: # 入力の形をチェックしています
raise ValueError('X.shape must be (sample_size, dim)')
return -0.5 * np.sum((X - self.mean) * (np.linalg.solve(self.cov, (X - self.mean).T).T), axis=1) \
-0.5 * self.dim * np.log(2.0 * np.pi) - 0.5 * np.linalg.slogdet(self.cov)[1]
def fit(self, X):
''' X を使って最尤推定をする
Parameters
----------
X : numpy.array, shape (sample_size, dim)
'''
if X.shape[1] != self.dim: # 入力の形をチェックしています
raise ValueError('X.shape must be (sample_size, dim)')
self.set_mean(np.mean(X, axis=0))
self.set_cov((X - self.mean).T @ (X - self.mean) / X.shape[0])
def sample(self, sample_size):
''' 現状のパラメタを使って `sample_size` のサイズのサンプルを生成する
Parameters
----------------
sample_size : int
Returns
-----------
X : numpy.array, shape (sample_size, dim)
各行は平均 `self.mean`, 分散 `self.cov` の正規分布に従う
'''
return np.random.multivariate_normal(self.mean, self.cov, size=sample_size)
def set_mean(self, mean):
if mean.shape != (self.dim,):
raise ValueError('input shape inconsistency')
self.mean = mean
def set_cov(self, cov):
if cov.shape != (self.dim, self.dim):
raise ValueError('input shape inconsistency')
if np.linalg.eigvalsh(cov)[0] <= 0:
raise ValueError('covariance matrix must be positive semidefinite.')
self.cov = cov
__init__
を書いてみようclass GMM:
def __init__(self, dim, num_components):
self.dim = dim
self.num_components = num_components
self.weight = np.ones(self.num_components) / self.num_components
self.gaussian_list = []
for _ in range(self.num_components):
self.gaussian_list.append(Gaussian(dim))
# gmm のオブジェクトができた
gmm = GMM(2, 10)
for each_gaussian in gmm.gaussian_list:
print(each_gaussian.mean)
print(gmm.weight)
[ 0.11684931 -1.71658838] [-1.10708776 0.25183766] [-1.24486871 0.55898262] [ 1.77706673 -1.59593316] [ 1.5407428 -0.58973061] [-1.81424689 0.16981966] [-0.12951037 0.52775457] [ 0.24461285 -0.3235077 ] [1.2561859 0.75503971] [ 0.15138818 -0.42148386] [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
以降、EMアルゴリズムでパラメタ推定をする fit
を書きたい
log_pdf
: 現在のパラメタを用いてサンプルの対数尤度を計算するメソッドe_step
: Eステップを実行するメソッドm_step
: Mステップを実行するメソッドの3つを実装し、これらを組み合わせて fit
を書く
log_pdf
を GMM
に実装せよX
: サンプルサイズclass GMM:
def __init__(self, dim, num_components):
self. dim = dim
self.num_components = num_components
self.gaussian_list = []
self.weight = np.ones(self.num_components) / self.num_components
for _ in range(self.num_components):
self.gaussian_list.append(Gaussian(dim))
def log_pdf(self, X):
sample_size = X.shape[0]
likelihood = np.zeros(sample_size) # 各データの尤度を計算する
'''
ここで現在の内部状態での likelihood を計算する
** ヒント **
self.gaussian_list の各要素は Gaussian クラスのオブジェクト
各データの対数尤度を計算する log_pdf という命令が実装してあった
'''
return np.log(likelihood) # 対数尤度が欲しいので対数をとる
self.gaussian_list[0].log_pdf(X)
はサンプル X
のデータそれぞれに対する対数尤度を計算するX
が sample_size
x dim
とすると、返り値は長さ sample_size
の配列self.weight[0]
は self.weight[0] * np.exp(self.gaussian_list[0].log_pdf(X))
は np.exp(self.gaussian_list[0].log_pdf(X))
の各要素に self.weight[0]
を掛けた値を返すスカラー * 配列
と書くと、配列の各要素にスカラーを掛けてくれる(numpy
の記法)class GMM:
def __init__(self, dim, num_components):
self. dim = dim
self.num_components = num_components
self.gaussian_list = []
self.weight = np.ones(self.num_components) / self.num_components
for _ in range(self.num_components):
self.gaussian_list.append(Gaussian(dim))
def log_pdf(self, X):
sample_size = X.shape[0]
likelihood = np.zeros(sample_size) # 各データの尤度を計算する
for each_component in range(self.num_components):
likelihood = likelihood + self.weight[each_component] * np.exp(self.gaussian_list[each_component].log_pdf(X))
return np.log(likelihood)
gmm = GMM(2, 10)
gmm.log_pdf(np.random.randn(100, 2)) # とりあえず何か計算はできている
array([-3.12648963, -3.86705701, -2.57161702, -3.06551493, -2.65614086, -3.0071172 , -4.62457074, -2.88583395, -2.95423388, -2.93725819, -2.66778609, -2.91656884, -3.43291286, -2.61947626, -3.44059698, -3.25383126, -2.84392495, -2.79003912, -2.66291126, -2.99469256, -4.03689065, -2.83441528, -2.75630803, -2.67636738, -3.75517664, -2.64458188, -2.79363847, -2.87221519, -3.10244486, -2.69670577, -3.02748979, -2.64780433, -2.63585142, -3.09722914, -2.58475377, -2.73534958, -3.05715522, -2.72387022, -4.5341286 , -3.08127239, -3.21928694, -3.54947664, -3.14211745, -2.74879231, -3.44220236, -3.49531004, -4.53836991, -3.18854928, -3.53895284, -2.55987735, -3.11221332, -2.57703954, -3.19613213, -3.35386561, -2.7605182 , -3.27197597, -2.80149132, -3.1221321 , -4.38693956, -3.09966228, -4.89665286, -2.68778129, -3.14177503, -2.7074782 , -3.16332312, -2.66797893, -2.89198018, -3.92306412, -3.82068393, -3.72077715, -3.35130862, -2.74938948, -3.1884004 , -3.23102166, -3.44943404, -2.59168975, -2.55253307, -3.3592686 , -4.24582831, -3.83563231, -3.23305835, -3.40850046, -3.27423081, -2.89178352, -3.01971239, -2.66453107, -2.55463812, -2.71222266, -3.1569365 , -2.73207968, -3.70887009, -3.55214124, -3.24868335, -2.5982095 , -3.72506253, -2.99258731, -3.05604243, -3.21803488, -3.46556688, -2.61541135])
e_step
を実装せよX
: sample_size
dim
の配列n_components
posterior
: sample_size
n_components
の配列として出力posterior[n, k]
= E-step: for each
class GMM:
def __init__(self, dim, num_components):
self.dim = dim
self.num_components = num_components
self.gaussian_list = []
self.weight = np.ones(self.num_components) / self.num_components
for _ in range(self.num_components):
self.gaussian_list.append(Gaussian(dim))
def log_pdf(self, X):
sample_size = X.shape[0]
likelihood = np.zeros(sample_size) # 各データの尤度を計算する
for each_component in range(self.num_components):
likelihood = likelihood + self.weight[each_component] * np.exp(self.gaussian_list[each_component].log_pdf(X))
return np.log(likelihood)
def e_step(self, X):
sample_size = X.shape[0]
posterior = np.zeros((sample_size, self.num_components))
'''
ここを埋める
** ヒント **
各 n について、分母は共通
1. 分子を先に計算して posterior に入れてしまう
2. posterior[n, :].sum() は分母と等しくなる
3. posterior[n, :] = posterior[n, :] / posterior[n, :].sum() とすれば良さそう
'''
return posterior
class GMM:
def __init__(self, dim, num_components):
self.dim = dim
self.num_components = num_components
self.gaussian_list = []
self.weight = np.ones(self.num_components) / self.num_components
for _ in range(self.num_components):
self.gaussian_list.append(Gaussian(dim))
def log_pdf(self, X):
sample_size = X.shape[0]
likelihood = np.zeros(sample_size) # 各データの尤度を計算する
for each_component in range(self.num_components):
likelihood = likelihood + self.weight[each_component] * np.exp(self.gaussian_list[each_component].log_pdf(X))
return np.log(likelihood)
def e_step(self, X):
sample_size = X.shape[0]
posterior = np.zeros((sample_size, self.num_components))
# 各コンポーネントで対数尤度が計算できるので、それを利用
for each_component in range(self.num_components):
posterior[:, each_component] \
= self.weight[each_component] * np.exp(self.gaussian_list[each_component].log_pdf(X))
# まとめて正規化しているが、 for 文でやってもOK(遅くなるけど)
posterior = posterior / posterior.sum(axis=1).reshape(-1, 1)
return posterior
# q はデータがどちらのコンポーネントに近いかを表しているので、
# そんな感じの結果になって欲しい
gmm = GMM(2, 2)
X = gmm.gaussian_list[0].mean.reshape(1, 2)
print(gmm.e_step(X))
X = gmm.gaussian_list[1].mean.reshape(1, 2)
print(gmm.e_step(X))
[[0.86539486 0.13460514]] [[0.13460514 0.86539486]]
m_step
を実装せよX
: sample_size
dim
の配列sample_size
n_components
class GMM:
def __init__(self, dim, num_components):
self.dim = dim
self.num_components = num_components
self.gaussian_list = []
self.weight = np.ones(self.num_components) / self.num_components
for _ in range(self.num_components):
self.gaussian_list.append(Gaussian(dim))
def fit(self, X, eps=1e-8):
sample_size = X.shape[0]
converge = False
old_ll = -np.inf
new_ll = -np.inf
while not converge:
posterior = self.e_step(X)
self.m_step(X, posterior)
new_ll = self.log_pdf(X).sum()
if new_ll < old_ll:
raise ValueError('likelihood decreases!')
if np.abs(old_ll - new_ll) / np.abs(new_ll) < eps:
converge = True
old_ll = new_ll
return posterior
def e_step(self, X):
posterior = np.zeros((X.shape[0], self.num_components))
for each_component in range(self.num_components):
posterior[:, each_component] \
= self.weight[each_component] * np.exp(self.gaussian_list[each_component].log_pdf(X))
posterior = posterior / posterior.sum(axis=1).reshape(-1, 1)
return posterior
def m_step(self, X, posterior):
self.weight = posterior.sum(axis=0)
for each_component in range(self.num_components):
self.gaussian_list[each_component].set_mean(posterior[:, each_component] @ X / self.weight[each_component])
self.gaussian_list[each_component].set_cov(
(posterior[:, each_component] * (X - self.gaussian_list[each_component].mean).T) \
@ (X - self.gaussian_list[each_component].mean) / self.weight[each_component])
self.weight = self.weight / np.sum(self.weight)
def log_pdf(self, X):
sample_size = X.shape[0]
likelihood = np.zeros(sample_size) # 各データの尤度を計算する
for each_component in range(self.num_components):
likelihood = likelihood + self.weight[each_component] * np.exp(self.gaussian_list[each_component].log_pdf(X))
return np.log(likelihood)
X = np.vstack((np.random.randn(100, 3), np.random.randn(100, 3) + 3))
gmm = GMM(3, 2)
posterior = gmm.fit(X)
print(gmm.gaussian_list[0].mean)
print(gmm.gaussian_list[1].mean)
[-0.15634628 0.0826993 0.04664785] [2.99283013 3.03949123 2.96590465]
posterior
array([[9.99999972e-01, 2.80218650e-08], [9.97371053e-01, 2.62894687e-03], [9.99999983e-01, 1.65791801e-08], [9.99999198e-01, 8.01597957e-07], [9.97733881e-01, 2.26611868e-03], [9.99476715e-01, 5.23284529e-04], [1.00000000e+00, 5.67816857e-11], [9.99999944e-01, 5.63259856e-08], [9.99999202e-01, 7.98373347e-07], [9.99999991e-01, 8.98339412e-09], [9.99999993e-01, 7.14412263e-09], [9.99999999e-01, 7.20960106e-10], [9.99999990e-01, 9.79199031e-09], [9.99999938e-01, 6.21605991e-08], [9.99999869e-01, 1.30754432e-07], [9.99999999e-01, 7.47173169e-10], [9.99985013e-01, 1.49867177e-05], [9.99999998e-01, 1.76905802e-09], [9.99999962e-01, 3.82747890e-08], [1.00000000e+00, 6.07185145e-13], [9.99999127e-01, 8.72579183e-07], [9.99913754e-01, 8.62462203e-05], [9.99999913e-01, 8.65319842e-08], [9.99999983e-01, 1.73070999e-08], [9.99999986e-01, 1.39271511e-08], [9.99999892e-01, 1.07590678e-07], [9.88421571e-01, 1.15784293e-02], [9.99999964e-01, 3.59867788e-08], [9.99999663e-01, 3.37265939e-07], [9.99999985e-01, 1.49309649e-08], [9.99999572e-01, 4.27887629e-07], [9.59783355e-01, 4.02166454e-02], [9.97050613e-01, 2.94938667e-03], [9.99999961e-01, 3.91009276e-08], [1.00000000e+00, 9.10492798e-13], [9.99999878e-01, 1.22356081e-07], [9.99965215e-01, 3.47851055e-05], [1.00000000e+00, 7.48456732e-11], [9.99993642e-01, 6.35809155e-06], [9.99999986e-01, 1.39305840e-08], [9.99994125e-01, 5.87501848e-06], [1.00000000e+00, 4.47041114e-12], [9.99987397e-01, 1.26034971e-05], [1.00000000e+00, 1.11335739e-12], [9.99998851e-01, 1.14868081e-06], [9.99999993e-01, 6.87674223e-09], [9.99999998e-01, 2.02409122e-09], [9.99998894e-01, 1.10599962e-06], [9.99996332e-01, 3.66762392e-06], [9.99999995e-01, 5.14285315e-09], [9.99999997e-01, 3.22461769e-09], [9.99999402e-01, 5.97516589e-07], [9.99989099e-01, 1.09005786e-05], [9.99971098e-01, 2.89020294e-05], [9.99896801e-01, 1.03198835e-04], [9.99999920e-01, 8.00432335e-08], [9.99999660e-01, 3.39700553e-07], [9.99619582e-01, 3.80418356e-04], [1.00000000e+00, 7.35290604e-15], [1.00000000e+00, 7.35528489e-12], [9.99979356e-01, 2.06439507e-05], [9.99998913e-01, 1.08707283e-06], [9.99999998e-01, 2.07980454e-09], [1.00000000e+00, 2.57652463e-10], [9.99997782e-01, 2.21792869e-06], [9.99999971e-01, 2.94073784e-08], [1.00000000e+00, 1.84888247e-13], [9.99999997e-01, 3.30709327e-09], [9.99999989e-01, 1.07981319e-08], [9.99999995e-01, 5.15835688e-09], [9.99999998e-01, 2.09368217e-09], [9.99675433e-01, 3.24567488e-04], [9.99896928e-01, 1.03071596e-04], [9.99996710e-01, 3.29020892e-06], [9.99999998e-01, 1.59474924e-09], [1.00000000e+00, 6.52173363e-11], [9.99999934e-01, 6.63339570e-08], [9.99999995e-01, 4.77295676e-09], [9.99998473e-01, 1.52695122e-06], [9.99999388e-01, 6.12160297e-07], [1.00000000e+00, 1.86621547e-11], [9.99999811e-01, 1.88552843e-07], [9.99999978e-01, 2.21031636e-08], [9.99999952e-01, 4.83580056e-08], [1.00000000e+00, 2.18167492e-10], [1.00000000e+00, 2.38769747e-10], [9.99999747e-01, 2.53122619e-07], [9.99999963e-01, 3.66072857e-08], [9.99997640e-01, 2.36039863e-06], [9.99999897e-01, 1.02749511e-07], [9.99999963e-01, 3.71438255e-08], [9.99999920e-01, 7.95790592e-08], [9.98109261e-01, 1.89073896e-03], [9.99999987e-01, 1.29753444e-08], [9.99999994e-01, 6.49605132e-09], [9.99979095e-01, 2.09052371e-05], [9.99999619e-01, 3.80542622e-07], [9.99962290e-01, 3.77098388e-05], [9.99999929e-01, 7.09338253e-08], [1.00000000e+00, 9.39713822e-12], [2.57839929e-10, 1.00000000e+00], [7.10030580e-10, 9.99999999e-01], [7.77263991e-05, 9.99922274e-01], [4.87672154e-07, 9.99999512e-01], [1.06497653e-13, 1.00000000e+00], [1.39125574e-09, 9.99999999e-01], [1.45855962e-07, 9.99999854e-01], [1.15844233e-10, 1.00000000e+00], [1.30885221e-07, 9.99999869e-01], [4.01342547e-03, 9.95986575e-01], [2.23846844e-06, 9.99997762e-01], [4.80096836e-12, 1.00000000e+00], [2.53902784e-07, 9.99999746e-01], [1.33699810e-06, 9.99998663e-01], [9.26144795e-13, 1.00000000e+00], [6.63946614e-09, 9.99999993e-01], [4.74685878e-05, 9.99952531e-01], [7.80325429e-12, 1.00000000e+00], [4.09353206e-09, 9.99999996e-01], [1.03403537e-07, 9.99999897e-01], [7.36491730e-11, 1.00000000e+00], [1.61149398e-08, 9.99999984e-01], [3.57696778e-09, 9.99999996e-01], [4.07324020e-10, 1.00000000e+00], [8.32553012e-08, 9.99999917e-01], [2.49562814e-04, 9.99750437e-01], [1.06010371e-12, 1.00000000e+00], [2.00498829e-08, 9.99999980e-01], [8.40632056e-13, 1.00000000e+00], [8.50063744e-10, 9.99999999e-01], [2.44062509e-07, 9.99999756e-01], [2.30363365e-08, 9.99999977e-01], [1.54354318e-12, 1.00000000e+00], [8.33215085e-06, 9.99991668e-01], [1.17876940e-09, 9.99999999e-01], [5.30530954e-09, 9.99999995e-01], [5.99631715e-05, 9.99940037e-01], [5.46127020e-10, 9.99999999e-01], [1.51642669e-05, 9.99984836e-01], [3.05980804e-07, 9.99999694e-01], [1.02057016e-05, 9.99989794e-01], [9.14538689e-06, 9.99990855e-01], [1.00902475e-02, 9.89909752e-01], [8.56886241e-08, 9.99999914e-01], [8.12302937e-06, 9.99991877e-01], [5.16071738e-06, 9.99994839e-01], [3.19395709e-06, 9.99996806e-01], [2.09877679e-08, 9.99999979e-01], [3.56352670e-07, 9.99999644e-01], [8.71899847e-10, 9.99999999e-01], [1.41279230e-06, 9.99998587e-01], [2.36027283e-05, 9.99976397e-01], [6.41034443e-01, 3.58965557e-01], [5.23253972e-09, 9.99999995e-01], [1.32300243e-03, 9.98676998e-01], [9.08053107e-11, 1.00000000e+00], [6.73462148e-06, 9.99993265e-01], [1.35862410e-05, 9.99986414e-01], [1.55411117e-09, 9.99999998e-01], [1.78898163e-08, 9.99999982e-01], [5.56246657e-07, 9.99999444e-01], [2.15857223e-04, 9.99784143e-01], [3.36432353e-07, 9.99999664e-01], [3.31609686e-07, 9.99999668e-01], [6.36769143e-08, 9.99999936e-01], [3.22594193e-10, 1.00000000e+00], [6.34359329e-08, 9.99999937e-01], [1.98389286e-07, 9.99999802e-01], [2.74824322e-08, 9.99999973e-01], [3.03502827e-09, 9.99999997e-01], [3.03011961e-07, 9.99999697e-01], [5.11191743e-09, 9.99999995e-01], [3.98535732e-04, 9.99601464e-01], [1.06161725e-06, 9.99998938e-01], [3.34384016e-06, 9.99996656e-01], [6.93024425e-10, 9.99999999e-01], [3.38301508e-08, 9.99999966e-01], [2.04188713e-01, 7.95811287e-01], [4.71264963e-07, 9.99999529e-01], [1.91537294e-10, 1.00000000e+00], [6.60530787e-06, 9.99993395e-01], [1.36686530e-12, 1.00000000e+00], [1.24327968e-05, 9.99987567e-01], [3.92591447e-06, 9.99996074e-01], [2.39191117e-12, 1.00000000e+00], [5.18414949e-10, 9.99999999e-01], [6.89257801e-15, 1.00000000e+00], [4.46495438e-06, 9.99995535e-01], [2.60311464e-04, 9.99739689e-01], [3.14866310e-09, 9.99999997e-01], [5.24164837e-08, 9.99999948e-01], [1.59225146e-08, 9.99999984e-01], [3.06876781e-14, 1.00000000e+00], [4.60126666e-11, 1.00000000e+00], [1.57673959e-10, 1.00000000e+00], [1.62251036e-08, 9.99999984e-01], [1.77302902e-15, 1.00000000e+00], [7.44778750e-12, 1.00000000e+00], [6.25496083e-09, 9.99999994e-01], [3.59014236e-08, 9.99999964e-01]])