일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- Dear abby
- teps
- 논문작성
- Python
- Statics
- IEEE
- Linear algebra
- 인공지능
- 수치해석
- 옵시디언
- 텝스공부
- 고체역학
- 수식삽입
- 에러기록
- Julia
- Numerical Analysis
- 딥러닝
- MATLAB
- WOX
- 생산성
- 우분투
- pytorch
- 논문작성법
- LaTeX
- ChatGPT
- matplotlib
- obsidian
- 텝스
- JAX
- Zotero
- Today
- Total
뛰는 놈 위에 나는 공대생
[PyTorch] 특정 조건에 맞는 텐서 출력/인덱싱 등 본문
데이터 안에 이상한 값이 없는지는 확인하고 싶을 때가 있다. 데이터가 이상하면 당연히 결과도 이상하기 때문이다.
넘파이의 경우에는 아래와 같이 where함수를 쓰면 인덱스를 리턴한다.
print(np.where(np.isnan(data)==True))
>>(array([], dtype=int64), array([], dtype=int64))
위 경우는 내부에 nan이 없기 때문에 빈 array를 return한다.
a = np.array([np.nan, 1, 2, 3])
print(np.where(np.isnan(a)==True))
>> (array([0], dtype=int64),)
만약 다음과 같이 nan이 들어있다면 그 값이 있는 index를 리턴한다. 이 때 튜플을 return한다는 점에 주의해야한다.
이렇게 인덱스를 return하면 유용한 점이 있는데 바로 그 값과 추출할 수 있다는 점이다. nan값은 오히려 없애고 싶은 경우라서 제거해야한다.
check = np.where(np.isnan(a)==True)[0] # tuple로 return하기 때문에 tuple 중에서 첫번째 인덱스 값을 가져온다.
a[check]
>> array([nan])
다음과 같이 사용할 수 있다.
제거하고 싶은 경우에는 True, False로 구성된 배열(binary mask)을 통해서 가능하다.
# 제거하고 싶은 경우
~np.isnan(a)
>> array([False, True, True, True])
a[~np.isnan(a)]
>> array([1., 2., 3.])
여기서 핵심은 numpy에서는 index [] 대괄호 안에 True, False로 구성된 배열 또는 index 값으로 구성된 배열을 통해서 원하는 배열의 일부를 추출 가능하다는 것이다.
그런데 PyTorch는 자체적인 torch tensor를 사용한다. 넘파이와 거의 비슷하기 때문에 다음과 같이 조건에 맞는 원소를 가져오는 것이 가능하다.
a = torch.tensor([0.1, 0.3, -0.5, 1.0])
check = a < 0 # 0보다 작은 원소 추출
print(check)
>> tensor([False, False, True, False])
print(a[check])
>> tensor([-0.5000])
다음과 같이 조건에 맞는지에 대한 배열(check)을 먼저 구하고 이를 인덱스에 넣어서 구하는 방법이다.
그런데 이런 경우도 있을 것이다.
1) 배열 원소 중에 조건에 부합하는 값들은 다른 값으로 바꾸고 싶다.
2) 배열 원소 중에 조건에 부합하는 값들의 인덱스를 가지고 싶다.
1) 배열 원소 중에 조건에 부합하는 값들은 다른 값으로 바꾸고 싶을 때
1)의 경우에는 where함수를 쓸 수 있다. 위의 numpy에서 where이 인덱스를 return하는 것과 달리 torch.where은 첫 번째 argument에 기준이 되는 조건을 넣고, 그 조건에 충족될 경우에 어떤 값을 return할 지, 충족하지 않을 경우에는 어떤 값을 return할 지 지정할 수 있다. 그래서 최종적인 배열은 조건으로 활용한 배열과 동일한 shape이 된다.
a = torch.tensor([0.1, 0.3, -0.5, 1.0])
check = torch.where(a < 0.0, True, False) # 조건에 맞으면 True, 아니면 False
>> tensor([False, False, True, False])
그리고 앞에서 그냥 텐서 a에 check를 사용해서 필요한 값을 추출할 수도 있지
torch.masked_select(a, check)
다음과 같이 masked_select 함수를 통해서도 a[check]와 동일한 효과를 낼 수 있다.
2) 배열 원소 중에 조건에 부합하는 값들의 인덱스를 가지고 싶을 때
2)와 같이 인덱싱을 하는 것에 대해서는 구글링을 열심히 했는데 아래 방법이 일단은 최선이었다.
a = torch.tensor([0.1, 0.3, -0.5, 1.0])
check = ( a<0 ).nonzero(as_tuple=True)[0]
print(check)
>> tensor([2])
nonzero를 사용하는 구체적인 방법은 파이토치 문서를 가보면 잘 설명되어있다.
위의 a<0 조건문은 True, False로 return하지만 (torch에서 구체적으로 True, False를 어떻게 쓰는지는 모르겠지만) 보통은 True가 1이고, False가 0이기 때문에 True인 index만 걸러낼 수 있다.
as_tuple=True로 지정하지 않으면 default는 False이다. 이 경우에는 2차원 텐서가 나온다고 하는데 다음과 같이 2차원 텐서에 대해서 nonzero 값을 찾는다면
다음과 같이 [row, column]으로 구성된 인덱스를 return한다.
하지만 as_tuple=True인 경우에는 다음과 같이 row와 column을 따로 분리한다. (그래서 1차원 텐서의 nonzero 인덱스를 찾을 때는 굳이 True로 지정할 필요는 없다.)
참고자료
https://pytorch.org/docs/stable/generated/torch.nonzero.html#torch.nonzero
'프로그래밍 Programming > 파이썬 Python' 카테고리의 다른 글
[Matplotlib] legend 그림 바깥에 배치/원하는 위치에 배치 (0) | 2023.01.14 |
---|---|
[Matplotlib] Matplotlib 폰트 스타일 바꾸기 (0) | 2023.01.13 |
[Pytorch] multi-output일 때 input gradient 구하기 (0) | 2023.01.03 |
[에러기록] matplotlib에서 figure만 그려지고 plot이 없는 경우 (0) | 2022.12.27 |
[에러기록] Visual studio code에서 아나콘다 가상환경이 안 돌아갈 때 (0) | 2022.11.01 |