유사개발자 샤이와 무지

[Tensorflow] dtype이 매칭되지 않던 문제 #TDP 02-step04_05 본문

WIL/Bug_fixed

[Tensorflow] dtype이 매칭되지 않던 문제 #TDP 02-step04_05

Shy & Mujee 2023. 3. 8. 02:59
from tensorflow.python.ops.init_ops import deprecated_arg_values
from tensorflow._api.v2 import linalg
import tensorflow as tf

A=tf.constant([[1,4,1],[1,6,-1],[2,-1,2]],dtype=tf.float64)

L_U,p= tf.linalg.lu(A)
print(L_U)
print(p)
print();print()

#making P, L, U
U=tf.linalg.band_part(L_U, 0, -1) #Upper triangular
print(U)
L=tf.linalg.band_part(L_U,-1,0) #Lower triangular
print(L,end="\n\n")

L=tf.linalg.set_diag(L, [1,1,1]) #strictly lower triangular part of LU
print(L)

P=tf.gather(tf.eye(3),p) #tf.eye is unit matrix tf.gatheris https://www.tensorflow.org/api_docs/python/tf/gather
print(P)
print();print()

#check A == PLU under this 2 function is same tool
print(tf.linalg.lu_reconstruct(L_U,p))
print(tf.matmul(P, tf.matmul(L,U))) #It is same with tf.gather(tf.matmul(L,U),p)
print();print()

#solve Ax=b using PLUx=b
print(tf.linalg.lu_solve(L_U, p, b))
y=tf.linalg.triangular_solve(L, tf.matmul(tf.transpose(P),b))
print(y)

x=tf.linalg.triangular_solve(U, y, lower=False)
print(x)
print();print()

#pivot, calculate det and rank
D=tf.linalg.diag_part(L_U) #tf.linalg.diag_part(U)
print(D)
rank=tf.math.count_nonzero(D)
print(rank,end="\n\n\n")

det_U=tf.reduce_prod(tf.linalg.diag_part(U)) #tf.linalg.det(U)
det_L=tf.reduce_prod(tf.linalg.diag_part(L)) #tf.linalg.det(L)
det_P=tf.linalg.det(P)

print(det_U);print(det_L);print(det_P)
det_A=det_P*det_L*det_U
print(det_A)

어제 http://www.yes24.com/Product/Goods/93156986를 공부하다가 발생한 issue였습니다

 

텐서플로 딥러닝 프로그래밍 - YES24

텐서플로 딥러닝 프로그래밍

www.yes24.com

[Problem]

코드만 보면 별 이슈꺼리는 없어보였지만

tf.matmul(P, tf.matmul(L,U))

InvalidArgumentError Traceback (most recent call last)


InvalidArgumentError: cannot compute MatMul as input #1(zero-based) was expected to be a float tensor but is a double tensor [Op:MatMul]

잡다한 로그들은 집어치우고 보면

 

matmul(행렬곱)을 실행할 수 없다고 합니다. 

원인은 분명히

"넌 data_type이 float인 tensor를 집어 넣었는데  double인 tensor가 잘못 들어온 것 같다"

라고 하는군요.

이 말은 제가 어디선가 자료형을 잘못 지정을 해줬겠죠? 라고 하기엔 파이썬입니다.

컴파일러 니가 지금 지정 이상하게 해두고 저한테 이거 니탓이야 이러고 있네요 ㅋ

그럼 이제 원인을 알았으니 대충 예상가는 위치로 이동해봅시다.

 

 

이 코드를 수정하고 난 느낌은 딱 이랬습니다. 또 파이썬 너냐...

 

이제 자료형이 변할만 한 부분을 찾으면서 코드에 대해 가볍게 설명을 해보자면 에러가 있는 부분까지는 선형대수의 PLU 분해를 컴퓨터로 구현해 둔 형태입니다. PLU분해에 대한 자세한 이야기는 제가 아래 써두었던 포스팅을 참고해주세요

 

추후 작성 예정

 

[Solve]

이 글을 읽고 오셨다면 주의깊게 봐야 할 부분은 이미 눈치 채셨을 것 같습니다. P L U 가 각각 분해되는 과정이겠죠

사실 설명 시작하면서 에러의 원인을 발견했었습니다. 정말 맨 위에 있더라고요. 같이 보시죠

L_U,p= tf.linalg.lu(A)
print(L_U)
print(p)

tf.Tensor( [[ 2. -1. 2. ] [ 0.5 6.5 -2. ] [ 0.5 0.69230769 1.38461538]], shape=(3, 3), dtype=float64)
tf.Tensor([2 1 0], shape=(3,), dtype=int32)

자 컴파일러가 저에게 구라를 한번 더 쳤지만 expected라고 명시를 해뒀기때문에 한 번만 넘기도록 하겠습니다.

직관적으로 data type이 다른 부분을 확인을 했으니 수정을 해야겠습니다. 그럼 어떻게 해야하는가

구글링을 또 열심히 해서 찾아본 결과 이 함수가 저에게 도움이 되겠군요.

tf.cast(x, dtype, name=None)

x Tensor or SparseTensor or IndexedSlices of numeric type. It could be uint8uint16uint32uint64int8int16int32int64float16float32float64complex64complex128bfloat16.
dtype The destination type. The list of supported dtypes is the same as x.
name A name for the operation (optional).

dtype의 형변환을 시켜주는 친구입니다.

저는 이 함수로 P의 dtype을 float64로 변형시켜주었습니다.

그럼 이제 결과를 확인해 봅시다

 

P=tf.cast(P, tf.float64) #P is int32 before. So We have to cast P to float64
print(P)

print(tf.matmul(P, tf.matmul(L,U))) #It is same with tf.gather(tf.matmul(L,U),p)

tf.Tensor(
[[0. 0. 1.]
[0. 1. 0.]
[1. 0. 0.]], shape=(3, 3), dtype=float64)

tf.Tensor(
[[ 1. 4. 1.]
 [ 1. 6. -1.]
 [ 2. -1. 2.]], shape=(3, 3), dtype=float64)

제가 원하는대로 dtype이 잘 변형되었고, 결과 도출이 잘 되었군요.

 

[Why It caused?]

그럼 이제 이 현상이 발생한 원인에 대해 생각해봅시다.

당연하게도 이 문제의 원인은 파이썬 자체적으로 자동 형변환(casting)입니다.

자동 형변환에 대해 간단하게 설명하자면  아래와 같습니다.

더보기

컴퓨팅 언어들은 '변수를 지정할때 자료형을 지정해주어야 한다' 라는 매우*1e+100 중요한 전제를 깔고 가야합니다.

파이썬에선 코더와 프로그래머의 편의를 위하여 이를 자동으로 지정하고 감추어 두었지만 파이썬은 뿌리가 되는 C만큼 많은 자료형을 가지지는 않지만 int,float,string등등의 자료형을 엄연히 가지고 있으며 이는 유저의 요청에 따라 컴파일러에서 알아서 변형시킵니다.

더 자세하게 알고 싶으신 분들은 제가 또 준비해둔 글을 참고해보세요!

추후 작성 예정

 

그리고, tensorflow는 파이썬의 자료형이 아닌 자체적인 자료형을 사용합니다.

(정확히는 Numpy라는 패키지와의 호완성을 위하여 Numpy의 자료형을 가져다가 사용합니다.)

 

위의 내용이 어느정도 이해가 가셨다면 아래부턴 쉽습니다.

미리 정답을 공개하자면 컴파일러들은 최적화 문제나 여러 issue로 자료형을 최대한 작게 설정합니다.

 

그 결과 data type의 충돌로 인하여 연산 오류가 발생하던 것이었습니다.

 

만약 프로그래밍과 코딩을 파이썬으로 처음 접하시는 분들이라면 아마 이 에러를 접하시고 당혹스러우실 수도 있으시겠지만, 좀 더 low language로 내려갈수록 흔히 볼 수 있는 실수입니다.

이번에 잘 익혀서 다른 언어를 배우실때 도움이 되시면 좋겠습니다.

 

 

Comments