기억 저장소

클라우드 기반 인공지능 개발과 DevOps 실무

인공지능/딥러닝

딥러닝 : callback 함수 , 콜백 함수 코드 및 사용

하늘.98 2021. 12. 1. 18:13

아래의 사진은 callback 함수를 사용한 

이미지를 학습시키는 코드이다.

callback 함수를 쓰는 이유는 training 시키는 과정중 epoch 을 많이 시키게 되면 

train 값에 대한 정확률만 높아져 오버핏팅이 되기 때문에 

입력자가 정해놓은 정확도 값 이상이 되면 학습을 멈추는 것을 뜻한다. 

아래 코드를 보면 두가지 방법이 있다.

 

 

차트화

import matplotlib.pyplot as plt

def plot_history(history):
  hist = pd.DataFrame(history.history)
  hist['epoch'] = history.epoch

  plt.figure(figsize=(8,12))

  plt.subplot(2,1,1)
  plt.xlabel('Epoch')
  plt.ylabel('Mean Abs Error [MPG]')
  plt.plot(hist['epoch'], hist['mae'],
           label='Train Error')
  plt.plot(hist['epoch'], hist['val_mae'],
           label = 'Val Error')
  plt.ylim([0,5])
  plt.legend()

  plt.subplot(2,1,2)
  plt.xlabel('Epoch')
  plt.ylabel('Mean Square Error [$MPG^2$]')
  plt.plot(hist['epoch'], hist['mse'],
           label='Train Error')
  plt.plot(hist['epoch'], hist['val_mse'],
           label = 'Val Error')
  plt.ylim([0,20])
  plt.legend()
  plt.show()

plot_history(epoch_history)

그래프를 보면 수 백번 에포크를 진행한 이후에는 모델이 거의 향상되지 않는 것 같습니다.

model.fit 메서드를 수정하여 검증 점수가 향상되지 않으면 자동으로 훈련을 멈추도록 만들어줍니다.

 

그래서 !! EarlyStopping 콜백(callback)사용합니다!!

#위에 에포크 1000개 했다면 모델을 다시 만들어 줍니다.
model = build_model()
# 안했다면 윗부분은 건너뜁니다.

# patience= 파라미터는 성능향상을 체크할 에포크 수로서
# 10이라고 세팅하면 에포크가 10번 지났는데도 성능향상없으면, 멈추라는 뜻입니다.
early_stop=tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=10)

epoch_history = model.fit(X_train,y_train,epochs=1000,validation_split=0.2,
                          callbacks = [early_stop])

이번엔 에포크 172까지만 수행하고 종료됐습니다.

 

import matplotlib.pyplot as plt

def plot_history(history):
  hist = pd.DataFrame(history.history)
  hist['epoch'] = history.epoch

  plt.figure(figsize=(8,12))

  plt.subplot(2,1,1)
  plt.xlabel('Epoch')
  plt.ylabel('Mean Abs Error [MPG]')
  plt.plot(hist['epoch'], hist['mae'],
           label='Train Error')
  plt.plot(hist['epoch'], hist['val_mae'],
           label = 'Val Error')
  plt.ylim([0,5])
  plt.legend()

  plt.subplot(2,1,2)
  plt.xlabel('Epoch')
  plt.ylabel('Mean Square Error [$MPG^2$]')
  plt.plot(hist['epoch'], hist['mse'],
           label='Train Error')
  plt.plot(hist['epoch'], hist['val_mse'],
           label = 'Val Error')
  plt.ylim([0,20])
  plt.legend()
  plt.show()

plot_history(epoch_history)

몇번 동일한 값이 나오니 자동으로 종료했습니다.