6 min to read
resilient Multiple Choice Learning
rMCL
Resilient Multiple Choice Learning
Introduction
머신러닝 모델은 일반적으로 주어진 입력에 대해 단일한 예측값을 출력하도록 학습되지만,
데이터의 본질적 특성이나 다양한 불확실성 요인으로 인해 조건부 출력 분포가 Multimodal 형태를 띄는 경우가 많다.
이러한 경우 조건부 확률분포의 평균은 저밀도 영역에 위치할 수 있으며, 단일 예측 대신 여러개의 가설을 동시에 예측하는 것이 더 바람직할 수 있다.
이러한 맥락에서 Multiple Choice Learning(MCL)은 각 head마다 하나의 예측가설을 생성할 수 있어 실용적인 해결책으로 제안되고 있다. 지도학습에서는 현재 입력 sample에 대해 가장 우수한 head에만 gradient를 계산하는 Winner-Takes-All (WTA) 방식은 각 head가 출력 공간의 서로다른 영역에 특화되도록 유도한다.
하지만 MCL은 hyppotheses collapse와 overconfidence라는 중대한 문제점을 갖고있어 이를 완화하기 위해 제안된 resilient Multiple Choice Learning (rMCL) 이 제안되었다.
Multiple Choice Learning
하나의 입력 x에 대해 여러개의 후보 예측을 동시에 학습하고 정답과 가장 잘 맞는 하나만 업데이트하는 학습 방식이다
MCL의 목표는 하나의 입력에 대해 여러 개의 출력을 동시에 학습함으로써, 단일 예측으로는 표현하기 어려운 조건부 출력 분포를 붕괴없이 근사하는 것이다
Model Uncertainty
모델 불확실성은 주로 학습데이터나 부족하거나 입력 공간이 충분히 관측되지 않았을때 나타날 수 있다
Data Uncertainty
데이터 불확실성은 관측 노이즈, 오차, 또는 생성과정에서의 내재적 확률성으로부터 발생할 수 있으며, 데이터 수가 증가해도 근본적으로 사라지지 않는다
다음과 같이 입력 x 하나에 여러개의 예측가설이 존재할때,
\[f_1(x), f_2(x), ....,f_K(x)\]K개의 예측중에서 정답에 가장 가까운 예측하나만 선택해서 손실을 계산한다
\[L(x,y) = \min_{k=1,...,K} l(y,f_k(x))\]하지만 기존 MCL에는 중대한 2가지의 문제점이 존재한다
Hypothesis collapse
한 head가 주도권을 잡으면 다른 head는 선택되지 않게되고, 한개의 head만 계속해서 업데이트되어 다른 head들은 훈련되지 않아 의미없는 출력이 생성될 수 있다
Overconfidence
확률 개념이 없어, rare mode도 하나의 대표 point가 되어 학습될 수 있고 이는 곧 분포를 왜곡 시킬 수 있다
Stochastic Multiple Choice Learning
지도학습 상황에서 입력 공간과 출력공간을 각각 다음과 같이 주어졌을때
\[Let \ \mathcal{X} \subseteq \mathbb R^d\] \[\mathcal Y \subseteq \mathbb R^q\]학습데이터는 다음과 같이 결합분포로부터 sampling된 데이터들로 구성되어 있다고 가정한다
\[p(x,y), \ \mathcal X \times\mathcal Y\] \[\mathcal D =(x_s,y_s)\]MCL은 하나의 입력에 대해 출력이 어려개인 조건부 분포가 multimodal인 경우를 해결하기 위해 제안되었고, 이를 딥러닝 환경에 맞게 확장한 sMCL이 제안되었다
sMCL은 K hypotheses 이라 불리는 모델을 사용하고, Winner-Takes-All (WTA) 손실을 통해 학습된다
\[f_\theta \triangleq (f_\theta^1,....,f_\theta^K) \in \mathcal F(\mathcal X, \mathcal Y^K)\]주어진 손실함수에 대해, 현재 batch에 있는 각각의 sample에 다음 계산을 진행한다
\[\mathcal L(f_\theta(x_s), y_s) \triangleq \min_{K\in[1,K]}\mathcal l(f_\theta^k(x_s),y_s)\]그후, winner hypothesis 즉, 손실을 최소화 하는 가설에 대해서만 backpropagation을 수행한다
이제는 하나의 입력에 대해 여러개의 target이 주어지는 상황이 주어졌을때
\[Y_s \subseteq \mathcal Y\]다음과 같은 meta-loss로 일반화될 수 있다
\[\mathcal L(f_\theta(x_s),Y_s) = \sum_{y \in Y} \sum_{k=1}^K 1(y\in \mathcal Y^k(x_s))\ l(f_\theta^k(x_s),y)\]where
\[\mathcal Y^k(x) \triangleq \{y\in \mathcal Y : \mathcal l (f_\theta^k(x),y)<l(f_\theta^r(x),y), \ \forall r \neq k\}\](i.e. k번째 head가 다른 모든 head보다 loss가 작다)
하지만 이런 손실함수는 한개나 몇몇개의 hypotheses 만 과도하게 역전파가 진행되어서 hypothesis collapse 로 이어질 수 있다
또한 Overconfidence 문제도 존재한다. 이러한 위험을 최소화 하기 위한 neccessary condition은 다음과 같다
\[\int_\mathcal X \sum_{k=1}^K \int_{\mathcal Y^k(x)} l(f^k_\theta(x),y)p(x,y)dydx\]위 식은 다음과 같이 도출된다. 기대 손실을 다음과 같이 정의할때
\[\mathbb E_{x,y} [\min_k l (f^k_\theta(x), y)]\] \[\sum_{k=1}^K \mathbb E_{x,y} [l (f^k_\theta(x), y) 1_{y \in \mathcal Y^k(x)}]\]이를 적분형으로 정의하면 위 식을 도출할 수 있다
이때 각 k에 대해, 확률질량이 0이 아닌 Voronoi Cell 내에서 각 K hypothesis 가 조건부 평균에 해당해야 한다
그리고 다음 식을 centroidal Voronoi tessellation (각 셀의 중심이 그 셀의 centroid) 을 이룬다고 한다
하지만 이 정리는 아주 작은 영역에서의 예측에서는 아무런 의미가 없을 것이며, 단순히 예측된 가설만으로는 작은 확률질량의 출력 공간을 식별할 수 없다
따라서 다음과 같은 hypothesis-scoring head 이 제안되었으며
\[\gamma^1_\theta,......,\gamma^K_\theta \in \mathcal F (\mathcal X, [0,1])\]이는 새로운 입력 x가 주어졌을 때 실제 출력 y가 k번째 head의 Voronoi cell에 있을 확률을 예측하는 것을 목표로 한다
\[\mathbb P(\mathcal Y_x \in \mathcal Y^k(x))\]where
\[\mathcal Y_x \sim p(y|x)\]Resilient Multiple Choice Learning
각 입력에 대해 멀티모달 조건부 분포를 추정하는 문제를 고려한다
다중 출력 회귀 문제를 위해, 다음과 같이 scoring함수를 포함한 무작위로 초기화된 다중 가설 모델을 통해
\[(f^1_\theta,......,f^K_\theta,\gamma^1_\theta,......,\gamma^K_\theta)\]MCL 학습은 다음과 같이 확장시킬 수 있다
각각의 다음과 같은 training sample에 대해
\[(x_s, Y_s)\]다음과 같이 각각 positive (or winner) 와 negetive hypothesis 집합으로 정의 할때
\[\mathcal K_+(x_s) \triangleq \{k^+ \in [1,K] : \exists y \in Y_s, k^+ \in \arg \min _k l(f^k_\theta(x_s), y)\}\] \[\mathcal K_-(x_s) \triangleq [1,K] -\mathcal K_+(x_s)\]이는 위에서 정의던 다음의 multi-target WTA loss L와 결합시킬 수 있다
\[\mathcal L(f_\theta(x_s),Y_s) = \sum_{y \in Y} \sum_{k=1}^K 1(y\in \mathcal Y^k(x_s))\ l(f_\theta^k(x_s),y)\] \[\mathcal L_{scoring}(\theta) \triangleq -(\sum_{k^+\in \mathcal K_+(x_s)}\log \gamma_\theta^{k^+}(x_s) + \sum_{k^-\in \mathcal K_-(x_s)}\log (1 - \gamma_\theta^{k^-}(x_s)))\]최종 손실함수는 L+βL_scoring의 형태로 정의된다
이는 기존 MCL 기법들과 달리 회귀문제에서 multimodal 분포를 예측할 수 있는점과 loss function에 의해 업데이트 되는 scroing branch를 도입한다는 novelty가 존재한다.
이때 각 scoring branch는 해당 hypothesis가 winner에 속할 확률이다
Algorithm : inference in the rMCL model
input :
Unlabelled input
\[x \in \mathcal X\]Trained hypotheses and score heads
\[(f_\theta^1,....,f_\theta^\mathcal K,\gamma_\theta^1,....,\gamma_\theta^\mathcal K)\in \mathcal F(\mathcal X, \mathcal Y)^\mathcal K \times\mathcal F(\mathcal X, [0,1])^\mathcal K\]Output :
- Prediction of the output conditional distribution
- Construct the associated Voronoi components
with
\[\mathcal Y = \cup_{k=1}^K \mathcal Y^k(x)\]- Normalize the predicted scores
- 출력이 조건부 분포일때, γ(x) > 0 인 모든 k에 대해 예측을 다음과 같이 해석하고 rMCL의 출력목표는 다음과 같다
Probabilistic interpretation at inference time
학습된 rMCL모델이 출력목표를 만족한다면 다음과 같은 law of total expectation을 도출할 수 있다
\[\mathbb E[Y_x] = \sum_{k=1}^k \gamma_\theta^k(x)f_\theta^k(x)\]만약 출력된 조건부 분포가 multimodal인 경우, 정보적으로 충분하지 않을수 있고, 실제로 완벽한 확률적 해석을 위해서는 어떻게 rMCL이 Y의 전체적인 조건부 밀도를 근사하는지 특정해야한다.
이는 출력공간의 각 Voronoi cell k 내에서 Y에 대한 분포를 고정하는 것만으로 rMCL의 예측분포를 표현할 수 있다
Proposition1 : Probabilistic interpretation of rMCL
각 입력 x ∈ X 에 대해 다음이 성립한다고 가정할때,
\[Y_x^k \sim \pi_k(·|x)\]i.e. 랜덤 출력변수가 Voronoi cell에 속할때 조건부 확률변수의 분포가 주어졌을때 다음을 만족한다
\[\mathbb E[Y_x^k] = f_\theta ^k (x)\]어떤 각 measurable set A ⊆ Y에 대해, 다음과 같이 도출할 수 있다
\[\mathbb P(Y_x \in A) = \Sigma_k \mathbb P (Y_x \in A \cap \mathcal Y^k(x))\] \[= \Sigma_k \mathbb P (Y_x \in \mathcal Y^k(x))\mathbb P (Y_x \in A | Y_x \in \mathcal Y^k(x))\] \[= \Sigma_k \gamma^k_\theta (x) \int _{y\in A} \pi _k (y|x) dy\]엔트로피 최대화원리에 따르면, 출력분포가 유한한 부피이면서, 각 가설들이 각 cell의 기하학적인 중심에 놓여져 있을대 Uniform distribution에서 가장 적은 정보를 갖는 분포가 된다
\[Y_x^k \sim \mathcal U(\mathcal Y^k(x))\]위 가정에서 예측 분포는 다음과 같이 해석할 수 있다
\[\hat p(y|x) = \sum_{k=1}^K \gamma_\theta^k(x) \frac{1_{y \in \mathcal Y^k(x)}}{\mathcal V (\mathcal Y ^k (x))}\] \[\mathcal V(\mathcal Y^k(x)) \triangleq \int_{y\in \mathcal Y^k(x)} dy\]유사하게, 출력분포를 Dirac deltas 혼합 분포로 사용할때, 각 영역 k에 대해 다음과 같이 가정하면
\[\gamma_\theta ^k (x) >0 \rightarrow Y_x^k \sim \delta_{f_\theta^k}(x)\]그러면 다음과 같이 도출된다
\[\hat p(y|x) = \sum_{k=1}^K \gamma_\theta^k(x)\delta_{f_\theta^k}(x)\]만약 Gaussian Mixture Model을 따른다면 \(\hat p(y|x) = \sum_{k=1}^K \gamma_\theta^k(x) \mathcal N (y | f_\theta ^k (x), \Sigma _k)\)
Reference
V. Letzelter, M. Fontaine, M. Chen, P. Pérez, S. Essid, and G. Richard,
“Resilient Multiple Choice Learning: A Learned Scoring Scheme with Application to Audio Scene Analysis,”
(NeurIPS 2023)
Comments