Hello Potato World
[ํฌํ ์ดํ ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] Siamese Contrastive Embedding Network for Compositional Zero-Shot Learning ๋ณธ๋ฌธ
[ํฌํ ์ดํ ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] Siamese Contrastive Embedding Network for Compositional Zero-Shot Learning
Heosuab 2023. 1. 8. 18:02โ ๏ฝก ห โ๏ธ ห ๏ฝก โ ๏ฝก ห โฝ ห ๏ฝก โ
[Zero-shot learning paper review]
Compositional Zero-shot Learning (CZSL)
์ธ๊ฐ์ ์ด๋ฏธ ์๊ณ ์๋ ๊ฐ์ฒด์ ๋ํ ์ ๋ณด๋ค์ ์กฐํฉํ๊ณ ๊ตฌ์ฑํ์ฌ ์๋ก์ด ๊ฐ์ฒด์ ์ผ๋ฐํํ๋ ๋ฅ๋ ฅ์ ๊ฐ์ง๊ณ ์๋ค. ๋ค์ ๋งํด์, ์ด๋ฏธ ์๊ณ ์๋ "whole apple"๊ณผ "sliced banana"์ ์ ๋ณด๋ฅผ ์กฐํฉํด์ ์๋ก์ด ๊ฐ์ฒด์ธ "sliced apple" ๋๋ "whole banana"๋ฅผ ์๊ฐํด๋ผ ์ ์๋ค. ์ด ๋ฅ๋ ฅ์ AI ์์คํ ์์ ๋ชจ๋ฐฉํ๊ณ ์ ํ๋ task๋ฅผ Compositional Zero-shot Learning (CZSL)์ด๋ผ๊ณ ํ๋ค.

CZSL์์๋ ๊ฐ๊ฐ์ ๊ฐ์ฒด(composition)์ ๋ ๊ฐ์ง์ ๊ตฌ์ฑ ์์, state์ object๋ก ๋ถํดํ๋ค. "whole", "sliced"์ฒ๋ผ ๊ฐ์ฒด์ ์ํ๋ฅผ ํํํ๋ ์์๋ state, "apple", "banana"์ฒ๋ผ ๊ฐ์ฒด์ ํํ์ ์ข
๋ฅ๋ฅผ ๊ตฌ๋ถ์ง๋ ์์๋ object๋ก ์ ์๋๋ค. CZSL์ ๋ชฉํ๋ Training data์ test data๊ฐ ๊ณตํต ์์๋ฅผ ๊ฐ์ง์ง ์๊ณ ๋ถ๋ฆฌ๋์ด ์๋ค๋ ๊ฐ์ ํ์, ์๋ก์ด(unseen) test composition์ ์๋ณํ๋ ๊ฒ์ด๋ค.
State์ object๋ฅผ ๋ชจ๋ ์๋ณํ๊ธฐ ์ํด CZSL์์ ์ฌ์ฉํ๋ ๋ํ์ ์ธ ๋ฐฉ๋ฒ์ผ๋ก๋,
- ๋ ๊ฐ์ง์ classifier๋ฅผ ๋ฐ๋ก ๋์ด state์ object๋ฅผ ๋ฐ๋ก ํ์ตํ๋ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์๋ค. ํ์ง๋ง ์ด ๋ฐฉ๋ฒ์ state-object ์ฌ์ด์ ์ํธ ์์ฉ ๋๋ entanglement๋ฅผ ๋ฌด์ํ๊ฒ ๋๋ค.
- ๋ ๋ชจ๋ composition๊ณผ visual feature๋ค์ด ํ๋ฒ์ ํฌ์๋ ์ ์๋ ๊ณตํต embedding space๋ฅผ ํ์ตํ์ฌ embedding ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ์ฌ ํ์ฉํ๋ ๋ฐฉ๋ฒ์ด ์๋๋ฐ, training๊ณผ test composition ์ฌ์ด์ ์ฐจ์ด๋ฅผ ๋ฌด์ํ์ฌ ๋น์ทํ ๊ฐ์ฒด๋ค์ ํผ๋ํ ์ ์๋ค. (e.g., young cat and young tiger)
๋ฐ๋ผ์ ํด๋น ๋ ผ๋ฌธ์์๋ state์ object ๊ฐ๊ฐ์ ๊ตฌ๋ถ์ ์ธ prototype๋ฅผ ํ์ฉํ์ง๋ง ๋ ์ฌ์ด์ joint representation๋ ํจ๊ป ํ์ตํ๋ Siamese Contrastive Embedding Network (SCEN)์ ์ ์ํ๋ค.
Siamese Contrastive Embedding Network (SCEN)
States์ ์งํฉ์
Training dataset์์์ images ์งํฉ์
CZSL task์ ์ ์์ ๋ฐ๋ผ training data์ test data๋ ๊ณตํต ์์๋ฅผ ๊ฐ์ง์ง ์์์ผ ํ๊ธฐ ๋๋ฌธ์, training data์ composition์

SCEN ๊ตฌ์กฐ๋ฅผ ์ธ ๊ฐ์ง ๋ชจ๋๋ก ์์ฝํ๋ฉด,
- Encoding : states์ objects ๊ฐ๊ฐ์ encoding
- Contrastive learning : state/object ๊ฐ๊ฐ์ contrastive space์์์ prototypes ์ถ์ถ
- Augmentation : State Transition Module(STM)์ ํตํด ๊ฐ์์ composition ์์ฑ
- Module 1. Encoding
ํ๋์ ์ด๋ฏธ์ง๊ฐ feature extractor FC๋ฅผ ํต๊ณผํด์ ์ป์ visual feature
State-object์ ๋ค์ํ ์กฐํฉ์ ํตํด ์ฌ๋ฌ composition์ ๊ตฌ์ฑํ ์ ์๊ธฐ ๋๋ฌธ์, ์ด๋ฌํ joint representation์ ํ์ตํ๊ธฐ ์ํด ์ธ ๊ฐ์ง์ ๋ฐ์ดํฐ๋ฒ ์ด์ค๋ฅผ ์ ์ํ๋ค.
- ๊ณ ์ ๋ state์ ๋ค์ํ objects๋ฅผ ์กฐํฉํ๋ State-constant database
- ๊ณ ์ ๋ object์ ๋ค์ํ states๋ฅผ ์กฐํฉํ๋ Object-constant database
- ๋ค์ํ objects-states ์กฐํฉ ์ค์์ input instance์ ๊ด๋ จ์ด ์๋ Irrelevant database
state
- Module 2. Contrastive learning
- State-based contrastive loss
input ์ state encoding ๊ฐ state-based contrastive space์ anchor๋ก ์ค์ ๋๋ค. input ์ ๋์ผํ state๋ฅผ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค ๋ก๋ถํฐ positive sample ์ ์ถ์ถํ๋ฉฐ, ๋์ผํ์ง ์์ state๋ฅผ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค ๋ก๋ถํฐ ๊ฐ์ negative samples ๋ฅผ ์ถ์ถํ๋ค.
anchor์ positive ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ ๊ฐ๊น์์ง๋๋ก, anchor์ negative ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ ๋ฉ์ด์ง๋๋ก ํ์ตํ๋ contrastive loss๋ ๋ค์๊ณผ ๊ฐ์ด ์ ์๋๋ค. ( : temperature parameter) - Object-based contrastive loss
object-based contrastive space์ anchor๋ ๋ก ์ค์ ๋๋ค. ๋์ผํ object๋ฅผ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค ๋ก๋ถํฐ positive sample ๋ฅผ ์ถ์ถํ๋ฉฐ, ๋์ผํ์ง ์์ object๋ฅผ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค ๋ก๋ถํฐ ๊ฐ์ negative samples ๋ฅผ ์ถ์ถํ๋ค.
state-object interaction์ ๊ณ ๋ คํ๊ธฐ ์ํด ๋ก๋ถํฐ ์ถ์ถํ๋ negative samples๋ ๋ ๊ฐ์ง์ loss์์ ๋์ผํ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๋ค. ( : temperature parameter) - Classification Loss
Classifier๊ฐ state์ object ๊ฐ๊ฐ์ prototype์ ํตํด ๊ตฌ๋ณํ ์ ์๋๋ก, ๋ ๊ณต๊ฐ์์์ classification loss๋ฅผ ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐํ๋ค. ์ ๋ฅผ state์ object ๊ฐ๊ฐ์ ๋ํ classification์ ํ๋ fully connected layers๋ผํ ๋, ์ ์ฒด classification loss๋ ๋ค์๊ณผ ๊ฐ๋ค.
์์์ ์ ์ํ ์ธ ๊ฐ์ง์ loss๋ฅผ ํตํด Siamese Contrastive Space์ ์ ์ฒด loss
- Module 3. Augmentation
Training data์ ๋ฑ์ฅํ์ง ์๋ unseen composition์ ๋ํ ์ผ๋ฐํ ์ฑ๋ฅ์ ๋์์ผ๋ก์จ training๊ณผ test ์ฌ์ด์ ์ฐจ์ด๋ฅผ ์ค์ด๊ธฐ ์ํด, ๊ฐ์์ composition์ ์์ฑํ๋ State Transition Module (STM) ๊ตฌ์กฐ๋ฅผ ์ ์ํ๋ค. STM์ Training data์ ๋ ๊ฐ์ง์ composition, "sliced apple", "red fox"๊ฐ ์๋ค๊ณ ํ์ ๋, ๋ฐ์ดํฐ์๋ ์์ง๋ง ์ค์ ๋ก ์กด์ฌํ ๋ฒํ "red apple"์ ์์ฑํ๋, ๋ฐ์ดํฐ์๋ ์๊ณ ์ค์ ๋ก๋ ์กด์ฌํ์ง ์๋ "sliced fox"๋ ๊ตฌ๋ณํด๋ด๋ ๊ฒ์ ๋ชฉํ๋ก ํ๋ค.

- ์
๋ ฅ ์ด๋ฏธ์ง์ object์ ๋ค๋ฅธ ์ด๋ฏธ์ง๋ค์ ๋ค์ํ state๋ฅผ ์กฐํฉํ๊ธฐ ์ํด ๊ฐ๊ฐ์ prototype์ ์ป๋๋ค. Object-specific encoder๋ฅผ ํตํด input
์ object prototype ๋ฅผ ์ถ์ถํ๊ณ , state-specific encoder๋ฅผ ํตํด ๋ค๋ฅธ ์ํ ์ state prototype ๋ฅผ ์ถ์ถํ๋ค. - Generator
๋ ์ถ์ถํ prototype์ ์กฐํฉํ์ฌ ๊ฐ์์ composition์ ๋ง๋ ๋ค. - Discriminator
๋ ์์ฑ๋ ๊ฐ์์ composition ๋ด์์ ์ค์ ๋ก ์กด์ฌํ์ง ์์๋งํ ๋ฐ์ดํฐ (irrational composition)์ ํ๋ณํ๋ค. - ์๋ก ์์ฑ๋ ๋ฐ์ดํฐ๋ก
์ ์ ์ฑ๋ฅ์ ๋์ด๋ ๊ฒ์ด ๋ชฉ์ ์ด์ง๋ง, ์์ฑ๋ ์ด๋ฏธ์ง๋ label์ด ์๊ธฐ ๋๋ฌธ์ re-encode ๊ณผ์ ์ ๊ฑฐ์น๋ค. ๋ค์ ์ธ์ฝ๋ฉ ๊ณผ์ ์ ๊ฑฐ์ณ์ ์ถ์ถ๋ state/object prototype์ ํตํด ํ์ตํ๋ re-classification loss๋ฅผ ์ ์ํ ์ ์๋ค.
์์์ ์ ์ํ ๋ ๊ฐ์ง์ loss๋ฅผ ํตํด State Transition Module (STM)์ ์ ์ฒด loss
๋ณธ ๋
ผ๋ฌธ์์ ์ ์ํ SCEN framework์ final loss๋
Results
CZSL์ ๋ํ์ ์ธ benchmark dataset ์ธ ๊ฐ์ง์์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋ค : MIT-States, UT-Zappos, C-GQA

MIT-States ๋ฐ์ดํฐ์ ์์ test AUC score๋ฅผ ๊ธฐ์ค์ผ๋ก ํ์ ๋ ๊ธฐ์กด์ SOTA์ธ 5.1%๋ฅผ ๋ฐ์ด๋๋ 5.3% (+0.2%)๋ฅผ ๊ธฐ๋กํ์ผ๋ฉฐ, Harmonic Mean(HM) score๋ฅผ ๊ธฐ์ค์ผ๋ก ํ์ ๋ 18.4% (+1.2%)๋ฅผ ๊ธฐ๋กํ๋ค. State์ object ๊ฐ๊ฐ์ ๋ํ ์์ธก์ accuracy๋ก ๋ณด์๋, 28.2% (+0.3%)์ 32.2% (+0.4%)์ ์ต๊ณ ์ฑ๋ฅ์ ๋ณด์ธ๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก UT-Zappos์์๋ SOTA performance๋ฅผ ๊ธฐ๋กํ์๋ค.

๊ฐ์ฅ ์ต๊ทผ ๋ฐํ๋ C-GQA ๋ฐ์ดํฐ์ ์์๋ AUC, HM, state/object accuracy ๋ชจ๋ ํฅ์๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋ค.