Tensorflow2-tensorflow-keras-回調函數

Introduction

  • 回調函數的API類位於tf.keras.callbacks類中
  • 通常在模型訓練的過程中,常常會額外作一些非模型訓練的事情

    • EarlyStopping:應用在當訓練模型的過程中,如果loss值不再下降,便可提前停止訓練

      • keras.callbacks.EarlyStopping(monitor, min_delta)
      • monitor:指定要關注的指標之值,一般情況都是關注驗證集上目標(損失)函數之值
      • min_delta:閾值
        • 此次訓練相較於上次訓練的差距是否比閾值;提前結束訓練
      • patience:當此次訓練相較於上次訓練比min_delta還小時,patience次數後就提前結束訓練
    • ModelCheckpint:訓練模型的過程中,記錄所有訓練參數的中間狀態,其會每隔一段時間將checkpoint保存下來

      • keras.callbacks.ModelCheckpoint(filepath)
      • 其須指定filepath為一文件名
      • save_best_onlyTrue時則保存最好的模型參數,否則會保存最近一次訓練的模型
    • TensorBoard:在模型訓練過程中,實時查看一些參數改變狀況的dashboard
      • keras.callbacks.Tensorboard(logdir)
      • 其須指定logdir位置為一資料夾
    • 其他回調函數皆收錄於 https://www.tensorflow.org/api_docs/python/tf/keras/callbacks

Usage

  • 通常callbacks都是在訓練的過程中進行監聽;因此是在模型fit函數裡面添加callbacks函數
  • 通常會定義一個callbacks的列表,再將其作為參數傳入到模型fit的參數中

Example


callbacks文件夾目錄結構

1
2
3
4
5
6
7
8
9
10
11
.
├── fashion_mnist_model.h5
├── train
│ ├── events.out.tfevents.1582554144.chenruiyude-iMac.local.1478.262.v2
│ ├── events.out.tfevents.1582554144.chenruiyude-iMac.local.profile-empty
│ └── plugins
│ └── profile
│ └── 2020-02-24_22-22-24
│ └── local.trace
└── validation
└── events.out.tfevents.1582554147.chenruiyude-iMac.local.1478.13246.v2

使用tensorboard可視化

1
tensorboard --logdir="儲存model的文件夾位置"

SCALARS

tensorboard_SCALARS

GRAPHS

tensorboard_GRAPHS