59doit

순환신경망 (2) 본문

인공지능

순환신경망 (2)

yul_S2 2023. 1. 21. 17:28
반응형

텐서플로를 이용한 순환신경망 만들기

SimpeRNN : 텐서플로에서 가장 기본적인 순환층

 

 

# 1. 필요한 클래스 임포트

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, SimpleRNN

 

 

# 2. 모델만들기

model = Sequential()
 
model.add(SimpleRNN(32, input_shape = (100,100))) #타임스텝의 길이가 100이고 원핫인코딩 크기가 100이므로 입력크기 (100,100)
model.add(Dense(1, activation = 'sigmoid'))
 
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                        Output Shape              Param #   
=================================================================
 simple_rnn (SimpleRNN)      (None, 32)                 4256      
                                                                 
 dense_6 (Dense)                  (None, 1)                   33        
                                                                 
=================================================================
Total params: 4,289
Trainable params: 4,289
Non-trainable params: 0
_________________________________________________________________

▶ 타임스텝의 길이가 100이고 원핫인코딩 크기가 100이므로 입력크기 (100,100)

원-핫 인코딩 된 100차원 벡터이고 셀 개수가 32개이므로 W_1x 행렬 요소의 개수는 100 x 32가 된다.
W_1h 행렬의 요소 개수도 32 x 32가 된다. 마지막으로 셀마다 하나씩 총 32개의 절편이 있다.

 

 

 

# 3.모델 컴파일 & 훈련

model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(x_train_onehot, y_train, epochs=20, batch_size=32, validation_data=(x_val_onehot, y_val))

▶모델 컴파일 하고 IMDB데이터 세트에 훈련시키기
확률적 경사 하강법 알고리즘(sgd)을 지정
이진분류이므로 손실함수 : binary_crossentropy로 지정

 

 

 

# 4. plot(훈련, 검증 세트에 대한 손실 그래프와 정확도그래프)

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['loss','val_loss'])
plt.show()
 
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['accuracy','val_accuracy'])
plt.show()



▲ 손실그래프
▲ 정확도그래프

 

 

 

 

 

 

# 5. 검증세트 정확도 평가하기

loss, accuracy = model.evaluate(x_val_onehot, y_val, verbose=0)
print(accuracy)
# 0.7077999711036682

▶검증세트에 대한 정확도 평가 -> 약 70%의 정확도를 보여줌
순환신경망을 직접 구현하여 모델 훈련시켜 얻은 정확도에 비해 많이 좋아지진 않았음

 

 

 

반응형

'인공지능' 카테고리의 다른 글

순환신경망 (4)  (6) 2023.01.22
순환신경망 (3)  (1) 2023.01.22
순환신경망 (1)  (5) 2023.01.21
합성곱 (6)  (2) 2023.01.21
합성곱(4)  (12) 2023.01.18
Comments