인공지능

torch tensor 반복 확장하기 expand vs. repeat vs. repeat_interleave 차이

감자156 2023. 4. 17. 14:27
반응형

pytorch에서 텐서를 반복하여 확장하는 방법들을 정리하겠음.

1) torch.expand()

dim=1인 축에 대해서만 확장이 가능하며, shape을 input으로 받음 

 

사용 예)

 

(1,2,2,3)의 shape을 가지는 tensor

 

dim 0만 3배 확장하기

 

여기서, 1이 아닌 차원을 확장하려 하면 runtime error 발생. 

 

dim=2를 4로 확장하려고 하면 runtime error 발생함.

 

해결방법 repeat이나 repeat_interleave 사용하면 됨. 

 

 

cf) expand_as() : ref) https://pytorch.org/docs/stable/generated/torch.Tensor.expand_as.html

ref ) https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html

2) torch.repeat()

단순히 인풋 텐서를 원하는 shape으로 차곡차곡 쌓음

 

사용 예)

 

(1,2,2,3)의 shape을 가지는 tensor

 

0 dim에 대해 2배 반복

 

1 dim에 대해 2배 반복

 

ret) https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html

3) torch.repeat_interleave()

 

원하는 행별로 반복할 수 있어서 numpy의 np.repeat 처럼 사용 가능함.

 

사용 예)

 

(1,2,2,3)의 shape을 가지는 tensor

 

 

dim 0에 대해 2배 확장

 

dim 1에 대해 2배 확장

 

이런식으로 사용했을 때, 원하는 차원 축을 반복하여 확장할 수 있어 model feeding 할 때 유용하게 사용함.

 

ref) https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html

반응형