cvpr 2021 LIIF continuous image neural representation 연구,, 달러($)표시가 있는 문자가 보이면 수식 로딩중이니 새로고침
Abstract
본 논문에서는 이미지를 연속적으로 표현하는 법을 제안한다. 3D reconstruction에서의 *implicit nueral representation에 영감을 받아 이미지의 좌표, 좌표 주변 2D deep feature를 input으로 받아 주어진 좌표의 output의 RGB value를 예측하는 Local Implicit Image Function(LIIF)를 제안한다. 좌표가 연속적이기 때문에 LIIF도 arbitrary resolution을 표현할 수 있다.
이미지의 연속적인 표현을 생성하기 위해 본 논문에서는 encoder와 super-resolution의 self-supervised 방법을 통해 LIIF representation을 학습한다. 학습된 continuius representation은 arbitrary resolution을 나타낼 수 있다.(학습되지 않은 x30 해상도까지 확장 가능). 더욱이 LIIF 표현법은 2D에서 이산적 표현과 연속적 표현 간의 다리 역할을 해주고, 다양한 사이즈의 gt에 대한 학습 task를 지원하며 gt를 resize 하는 방법보다 훨씬 성능이 좋다.
*implicit neural representation
neural implicit representation이란 특정 input을 신경망의 output으로 표현하는 표현법을 말한다. 신경망을 통해 output이 결정되기 때문에 그 함수를 정의할 수 없어 "implicit"이라고 함. 이때의 신경망을 nerual implicit fuction이라고 한다.
예를 들어, 이미지 좌표의 x,y를 넣었을 때, f(x, y) = (a1, a2, a3, a4) 꼴로 표현하는 신경망 f(.)가 있다고 하면, 이때의 f(.)는 정의할 수 없는 식(신경망)이기 때문에 neural implicit function이고, 입력 x, y는 f(x, y), 다시 말해 (a1, a2, a3, a4)로 neural implicit representation 할 수 있다.
Problem Definition
보통 이미지는 이미지 픽셀의 2D array로 표현하고 저장하게 되는데, 이러면 complexity와 precision간의 상충 관계(trade-off)가 해상도(resolution)에 의해 결정된다. 이와 같은 pixel-based 표현은 다양한 vision task에 성공적으로 적용되지만, resolution에 제약이 있다는 단점을 가진다. 예를 들어, 우리가 convolution 신경망을 훈련하고자 할 때, 보통 이미지를 같은 사이즈로 resize 해야 해서, fidelity를 포기해야 하는 경우가 생긴다. 이미지를 고정된 resolution으로 표현하는 것 대신에, 본 논문에서는 continuous representation을 제안한다. 이미지를 연속적인 도메인에서의 function으로 모델링하여, 본 논문에서는 이미지를 원하는 임의의 resolution으로 저장하거나 생성할 수 있다.
Introduction
어떻게 이미지를 연속적인 함수로 표현할 수 있을까?
본 연구는 최근 발전하고 있는 implicit neural representaion에서 영감을 받았다. implicit neural representation의 key idea는 좌표를 관련 신호(signal)로 mapping하는 function(신경망)으로 객체를 표현하는 것이다. 각 객체마다 각각의 함수를 fitting하는 것 대신에 객체 전체가 지식을 공유하기 위해 encoder-based 방법들이 제안되었다. encoder-based 방법들은 각각의 객체들마다의 latent code를 예측하고, decoding 함수는 모든 객체가 공유하며 좌표 input에 추가적인 input으로 latent code를 받는 방식이다. 이런 방법이 3D에서는 성공했음에도 불구하고, 이미지를 digit처럼 단순하게 표현해서 높은 fidelity를 가진 진짜 natural image를 표현하는 데는 실패했다. 본 논문에서는 자연스럽고 복잡한 이미지를 연속적인 방법으로 표현하기 위한 Local Implicit Image Function (LIIF)를 제안한다.
contirbution
1) 자연적이고 복잡한 이미지를 연속적인 방법으로 표현하는 novel method를 제안함
2) 학습되지 않은 x30배의 resolution 확장도 가능함
3) 다양한 사이즈의 GT image를 학습하는 task들에 효과적임.
Method
Local Implicit Image Function
LIIF representation에서, 각각의 continuous image $I^{(i)}$는 2D feature map $M^{(i)}$ 로 표현된다. decoding function $f_{\theta}$ 는 모든 이미지가 공유하며, MLP로 매개변수화 된다.
$s = f_{\theta}(z,x)$
$z$ : 벡터
$x$ : 좌표
$s$ : 는 예측된 signal ( RGB 값 )
$x$의 범위는 두 차원에 대해 $[0,2H]$, $[0,2W]$으로 가정했다. ( 범위가 왜 [0, 2x] 냐면 나중에 [-1,1]로 vs 범위를 지정함 ). 정의된 $f_{\theta}$ (좌표 -> RGB값 mapping 함수)에서 각각의 벡터 $z$는 $f_{\theta}(z,.)$ 로 표현할 수 있다.
본 논문에서는 2D feature map $M^{(i)}$의 $H \times W$ feature vector(latent code)가 이미지 $I^{(i)}$의 연속적인 이미지 도메인의 2D 공간에서 고르게 펴져있다고 가정했고(한 픽셀당 하나, 그림 2의 파란 동그라미) 각각의 feature vector에 2D 좌표를 할당했다. 연속적인 이미지 $I^{(i)}$의 좌표 $x_{q}$에서의 RGB 값은 다음과 같이 정의된다.
$z^{*}$ : $M^{(i)}$ 의 $x_{q}$에서 가장 가까운 latent code
$v^{*}$ : 이미지 도메인에서의 latent code $z^{*}$의 좌표
그림 2를 예시로 보면, $z_{11}^{*}$는 정의한 대로 $x_{q}$에 대한 $z^{*}$이다. $v^{*}$ 는 $z_{11}^{*}$의 좌표에 의해 정의되고, $x_{q}$의 RGB 값은 $x_{q}$와 가장 가까운 픽셀의 latent code $z^{*}$와 그 latent code 좌표까지의 거리(offset) $x_{q} - v^{*}$ 으로 결정된다.
Local implicit image function을 요약하면, 함수 $f_{\theta}$는 모든 이미지가 공유하고, 연속적인 이미지는 2D feature map $M^{(i)}$로 표현될 수 있다. 2D feature map $M^{(i)}$는 2D 도메인에 고르게 뿌려져 있는 $H \times W$개의 latent code라고 볼 수 있다. 2D feature map $M^{(i)}$의 개별 latent code $z$는 연속적인 이미지의 어떤 local piece를 표현하고, latent code 자기 자신과 가장 가까운 좌표 집합의 signal을 예측하는 것을 담당한다.
Feature unfolding
$M^{(i)}$안의 각각의 lartent code가 포함하는 정보를 더 풍부하게 하기 위해, 본 논문에서는 $M^{(i)}$를 feature unfolding 하여 $\widehat {M} ^ {(i)}$를 구했다. $\widehat{M} ^ {(i)}$에서의 latent code는 ${M} ^ {(i)}$의 이웃하는 3X3 latent code를 concatenation한 것이다. 식으로 정리하면 feature unfolding은 다음과 같이 정의된다.
feature unfolding 후에, $\widehat{M} ^ {(i)}$는 모든 계산에서 ${M} ^ {(i)}$를 대체한다.
Local ensemble
식 2는 불연속적인 예측에 대해 문제가 있다. 특히, $x_{q}$에서의 signal prediction은 $M^{(i)}$의 nearest latent code $z^{*}$를 쿼리 하여 수행되는데, $x_{q}$가 2D 도메인으로 이동하면, 선택되는 $z^{*}$가 갑자기 다른 값으로 바뀔 수 있다 (즉, nearest latent code가 바뀔 수 있다). 예를 들어, 그림 2에서 $z_{11}^{*}$로 바뀔 때, 점선 부분을 지날 때, 거의 같다고 볼 수 있을 만큼 점선 전후의 가까운 두 지점의 값이 다른 latent code로부터 예측되게 된다는 문제가 있다. 그래서 $z^{*}$가 바뀌는 경계에서 불연속적인 패턴이 나타날 수 있다. 이 문제를 해결하기 위해, 식 2를 다음과 같이 확장했다.
이미지의 한 좌표 $x_{q}$값과 가장 가까운 $z^{*}$로 latent code 값 하나로 representation을 구하는 게 아니라, $x_{q}$주변의 왼쪽 위, 오른쪽 위, 왼쪽 아래, 오른쪽 아래 부분 공간의 nearest latent code로 representation을 구한다.
그림 2에서 원래는 $x_{q}$와 가장 가까운 $z_{11}^{*}$만 사용했다면, 확장한 식에서는 $z_{00}^{*}$, $z_{01}^{*}$, $z_{10}^{*}$, $z_{11}^{*}$ 네 개의 값을 사용한다. $S_t$는 $v_t'^{*}$와 $x_{q}$사이의 직사각형 영역을 말한다.(t' : t와 대각선 맞은편) 그림 2에서 확인할 수 있다.
이렇게 설정함으로써 의도적으로 local latent code로부터 표현되는 local pieces가 그 주변 pieces와 오버랩되게 한다. 그래서 각 좌표마다 독립적으로 signal을 예측하는 4개의 latent code가 가지게 된다. 각 latent code에 대한 4개의 prediction은 normalized confidence에 의해 voting 되어 합쳐지고, 이 confidence는 직사각형 영역 $S$에 비례한다. 결국 query 좌표가 latent code 가까울수록 맞은편 latent code과 query point 간의 영역이 커지므로, confidence가 높아지게 된다. 그래서 $z^{*}$가 바뀌는 지점에서도 연속적인 transition을 가능하게 한다.
Cell decoding
본 논문에서는 LIIF representation이 임의의 해상도에서도 pixel-based 형태로 표현될 수 있도록 만들고자 한다. 원하는 해상도가 주어졌다고 가정했을 때, pixel-based 형태로 나타내는 가장 직관적인 방법은 연속적인 표현 $I^{(i)}(x)$ 에서의 pixel 중앙값의 좌표에서 rgb 값을 쿼리 하는 것이다. 이 방법이 잘 동작하긴 하지만, query pixel의 예측된 RGB값은 size에 독립적이고 center value 외의 픽셀 영역의 정보들은 모두 무시되기 때문에 최적의 방법은 아니다.
이 문제를 해결하기 위해서 본 논문에서는 그림 3과 같이 cell decoding을 추가하여 식 1을 재정의했다.
이 식은 만약 좌표 x에 중심을 맞춘 모양 c로 픽셀을 렌더링 할 때의 RGB 값은 어떠해야 하는지를 나타낸다.
Learning Conrinuous Image Representation
마지막으로 학습 방법이다.
먼저 이미지를 이미지의 LIIF 표현으로 encoding 해줄 encoder를 학습한다. neural implicit function $f_{\theta}$는 모든 이미지가 공유한다. 생성된 LIIF 표현은 input을 복원할 수 있을 뿐만 아니라, 더 중요한 건 연속적인 표현이기 때문에 더 높은 해상도를 표현할 때도 high fidelity를 유지해야 한다. 그러므로, 본 논문에서는 super-resolution의 self-supervised task의 프레임워크를 훈련했다.
그림 4처럼 하나의 학습용 이미지를 예시로 보자. 먼저, 이미지를 random 한 스케일로 down sampling 해 input을 생성한다. 그리고 GT는 학습용 이미지의 픽셀 샘플 $x_{hr}$, $s_{hr}$로 표현할 수 있다. $x_{hr}$은 이미지 도메인에서 픽셀의 중앙 좌표이고, $s_{hr}$은 픽셀의 RGB값이다. 좌표 $x_{hr}$는 LIIF 표현을 쿼리 하는 데 사용되고, $s_{pred}$는 예측되는 signal을 의미하고, L1 loss를 통해 $s_{hr}$간의 차이로 학습된다.
코드
https://github.com/yinboc/liif/blob/main/models/liif.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import models
from models import register
from utils import make_coord
@register('liif')
class LIIF(nn.Module):
def __init__(self, encoder_spec, imnet_spec=None,
local_ensemble=True, feat_unfold=True, cell_decode=True):
super().__init__()
self.local_ensemble = local_ensemble
self.feat_unfold = feat_unfold
self.cell_decode = cell_decode
self.encoder = models.make(encoder_spec)
if imnet_spec is not None:
imnet_in_dim = self.encoder.out_dim
if self.feat_unfold:
imnet_in_dim *= 9
imnet_in_dim += 2 # attach coord
if self.cell_decode:
imnet_in_dim += 2
self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})
else:
self.imnet = None
def gen_feat(self, inp):
self.feat = self.encoder(inp)
return self.feat
def query_rgb(self, coord, cell=None):
feat = self.feat
if self.imnet is None:
ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
return ret
if self.feat_unfold:
feat = F.unfold(feat, 3, padding=1).view(
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
if self.local_ensemble:
vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6
else:
vx_lst, vy_lst, eps_shift = [0], [0], 0
# field radius (global: [-1, 1])
rx = 2 / feat.shape[-2] / 2
ry = 2 / feat.shape[-1] / 2
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
.permute(2, 0, 1) \
.unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
preds = []
areas = []
for vx in vx_lst:
for vy in vy_lst:
coord_ = coord.clone()
coord_[:, :, 0] += vx * rx + eps_shift
coord_[:, :, 1] += vy * ry + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_feat = F.grid_sample(
feat, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
q_coord = F.grid_sample(
feat_coord, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
rel_coord = coord - q_coord
rel_coord[:, :, 0] *= feat.shape[-2]
rel_coord[:, :, 1] *= feat.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1)
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1)
bs, q = coord.shape[:2]
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
preds.append(pred)
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9)
tot_area = torch.stack(areas).sum(dim=0)
if self.local_ensemble:
t = areas[0]; areas[0] = areas[3]; areas[3] = t
t = areas[1]; areas[1] = areas[2]; areas[2] = t
ret = 0
for pred, area in zip(preds, areas):
ret = ret + pred * (area / tot_area).unsqueeze(-1)
return ret
def forward(self, inp, coord, cell):
self.gen_feat(inp)
return self.query_rgb(coord, cell)