Hello Potato World
[ํฌํ ์ดํ ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] Learning Deep Features for Discriminative Localization ๋ณธ๋ฌธ
[ํฌํ ์ดํ ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] Learning Deep Features for Discriminative Localization
Heosuab 2021. 8. 8. 04:32
โ ๏ฝก ห โ๏ธ ห ๏ฝก โ ๏ฝก ห โฝ ห ๏ฝก โ
[XAI paper review]
Interpretable Machine Learning ์ฑ ์์ ์๊ฐ๋ Grad-CAM์ ๋ด์ฉ์ด ๊ถ๊ธํด์ ธ์ ์ฐพ์๋ณด๋ค๊ฐ, ์ฐ๊ด๋ Grad-CAM, Grad-CAM++, Guided Grad-CAM ๋ฑ๋ฑ์ ๊ธฐ๋ฐ์ด ๋๋ CAM(Class Activation Maps)์ ๋ค๋ฃจ๋ ๋ ผ๋ฌธ์ ๋จผ์ ๋ฆฌ๋ทฐํ๊ฒ ๋์๋ค.
Global Average Pooling(GAP) vs Global Max Pooling(GMP)
์ด ๋ ผ๋ฌธ์์ ๊ฐ์ฅ ์ค์ํ ๊ฐ๋ ์ธ Global Average Pooling์ ๋จผ์ ๋ณด๋ฉด, ์ฐ์ Pooling layer๋ CNN ๋ด์ ๋ง์ Convolution layer๋ด์ ์กด์ฌํ๋ filter(parameter)์ ๊ฐ์๊ฐ ๋๋ฌด ๋ง์์ ธ์ Overfitting์ด ๋ฐ์ํ๋ ๊ฒ์ ๋ฐฉ์งํ๊ธฐ ์ํด, parameter ์๋ฅผ ์ค์ผ ์ ์๋๋ก ์ฌ์ฉ๋๋ layer์ด๋ค.
๊ทธ ์ค Max Pooling์ด๋ ๊ฐ ์์ญ(์ง์ญ) ๋ด์์ ๊ฐ์ฅ ํฐ ๊ฐ์ ์ ํํ์ฌ ์ค์ด๋ ๋ฐฉ์์ด๊ณ , Global Max Pooling์ด๋ ์ ์ฒด ์์ญ(์ ์ญ)์ ํ๋ฒ์ ๊ณ ๋ คํด์ (heigt, width, channel)ํํ์ 3์ฐจ์์ (channel, )ํํ์ 1์ฐจ์ ๋ฒกํฐ๋ก ๊ทน๋จ์ ์ธ feature์ ๊ฐ์๋ฅผ ๋ง๋๋ ๋ฐฉ์์ด๋ค.
๋, ์ ์ฒด ์์ญ ๋ด์์ ๊ฐ์ฅ ํฐ ๊ฐ์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ Global Max Pooling(GMP)๋ผ๊ณ ํ์ง๋ง, ๋ชจ๋ ๊ฐ์ ๊ณ ๋ คํ์ฌ ํ๊ท ๊ฐ์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ Global Average Pooling(GAP)์ด๋ผ๊ณ ํ๋ค.
๋ณดํต CNN์ ๊ตฌ์กฐ์์๋ ๋ง์ง๋ง layer๋ก FC layer๋ฅผ ์ฌ์ฉํ๊ฒ ๋๋๋ฐ, ์ด FC layer๋ parameter์ ๊ฐ์๋ฅผ ๋งค์ฐ ์ปค์ง๋๋ก ๋ง๋ค๊ธฐ ๋๋ฌธ์ overfitting ์ํ์ด ์ฆ๊ฐํ ์ ์๊ณ , Feature map(pooling์ด์ )์ ์กด์ฌํ๋ object๋ค์ ์์น์ ๋ณด๊ฐ ์์ค๋๋ค๋ ๋จ์ ์ด ์๋ค. ์ด ๋ ผ๋ฌธ์์๋ CNN์ ๋ง์ง๋ง FC layer๋ฅผ Global Average Pooling์ผ๋ก ๋์ฒดํ์ฌ overfitting์ ๋ฐฉ์งํ ์ ์๋ regularization์ ์ญํ ์ ํ๋ฉฐ, ์์น์ ๋ณด๋ฅผ ์์คํ์ง ์์ ์ ์๋๋ก ํ์๋ค.
Learning Deep Features for Discriminative Localization
์ด ๋ ผ๋ฌธ์์ ์ ์ํ CAM(Class Activation Maps)์ key point 2๊ฐ์ง๋ ๋ค์๊ณผ ๊ฐ๋ค.
- Weakly-supervised object localization
- Visualizing CNNs
์์์ ์ธ๊ธํ๋ ๊ฒ์ฒ๋ผ FC layer๋์ ์ Global Average Pooling์ ์ฌ์ฉํจ์ผ๋ก์จ ์์น์ ๋ณด๋ฅผ ์์คํ์ง ์์ ์ ์๋๋ก ๋ง๋ค์๋๋ฐ, ๋๋ถ์ ๋จ ํ๋ฒ์ forward-pass๋ง์ ํตํด ์ฌ๋ฌ๊ฐ์ง Task๋ฅผ ์ํํ๊ฒ ๋์๋ค. ์๋ฅผ ๋ค์ด, Object Classification๋ง์ ์ํด ํ์ต๋ CNN ๋ชจ๋ธ์ด ์ด๋ฏธ์ง๋ฅผ classifyํ ์ ์์๋ฟ๋ง ์๋๋ผ localization๋ ์ํํ ์
์๊ฒ ๋์๋ค. ์ฆ, ๊ฐ ์ด๋ฏธ์ง์ label๋ง ์ฃผ์ด์ง ์ํฉ์์ ์ฃผ์ด์ ธ์์ง ์์ localization์ ๋ณด๋ฅผ ์์ธกํ ์ ์๊ฒ ๋๋ค. (Weakly supervised learning : ํ์ต์ ์ฃผ์ด์ง ์ ๋ณด๋ณด๋ค ์์ธกํ๋ ค๋ ์ ๋ณด๊ฐ ๋ ๋ํ ์ผํ ๊ฒฝ์ฐ)
์๋ ๊ทธ๋ฆผ์ Global Average Pooling์ ์ฌ์ฉํ์ฌ CAM์ ์๊ฐํํ ๊ฒ์ธ๋ฐ, ๊ฐ ์ด๋ฏธ์ง๋ค์ ๋ํด classifyํ๋ฉด์๋ object๋ค์ด ์์นํ๋ ์์ญ๋ ์ฐพ์๋ผ ์ ์์์ ๋ณผ ์ ์๋ค.
Class Activation Mapping
CAM์ ์์์ ๋ณธ ๊ทธ๋ฆผ์ฒ๋ผ, CNN์ด input image์ ๋ํ prediction์ ๋ง๋ค์ด๋์ ๋, ํด๋น class๋ก ํ๋ณํ๋๋ฐ ์ค์ํ๊ฒ ์๊ฐํ๋ ์์ญ์ ํ์ํ์ฌ ์๊ฐํํ๋ ์๊ณ ๋ฆฌ์ฆ์ด๋ค. ๊ตฌ์กฐ๋ ์๋์ ๊ฐ๋ค.
1. ๋ง์ง๋ง Convolution layer์ feature map์ $f_k(x,y)$๋ผ๊ณ ํ๋ฉด, ๊ฐ๊ฐ์ unit $k$์ ๋ํด GAP์ ์ํํด์ $k$๊ฐ์ ๊ฐ์ ์ถ๋ ฅํ๋ค. (GAP์ ๊ฒฐ๊ณผ $F_k$)
2. ๊ฐ๊ฐ์ $F_k$์ ๋ํด์, class c์ ๋ํ ๊ฐ์ค์น $w_k^c$์ weighted sum์ ๊ณ์ฐํ์ฌ $S_c$๋ฅผ ์ถ๋ ฅํ๋ค. ์ด ๋์ ์ถ๋ ฅ $S_c$๋ softmax์ input์ผ๋ก ์ฌ์ฉ๋๋ค.
3. Softmax์ฐ์ฐ์ ๊ฑฐ์น๋ฉด ๊ฐ class c์ ๋ํ ๊ฒฐ๊ณผ $P_c$๊ฐ ์ถ๋ ฅ๋๋ค. bias๋ classification์ ์ฑ๋ฅ์ ์ํฅ์ ๊ฑฐ์ ๋ฏธ์น์ง ์๋๋ค๊ณ ๊ฐ์ ํ๊ณ , bias๋ 0์ผ๋ก ์ค์ ํ์ฌ ๊ณ์ฐํ๋ค.
4. Class c์ ๋ํ CAM์ $M_c$๋ผ๊ณ ์ ์ํ๊ณ , $S_c$์ ์์์ ๋ณํํ์ฌ ๊ตฌํ ์ ์๋ ํํ๋ก ์ฌ์ฉํ๋ค.
๋ฐ๋ผ์ ๊ณต๊ฐ์ ์ขํ(x,y)์ activation๊ฐ์ ์ค์๋๋ฅผ ๋ํ๋ด๋ ๊ฐ CAM $M_c(s,y)$๋ ์ด๋ฏธ์ง๊ฐ class C๋ก classify๋๋๋ฐ ๋ฏธ์น๋ ์ํฅ์ ๋ํ๋ธ๋ค.
๋ง์ง๋ง convolution layer์์์ CAM์ ์๊ฐํํ๊ธฐ ๋๋ฌธ์, ์ต์ข CAM์ ์ฒ์ input image์ ๊ฐ์ ํฌ๊ธฐ๋ก unsamplingํ๋ฉด, input image๋ด์์ class c์ ๊ด๋ จ๋์ด์๋ ์์ญ์ด ์ด๋์ธ์ง ํ์ธํ ์ ์๋ค. ์์ ๊ทธ๋ฆผ์์๋ input image๊ฐ "Australian terrier"์ class๋ก ๊ตฌ๋ถ๋๋๋ฐ ์ํฅ์ ๋ฏธ์น๋ ์์ญ์ ํ์ด๋ผ์ดํธํ CAM ๊ฒฐ๊ณผ๋ฅผ ๋ณผ ์ ์๊ณ , ๊ฐ์์ง๊ฐ ์์นํ ์์ญ์ localization๋ ํจ๊ป ์ํํ ๊ฒ์ ์ ์ ์๋ค.
Results
ILSVRC์ ์ด 4๊ฐ์ง class์ CAM์ ์๊ฐํํ ๊ทธ๋ฆผ.
์ฒซ๋ฒ์งธ์ ๋๋ฒ์งธ ๊ทธ๋ฆผ์์๋ "briard"์ "hen"์ ๋จธ๋ฆฌ ๋ถ๋ถ์ด prediction์ ํฐ ์ํฅ์ ๋ฏธ์ณค๊ณ , ์ธ ๋ฒ์งธ ๊ทธ๋ฆผ์์์ "barbell"์ ์ํ ๋ถ๋ถ, ๋ค ๋ฒ์งธ ๊ทธ๋ฆผ์์์ "bell cote"์ bell ๋ถ๋ถ์ด ์ํฅ์ ๋ง์ด ๋ฏธ์น ๊ฒ์ ๋ณผ ์ ์๋ค.
์ด๋ฒ์ ํ๋์ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ง๊ณ , ์ฌ๋ฌ class์ ๋ํ CAM์ ์์ธกํด์ ์๊ฐํํ ์๋ฃ์ด๋ค.
์ค์ Ground Truth label์ Dome์ด๊ณ , ๊ฐ์ฅ ํ๋ฅ ์ด ๋์ Top5์ class์ ๋ํด ์๊ฐํํ์๋ค. ํ๋์ class์ ๋ํด ์ฌ๋ฌ ์ด๋ฏธ์ง์ CAM์ ๋น๊ตํ์ ๋๋, ํด๋น class๋ฅผ ๋ํํ ์ ์๋ ํน์ง๋ค์ ์ผ๊ด๋๊ฒ ํ์ด๋ผ์ดํธํ ๊ฒ์ ๋ณผ ์ ์์์ง๋ง, ์ด ๊ทธ๋ฆผ์ฒ๋ผ ์ฌ๋ฌ class์ ๋ํ CAM์ ๋น๊ตํ์ ๋๋ ๊ฐ๊ฐ ๋ค๋ฅธ ๋ถ๋ถ๋ค์ ํ์ด๋ผ์ดํธํ ๊ฒ์ ๋ณผ ์ ์๋ค.
CAM์ผ๋ก ๊ตฌํด์ง segmentation map์ ์ ๋ถ ์ปค๋ฒํ ์ ์๋ ๊ฐ์ฅ ํฐ bounding box๋ฅผ ์์ฑํด์ Localization์ ์ํํ๋ค. ๊ฐ a)์ b)์ ์๋จ ๊ทธ๋ฆผ์ GoogleNet-GAP๋ก ๊ตฌํด์ง ๊ฒฐ๊ณผ์ด๊ณ , ํ๋จ ๊ทธ๋ฆผ๋ค์ AlexNet์ ์ฌ์ฉํ ๊ฒฐ๊ณผ์ด๋ค.
๊ฐ ์ด๋ฏธ์ง์์ Gound Truth๋ ๋ น์์ผ๋ก ํ์๋ Bounding box์ด๋ฉฐ, CAM์ ์ฌ์ฉํ์ฌ ์์ธกํ box๋ ๋นจ๊ฐ์์ผ๋ก ํ์๋์๋ค.
References
[1] Zhou et al, Learning Deep Features for Discriminative Localization, 2016