[PyTorch] PyTorch 다차원 텐서 곱(matmul)

2022. 7. 10. 17:14·프로그래밍 Programming/파이썬 Python

PyTorch에서 텐서끼리의 곱이 나와있는데 그 규칙을 알기 위해 이것저것 해본 것을 기록한 글이다.

 

여기서는 PyTorch의 matmul만 다룬다.

 

https://pytorch.org/docs/stable/generated/torch.matmul.html?highlight=matmul#torch.matmul 

 

torch.matmul — PyTorch 1.12 documentation

Shortcuts

pytorch.org

 

다음 글을 보면

 

1차원 텐서(벡터)나 2차원 텐서(행렬) 곱은 이해하기 어렵지 않다.

맨 마지막 항목에 주목해야한다. 맨 마지막 항목이 두 텐서 중 하나는 1차원 이상, 다른 하나는 N(>3)차원 이상일 때의 곱을 나타낸 것이다.

 

여기서 미리 알아야할 것 ::

Tensor에서 첫 번째 dimension 자리는 batch dimension라고 불려진다.

 

요약하면

1) 첫 번째 argument가 1차원 텐서

 

이 경우에는 batched matrix multiply를 적용한다고 하는데, 잘 이해가 안되어서 다음 예시를 해보았다.

 

[예시]

tensor1 = torch.randn(10)
tensor2 = torch.ones(4, 10, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([4, 5])

 

두 번째 argument의 N차원 텐서의 두 번째 dimension이 일치해야 계산 가능.

두 번째 argument는 3차원 텐서인데 첫 번째 dimension은 batch 축이므로, 그 다음 dimension이 벡터와 차원 일치해야 계산이 가능하다.

 

다음과 같은 경우는 모두 에러를 내뱉는다.

 

 

2) 두 번째 argument가 1차원 텐서

 

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])

 

이번에는 tensor1의 (10,3,4) * tensor2의 (4)

마지막 dimension이 일치해야 계산 가능. 그 외에는 error를 띄운다.

 

 

3) 일반적인 rule

batch size dimension을 제외하고는 broadcast할 수 있는 dimension은 broadcast를 적용한다.

 


계산 동작 방식

 

이번에는 정확히 어떤 방식으로 계산이 되는지 예시를 보면서 확인한다.

 

a = np.ones((2,4))
b = 2 * a
c = 3 * a
d = np.stack((a,b,c), axis=2)
print(f"d matrix : \n{d}")

d_t = torch.from_numpy(d).double()
print(f"size of d : {d_t.shape}")

e = torch.tensor([1,10,100]).double()
print(f"size of e : {e.shape}")

result = torch.matmul(d_t,e)
print(f"**result : \n{result}")
print(f"**size of result : {result.shape}")

이 코드에서 d matrix는

 

$a = \begin{bmatrix}
1 & 1 & 1 & 1 \\
1 & 1 & 1 & 1 \\
\end{bmatrix}$

$b = \begin{bmatrix}
2 & 2 & 2 & 2 \\
2 & 2 & 2 & 2 \\
\end{bmatrix}$

$c=\begin{bmatrix}
3 & 3 & 3 & 3 \\
3 & 3 & 3 & 3 \\
\end{bmatrix}$

을 결합한 3차원 텐서이다.

 

matmul(d_t, e)로 계산했다.

d matrix : 
[[[1. 2. 3.]
  [1. 2. 3.]
  [1. 2. 3.]
  [1. 2. 3.]]

 [[1. 2. 3.]
  [1. 2. 3.]
  [1. 2. 3.]
  [1. 2. 3.]]]
size of d : torch.Size([2, 4, 3])
size of e : torch.Size([3])
**result : 
tensor([[321., 321., 321., 321.],
        [321., 321., 321., 321.]], dtype=torch.float64)
**size of result : torch.Size([2, 4])

여기서는 2(배치 축)를 제외하고

(4,3) * 3 의 곱으로 계산되었다.

또한 pytorch는 출력을 할 때 배치 축을 기준으로 표기를 한다. 그래서 위의 결과 코드를 보면 4*3이 2번(배치 축)으로 출력된 것을 볼 수 있다. 차원이 3인 축을 기준으로 각각 [1,10,100] 벡터가 곱해져서 모두 321이라는 결과가 나온다.


Torch에서는 batch size dimension을 건드리지 않기 때문에 아래와 같이 코드를 짜면 오류가 나온다.

a = np.ones((2,4))
b = 2 * a
c = 3 * a
d = np.stack((a,b,c), axis=0)

print(f"d matrix : \n{d}")
d_t = torch.from_numpy(d).double()
print(f"size of d : {d_t.shape}")

e = torch.tensor([1,10,100]).double()
e = e.view((1,-1))
print(f"size of e : {e.shape}")

result = torch.matmul(e,d_t)
print(f"**result : \n{result}")
print(f"**size of result : {result.shape}")

 

d matrix : 
[[[1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 [[2. 2. 2. 2.]
  [2. 2. 2. 2.]]

 [[3. 3. 3. 3.]
  [3. 3. 3. 3.]]]
size of d : torch.Size([3, 2, 4])
size of e : torch.Size([1, 3])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [94], in <cell line: 15>()
     11 e = e.view((1,-1))
     12 print(f"size of e : {e.shape}")
---> 15 result = torch.matmul(e,d_t)
     16 print(f"**result : \n{result}")
     17 print(f"**size of result : {result.shape}")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (12x2 and 3x1)

d의 사이즈가 (3,2,4) 이고 e가 (1,3)이기 때문에

 

e*d => (1,3) * (3,2,4) 이렇게 될 것 같지만

오류가 난다.


동일하면서 조금 다르게 풀면

a = np.ones((2,4))
b = 2 * a
c = 3 * a
d = np.stack((a,b,c), axis=1)
print(d)

d_t = torch.from_numpy(d).double()
print(f"size of d : {d_t.shape}")

e = torch.tensor([1,10,100]).double()
print(f"size of e : {e.shape}")

result = torch.matmul(e,d_t)
print(f"**result : \n{result}")
print(f"**size of result : {result.shape}")

여기서는 d의 차원을 바꾸고, matmul(e,d_t)로 풀었다.

 

[[[1. 1. 1. 1.]
  [2. 2. 2. 2.]
  [3. 3. 3. 3.]]

 [[1. 1. 1. 1.]
  [2. 2. 2. 2.]
  [3. 3. 3. 3.]]]
size of d : torch.Size([2, 3, 4])
size of e : torch.Size([3])
**result : 
tensor([[321., 321., 321., 321.],
        [321., 321., 321., 321.]], dtype=torch.float64)
**size of result : torch.Size([2, 4])

 

마찬가지로 배치사이즈인 2를 제외하고 (3,4)로 나타나기 때문에

3 * (3,4) 로 문제없이 결과가 나온다.

 

파이토치에서 직접 텐서 계산을 하고 싶다면 이 점을 고려해야한다.

저작자표시 비영리 변경금지 (새창열림)

'프로그래밍 Programming > 파이썬 Python' 카테고리의 다른 글

[PyTorch] DataLoader shuffle 기능 사용 시, RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'  (0) 2022.07.14
[PyTorch] GPU에서 텐서 사용하기  (0) 2022.07.12
[에러기록] matplotlib의 imshow를 쓸 때 커널이 죽는 현상 (추가)  (4) 2022.07.09
[PyTorch] PyTorch에서 GPU 사용  (2) 2022.07.07
[에러기록] (pytorch) RuntimeError: Numpy is not available  (0) 2022.07.06
'프로그래밍 Programming/파이썬 Python' 카테고리의 다른 글
  • [PyTorch] DataLoader shuffle 기능 사용 시, RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
  • [PyTorch] GPU에서 텐서 사용하기
  • [에러기록] matplotlib의 imshow를 쓸 때 커널이 죽는 현상 (추가)
  • [PyTorch] PyTorch에서 GPU 사용
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (468)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability &amp; Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (27)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 실험 Experiment (1)
      • 유학 생활 Daily (8)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    Linear algebra
    Numerical Analysis
    텝스
    teps
    obsidian
    Julia
    옵시디언
    서버
    JAX
    고체역학
    ChatGPT
    pytorch
    IEEE
    MATLAB
    에러기록
    논문작성법
    Dear abby
    인공지능
    우분투
    논문작성
    LaTeX
    Zotero
    수치해석
    Statics
    딥러닝
    WOX
    생산성
    Python
    matplotlib
    텝스공부
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[PyTorch] PyTorch 다차원 텐서 곱(matmul)
상단으로

티스토리툴바