Hello Potato World
[ํฌํ ์ดํ ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] Siamese Contrastive Embedding Network for Compositional Zero-Shot Learning ๋ณธ๋ฌธ
[ํฌํ ์ดํ ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] Siamese Contrastive Embedding Network for Compositional Zero-Shot Learning
Heosuab 2023. 1. 8. 18:02โ ๏ฝก ห โ๏ธ ห ๏ฝก โ ๏ฝก ห โฝ ห ๏ฝก โ
[Zero-shot learning paper review]
Compositional Zero-shot Learning (CZSL)
์ธ๊ฐ์ ์ด๋ฏธ ์๊ณ ์๋ ๊ฐ์ฒด์ ๋ํ ์ ๋ณด๋ค์ ์กฐํฉํ๊ณ ๊ตฌ์ฑํ์ฌ ์๋ก์ด ๊ฐ์ฒด์ ์ผ๋ฐํํ๋ ๋ฅ๋ ฅ์ ๊ฐ์ง๊ณ ์๋ค. ๋ค์ ๋งํด์, ์ด๋ฏธ ์๊ณ ์๋ "whole apple"๊ณผ "sliced banana"์ ์ ๋ณด๋ฅผ ์กฐํฉํด์ ์๋ก์ด ๊ฐ์ฒด์ธ "sliced apple" ๋๋ "whole banana"๋ฅผ ์๊ฐํด๋ผ ์ ์๋ค. ์ด ๋ฅ๋ ฅ์ AI ์์คํ ์์ ๋ชจ๋ฐฉํ๊ณ ์ ํ๋ task๋ฅผ Compositional Zero-shot Learning (CZSL)์ด๋ผ๊ณ ํ๋ค.
CZSL์์๋ ๊ฐ๊ฐ์ ๊ฐ์ฒด(composition)์ ๋ ๊ฐ์ง์ ๊ตฌ์ฑ ์์, state์ object๋ก ๋ถํดํ๋ค. "whole", "sliced"์ฒ๋ผ ๊ฐ์ฒด์ ์ํ๋ฅผ ํํํ๋ ์์๋ state, "apple", "banana"์ฒ๋ผ ๊ฐ์ฒด์ ํํ์ ์ข
๋ฅ๋ฅผ ๊ตฌ๋ถ์ง๋ ์์๋ object๋ก ์ ์๋๋ค. CZSL์ ๋ชฉํ๋ Training data์ test data๊ฐ ๊ณตํต ์์๋ฅผ ๊ฐ์ง์ง ์๊ณ ๋ถ๋ฆฌ๋์ด ์๋ค๋ ๊ฐ์ ํ์, ์๋ก์ด(unseen) test composition์ ์๋ณํ๋ ๊ฒ์ด๋ค.
State์ object๋ฅผ ๋ชจ๋ ์๋ณํ๊ธฐ ์ํด CZSL์์ ์ฌ์ฉํ๋ ๋ํ์ ์ธ ๋ฐฉ๋ฒ์ผ๋ก๋,
- ๋ ๊ฐ์ง์ classifier๋ฅผ ๋ฐ๋ก ๋์ด state์ object๋ฅผ ๋ฐ๋ก ํ์ตํ๋ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์๋ค. ํ์ง๋ง ์ด ๋ฐฉ๋ฒ์ state-object ์ฌ์ด์ ์ํธ ์์ฉ ๋๋ entanglement๋ฅผ ๋ฌด์ํ๊ฒ ๋๋ค.
- ๋ ๋ชจ๋ composition๊ณผ visual feature๋ค์ด ํ๋ฒ์ ํฌ์๋ ์ ์๋ ๊ณตํต embedding space๋ฅผ ํ์ตํ์ฌ embedding ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ์ฌ ํ์ฉํ๋ ๋ฐฉ๋ฒ์ด ์๋๋ฐ, training๊ณผ test composition ์ฌ์ด์ ์ฐจ์ด๋ฅผ ๋ฌด์ํ์ฌ ๋น์ทํ ๊ฐ์ฒด๋ค์ ํผ๋ํ ์ ์๋ค. (e.g., young cat and young tiger)
๋ฐ๋ผ์ ํด๋น ๋ ผ๋ฌธ์์๋ state์ object ๊ฐ๊ฐ์ ๊ตฌ๋ถ์ ์ธ prototype๋ฅผ ํ์ฉํ์ง๋ง ๋ ์ฌ์ด์ joint representation๋ ํจ๊ป ํ์ตํ๋ Siamese Contrastive Embedding Network (SCEN)์ ์ ์ํ๋ค.
Siamese Contrastive Embedding Network (SCEN)
States์ ์งํฉ์ $A$, objects์ ์งํฉ์ $O$๋ผ๊ณ ํ๋ฉด, state-object์ ์์ผ๋ก ๊ตฌ์ฑ๋๋ components์ ์งํฉ $C$๋ ๋ค์๊ณผ ๊ฐ์ด ํํ๋ ์ ์์ผ๋ฉฐ,
\[ C = A \times O = \{(a,o) \mid a \in A, o \in O\} \]
Training dataset์์์ images ์งํฉ์ $I^s$, ๊ทธ์ ๋์ํ๋ component๋ฅผ $C^s$ ($C^s \subset C$) ๋ผ๊ณ ํ๋ฉด, image-component์ ์์ผ๋ก ๊ตฌ์ฑ๋๋ training dataset $D_{tr}$์ ๋ค์๊ณผ ๊ฐ์ด ํํ๋๋ค.
\[ D_{tr} = \{((i,c) \mid i \in I^s, c \in C^s\} \]
CZSL task์ ์ ์์ ๋ฐ๋ผ training data์ test data๋ ๊ณตํต ์์๋ฅผ ๊ฐ์ง์ง ์์์ผ ํ๊ธฐ ๋๋ฌธ์, training data์ composition์ $C^s$, test data์ composition์ $C^u$๋ผ๊ณ ํ๋ค๋ฉด $C^s \cap C^u = \emptyset$์ ๋ง์กฑํด์ผ ํ๋ค. ๋ํ ์๋ก์ด ์ด๋ฏธ์ง๋ฅผ seen, unseen composition ์ค์์ ์์ธกํด์ผ ํ๊ธฐ ๋๋ฌธ์, $\{I^s, C^s\}$๋ก ํ์ต๋ mapping function $I \to C^s \cup C^u$๋ฅผ ํ์ตํ๋ ๊ฒ์ ๋ชฉํ๋ก ํ๋ค.
SCEN ๊ตฌ์กฐ๋ฅผ ์ธ ๊ฐ์ง ๋ชจ๋๋ก ์์ฝํ๋ฉด,
- Encoding : states์ objects ๊ฐ๊ฐ์ encoding
- Contrastive learning : state/object ๊ฐ๊ฐ์ contrastive space์์์ prototypes ์ถ์ถ
- Augmentation : State Transition Module(STM)์ ํตํด ๊ฐ์์ composition ์์ฑ
- Module 1. Encoding
ํ๋์ ์ด๋ฏธ์ง๊ฐ feature extractor FC๋ฅผ ํต๊ณผํด์ ์ป์ visual feature $x$๋, state/object ๊ตฌ์ฑ ์์๋ก ๋ถํดํ๊ธฐ ์ํด ๋ ๊ฐ์ embedding์ผ๋ก ์ธ์ฝ๋ฉ๋๋ค. State-specific Encoder $E_s$๋ state๋ฅผ ์ ํํํ๊ธฐ ์ํด, Object-specific Encoder $E_o$๋ object๋ฅผ ์ ํํํ๊ธฐ ์ํด ํ์ต๋๋ค.
\[ h_s = E_s(x) \]
\[ h_o = E_o(x) \]
State-object์ ๋ค์ํ ์กฐํฉ์ ํตํด ์ฌ๋ฌ composition์ ๊ตฌ์ฑํ ์ ์๊ธฐ ๋๋ฌธ์, ์ด๋ฌํ joint representation์ ํ์ตํ๊ธฐ ์ํด ์ธ ๊ฐ์ง์ ๋ฐ์ดํฐ๋ฒ ์ด์ค๋ฅผ ์ ์ํ๋ค.
- ๊ณ ์ ๋ state์ ๋ค์ํ objects๋ฅผ ์กฐํฉํ๋ State-constant database $D_s$
- ๊ณ ์ ๋ object์ ๋ค์ํ states๋ฅผ ์กฐํฉํ๋ Object-constant database $D_o$
- ๋ค์ํ objects-states ์กฐํฉ ์ค์์ input instance์ ๊ด๋ จ์ด ์๋ Irrelevant database $D_{ir}$
state $\hat{a}$์ object $\hat{o}$๋ก ์ด๋ฃจ์ด์ง input instance $x=(\hat{a},\hat{o}) \in I^s$๊ฐ ์ ๋ ฅ๋๋ฉด, ์ธ ๊ฐ์ง์ ๋ฐ์ดํฐ๋ฒ ์ด์ค๋ ๋ค์๊ณผ ๊ฐ์ด ๊ตฌ์ฑ๋๋ค.
\[ D_s = \{(a,o) \mid a=\hat{a}, (a,o) \in C^s\} \]
\[ D_o = \{(a,o) \mid o=\hat{o}, (a,o) \in C^s\} \]
\[ D_{ir} = \{(a,o) \mid a \ne \hat{a}, o \ne \hat{o}, (a,o) \in C^s \} \]
- Module 2. Contrastive learning
$E_s, E_o$๋ฅผ ํตํด ๋ ๊ฐ์ ๋
๋ฆฝ๋ embedding space (Siamese contrastive space)๋ก ํฌ์๋ $h_s, h_o$๋ contrastive learning์ ํตํด state์ object ๊ฐ๊ฐ์ ๊ฐ์ฅ ์ ํํํ ์ ์๋ prototype์ผ๋ก ํ์ต๋๋ค. ํ์ง๋ง ๊ธฐ์กด์ contrastive loss๋ฅผ ์ฌ์ฉํ์ฌ state์ object๋ฅผ ๋ฐ๋ก ํ์ตํ๋ฉด state-object interaction์ด ๋ฌด์๋๊ธฐ ๋๋ฌธ์, ์ ์๋ ์ธ ๊ฐ์ง์ ๋ฐ์ดํฐ๋ฒ ์ด์ค๋ฅผ ์ฌ์ฉํ์ฌ loss๋ฅผ ์๋ก ์ ์ํ๋ค.
- State-based contrastive loss $\mathcal{L}_{scl}$
input $x$์ state encoding $h_s$๊ฐ state-based contrastive space์ anchor๋ก ์ค์ ๋๋ค. input $x$์ ๋์ผํ state๋ฅผ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค $D_s$๋ก๋ถํฐ positive sample $h_s^{ss}$์ ์ถ์ถํ๋ฉฐ, ๋์ผํ์ง ์์ state๋ฅผ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค $D_{ir}$๋ก๋ถํฐ $k$๊ฐ์ negative samples $\{ h_{s_1}^{ir}, ..., h_{s_k}^{ir} \}$๋ฅผ ์ถ์ถํ๋ค.
anchor์ positive ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ ๊ฐ๊น์์ง๋๋ก, anchor์ negative ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ ๋ฉ์ด์ง๋๋ก ํ์ตํ๋ contrastive loss๋ ๋ค์๊ณผ ๊ฐ์ด ์ ์๋๋ค. ($\tau_s > 0$ : temperature parameter)
\[ \mathcal{L}_{scl} = -log \frac{exp((h_s)^{\top} h_s^{ss} / \tau_s)}{exp((h_s)^{\top} h_s^{ss}/\tau_s) + \sum\nolimits_{i=1}^K exp((h_s)^{\top} h_{s_i}^{ir}/ \tau_s)} \] - Object-based contrastive loss $\mathcal{L}_{ocl}$
object-based contrastive space์ anchor๋ $h_o$๋ก ์ค์ ๋๋ค. ๋์ผํ object๋ฅผ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค $D_o$๋ก๋ถํฐ positive sample $h_o^{os}$๋ฅผ ์ถ์ถํ๋ฉฐ, ๋์ผํ์ง ์์ object๋ฅผ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค $D_{ir}$๋ก๋ถํฐ $k$๊ฐ์ negative samples $\{ h_{o_1}, ..., h_{o_k}^{ir} \}$๋ฅผ ์ถ์ถํ๋ค.
state-object interaction์ ๊ณ ๋ คํ๊ธฐ ์ํด $D_{ir}$๋ก๋ถํฐ ์ถ์ถํ๋ negative samples๋ ๋ ๊ฐ์ง์ loss์์ ๋์ผํ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๋ค. ($\tau_o > 0$ : temperature parameter)
\[ \mathcal{L}_{ocl} = -log \frac{exp((h_o)^{\top} h_o^{os} / \tau_o)}{exp((h_o)^{\top} h_o^{os}/\tau_o) + \sum\nolimits_{j=1}^K exp((h_o)^{\top} h_{o_j}^{ir}/ \tau_o)} \] - Classification Loss $\mathcal{L}_{cls}$
Classifier๊ฐ state์ object ๊ฐ๊ฐ์ prototype์ ํตํด ๊ตฌ๋ณํ ์ ์๋๋ก, ๋ ๊ณต๊ฐ์์์ classification loss๋ฅผ ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐํ๋ค. $C_a$ ์ $C_o$๋ฅผ state์ object ๊ฐ๊ฐ์ ๋ํ classification์ ํ๋ fully connected layers๋ผํ ๋, ์ ์ฒด classification loss๋ ๋ค์๊ณผ ๊ฐ๋ค.
\[ \mathcal{L} = C_a(h_s, a) + C_o(h_o, o) \]
์์์ ์ ์ํ ์ธ ๊ฐ์ง์ loss๋ฅผ ํตํด Siamese Contrastive Space์ ์ ์ฒด loss $L_{cts}$๊ฐ ์ ์๋๋ค.
\[ \mathcal{L}_{cts} = \mathcal{L}_{scl} + \mathcal{L}_{ocl} + \mathcal{L}_{cls} \]
- Module 3. Augmentation
Training data์ ๋ฑ์ฅํ์ง ์๋ unseen composition์ ๋ํ ์ผ๋ฐํ ์ฑ๋ฅ์ ๋์์ผ๋ก์จ training๊ณผ test ์ฌ์ด์ ์ฐจ์ด๋ฅผ ์ค์ด๊ธฐ ์ํด, ๊ฐ์์ composition์ ์์ฑํ๋ State Transition Module (STM) ๊ตฌ์กฐ๋ฅผ ์ ์ํ๋ค. STM์ Training data์ ๋ ๊ฐ์ง์ composition, "sliced apple", "red fox"๊ฐ ์๋ค๊ณ ํ์ ๋, ๋ฐ์ดํฐ์๋ ์์ง๋ง ์ค์ ๋ก ์กด์ฌํ ๋ฒํ "red apple"์ ์์ฑํ๋, ๋ฐ์ดํฐ์๋ ์๊ณ ์ค์ ๋ก๋ ์กด์ฌํ์ง ์๋ "sliced fox"๋ ๊ตฌ๋ณํด๋ด๋ ๊ฒ์ ๋ชฉํ๋ก ํ๋ค.
- ์
๋ ฅ ์ด๋ฏธ์ง์ object์ ๋ค๋ฅธ ์ด๋ฏธ์ง๋ค์ ๋ค์ํ state๋ฅผ ์กฐํฉํ๊ธฐ ์ํด ๊ฐ๊ฐ์ prototype์ ์ป๋๋ค. Object-specific encoder๋ฅผ ํตํด input $x$์ object prototype $h_o$๋ฅผ ์ถ์ถํ๊ณ , state-specific encoder๋ฅผ ํตํด ๋ค๋ฅธ ์ํ $\{ s_1, s_2, ..., s_n \}$์ state prototype $ h_{\tilde{s}} = \{ h_{s_1}, h_{s_2}, ..., h_{s_n} \}$๋ฅผ ์ถ์ถํ๋ค.
- Generator $G$ ๋ ์ถ์ถํ prototype์ ์กฐํฉํ์ฌ ๊ฐ์์ composition์ ๋ง๋ ๋ค.
\[ G(h_{\tilde{s}}, h_o) = \hat{x}_{\tilde{s},o} \] - Discriminator $D$ ๋ ์์ฑ๋ ๊ฐ์์ composition ๋ด์์ ์ค์ ๋ก ์กด์ฌํ์ง ์์๋งํ ๋ฐ์ดํฐ (irrational composition)์ ํ๋ณํ๋ค.
\[ \underset{D}{max} \underset{G,E_s,E_o}{min} V(G, D) = \mathbb{E}_{s,o} (logD(x_{a,o})) + \mathbb{E}_{h_{\tilde{s}},h_o} (log(1-D(G(h_{\tilde{s}},h_o)))) \] - ์๋ก ์์ฑ๋ ๋ฐ์ดํฐ๋ก $E_s$ ์ $E_o$ ์ ์ฑ๋ฅ์ ๋์ด๋ ๊ฒ์ด ๋ชฉ์ ์ด์ง๋ง, ์์ฑ๋ ์ด๋ฏธ์ง๋ label์ด ์๊ธฐ ๋๋ฌธ์ re-encode ๊ณผ์ ์ ๊ฑฐ์น๋ค. ๋ค์ ์ธ์ฝ๋ฉ ๊ณผ์ ์ ๊ฑฐ์ณ์ ์ถ์ถ๋ state/object prototype์ ํตํด ํ์ตํ๋ re-classification loss๋ฅผ ์ ์ํ ์ ์๋ค.
\[ \mathcal{L}_{cls_{re}} = C_a(E_s(G(h_{\tilde{s}}, h_o)), \tilde{a}) + C_o(E_o(G(h_{\tilde{s}}, h_o)), o) \]
์์์ ์ ์ํ ๋ ๊ฐ์ง์ loss๋ฅผ ํตํด State Transition Module (STM)์ ์ ์ฒด loss $\mathcal{L}_{stm}$๊ฐ ์ ์๋๋ค.
\[ \mathcal{L}_{stm} = \underset{D}{max} \underset{G,E_s,E_o}{min} V(G, D) + \mathcal{L}_{cls_{re}} \]
๋ณธ ๋ ผ๋ฌธ์์ ์ ์ํ SCEN framework์ final loss๋ $\mathcal{L}_{cts}, \mathcal{L}_{stm}$์ weighted sum์ผ๋ก ์ ์๋๋ค.
\[ \mathcal{L}_{total} = \alpha \mathcal{L}_{cts} + \beta \mathcal{L}_{stm} \]
Results
CZSL์ ๋ํ์ ์ธ benchmark dataset ์ธ ๊ฐ์ง์์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋ค : MIT-States, UT-Zappos, C-GQA
MIT-States ๋ฐ์ดํฐ์ ์์ test AUC score๋ฅผ ๊ธฐ์ค์ผ๋ก ํ์ ๋ ๊ธฐ์กด์ SOTA์ธ 5.1%๋ฅผ ๋ฐ์ด๋๋ 5.3% (+0.2%)๋ฅผ ๊ธฐ๋กํ์ผ๋ฉฐ, Harmonic Mean(HM) score๋ฅผ ๊ธฐ์ค์ผ๋ก ํ์ ๋ 18.4% (+1.2%)๋ฅผ ๊ธฐ๋กํ๋ค. State์ object ๊ฐ๊ฐ์ ๋ํ ์์ธก์ accuracy๋ก ๋ณด์๋, 28.2% (+0.3%)์ 32.2% (+0.4%)์ ์ต๊ณ ์ฑ๋ฅ์ ๋ณด์ธ๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก UT-Zappos์์๋ SOTA performance๋ฅผ ๊ธฐ๋กํ์๋ค.
๊ฐ์ฅ ์ต๊ทผ ๋ฐํ๋ C-GQA ๋ฐ์ดํฐ์ ์์๋ AUC, HM, state/object accuracy ๋ชจ๋ ํฅ์๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋ค.