TensorFlow模型詳解與應用
DNNLinearCombinedClassifier 類繼承于類 Estimator,Estimator 類繼承于類 BaseEstimator。BaseEstimator 是一個抽象類,定義了通用的模型訓練以及評測的函數接口 (train_model, evaluate_model, infer_model),Estimator 類中用一個統一函數 call_model_fn 來實現 train_model, evaluate_model, infer_model。
圖 7 estimator 的類關系圖
為了更好了解整個過程,我們看看內部函數的調用過程(代碼可以參見 estimator/estimator.py):
圖 8 Estmiator 類的函數調用圖
模型訓練通過調用 BaseEstimator 的 fit() 接口開始,其調用棧是:fit -》 _train_model -》 _get_train_ops -》_call_model_fn(ModelKeys.TRAIN) -》 _model_fn,最終_model_fn() 產生模型并通過 export 函數將模型輸出到 model_dir 對應目錄中。
我們把訓練模型的調用過程在代碼級別展開,標出關鍵的幾個函數和數據結構,省略不關鍵的代碼,希望能讓讀者看到訓練模型的大致過程:
圖 9 模型訓練的調用棧
評測(evaluate)和預測(predict)的過程與訓練(train)大致相同,讀者可以通過源代碼文件找到對應函數了解。可以看出,整個函數調用棧中最關鍵的 2 個函數是: input_fn 和 model_fn。input_fn 從輸入數據中生成 features 和 labels,features 是一個 Tensor 或者是一個從特征名到 Tensor 的字典,如果 features 是一個 Tensor,程序會給這個 Tensor 一個空字符串的鍵值,轉換成特征名到 Tensor 的字典。labels 是樣本的 label 構成的 tensor。input_fn 由應用程序調用者提供實現,返回(features, labels)二元組,要求 tf.get_shape(features)[0] == tf.get_shape(labels)[0],也就是兩個 tensor 的行數目得保持一致。model_fn 定義訓練和評測模型的具體邏輯,如模型訓練產生的誤差 (model_fn_ops.loss) 以及訓練算子(model_fn_ops.train_op)通過封裝在 EstmiatorSpec 的對象中由 training 的 Session 進行調用。每個具體模型需要實現的是自定義的 model_fn。
DNNLinearCombinedClassifier 是如何實現自己的 model_fn 的呢?本文開頭我們給出了它的初始化函數原型,進入初始化函數的實現中我們定位到代碼行 model_fn=_dnn_linear_combined_model_fn。
這個就是 DNNLinearCombinedClassifier 的 model_fn。這個函數的定義如下:
def_dnn_linear_combined_model_fn(features, labels, mode, params, config= None)
features 和 labels 大家都已經知道,mode 指定 model_fn 的操作模式,目前支持 3 個值:訓練模型 (model_fn.ModeKeys.TRAIN),對模型進行評測 (model_fn.ModeKeys.EVAL),根據輸入特征進行預測 (model_fn.ModeKeys.PREDICT),mode 的定義可參見文件 estimator/model_fn.py。params 和 config 參數分別定義模型訓練的參數以及模型運行的配置。
非常好我支持^.^
(2) 40%
不好我反對
(3) 60%