군침이 싹 도는 코딩

epoch 에 따른 데이터의 overfitting 과 해결법 (callback class) 본문

Python/Deep Learning

epoch 에 따른 데이터의 overfitting 과 해결법 (callback class)

mugoori 2022. 12. 29. 11:24

X축 epoch Y축 Loss

# 인공지능 학습에서 에포크를 반복 할 수록 학습 데이터의 결과가 좋아진다

반대로 벨리데이션 데이터는 점점 로스가 커지는것을 볼 수 있다

학습한 데이터(Train Loss) 와 벨리데이션 데이터(Validation Loss) 가 멀어지는 현상을

overfitting 이라고 한다

 

 

 

 

X축 epoch Y축 accuracy

# 위와 마찬가지로 정확도(accuracy) 또한 학습한 데이터와 벨리데이션 데이터간의 격차가 커진다

이를 해결하기 위해서는 callback class를 이용해 원하는 조건을 걸어서 멈추게하면 된다

 

 

 

 

class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=()) :
    if logs['val_accuracy'] > 0.88 :
      print('\n내가 정한 정확도에 도달했으니, 학습을 멈춘다.')
      self.model.stop_training = True

# 코드 작성을 통해 콜백클래스를 만들어준다

이 코드는 에포크가 1번 끝날때마다 val_accuracy가 0.88 이상인지 확인해서

0.88 이상이라면 내가 정한 정확도에 도달했으니, 학습을 멈춘다라고 프린트하고

학습을 종료하라는 함수를 만들었다

 

 

 

 

my_cb = myCallback()

# 이것을 변수에 저장한다

 

 

 

 

def build_model():
  model = Sequential()
  model.add( Flatten())
  model.add( Dense(128,'relu') )
  model.add( Dense(64,'relu') )
  model.add( Dense(10, 'softmax'))
  model.compile('adam','sparse_categorical_crossentropy',['accuracy'])
  return model

# 인공지능 모델링을 함수를 통해 만들어준다

 

 

 

 

model = build_model()

# 이것을 변수에 저장한다

 

 

 

 

epoch_history = model.fit(X_train,y_train,epochs=30,validation_split=0.2,callbacks=[my_cb])
>>>
Epoch 1/30
1500/1500 [==============================] - 7s 4ms/step - loss: 0.5129 - accuracy: 0.8174 - val_loss: 0.4252 - val_accuracy: 0.8474
Epoch 2/30
1500/1500 [==============================] - 7s 5ms/step - loss: 0.3823 - accuracy: 0.8605 - val_loss: 0.3926 - val_accuracy: 0.8575
Epoch 3/30
1500/1500 [==============================] - 6s 4ms/step - loss: 0.3414 - accuracy: 0.8737 - val_loss: 0.3881 - val_accuracy: 0.8539
Epoch 4/30
1500/1500 [==============================] - 6s 4ms/step - loss: 0.3194 - accuracy: 0.8809 - val_loss: 0.3396 - val_accuracy: 0.8778
Epoch 5/30
1500/1500 [==============================] - 7s 5ms/step - loss: 0.2979 - accuracy: 0.8900 - val_loss: 0.3465 - val_accuracy: 0.8760
Epoch 6/30
1489/1500 [============================>.] - ETA: 0s - loss: 0.2859 - accuracy: 0.8926
내가 정한 정확도에 도달했으니, 학습을 멈춘다.
1500/1500 [==============================] - 6s 4ms/step - loss: 0.2853 - accuracy: 0.8929 - val_loss: 0.3146 - val_accuracy: 0.8842

# 학습을 시킬때 파라미터에 벨리데이션 스플릿과 콜백스에 아까 만들어두었던

콜백클래스 함수를 저장한 변수를 넣어준다

30번의 에포크를 지정했지만 6번만에 val_accuracy 0.88에 도달했기때문에 멈추었다