模型的可解釋性是機(jī)器學(xué)習(xí)領(lǐng)域的一個(gè)重要分支,隨著 AI 應(yīng)用范圍的不斷擴(kuò)大,人們越來越不滿足于模型的黑盒特性,與此同時(shí),金融、自動駕駛等領(lǐng)域的法律法規(guī)也對模型的可解釋性提出了更高的要求,在可解釋 AI 一文中我們已經(jīng)了解到模型可解釋性發(fā)展的相關(guān)背景以及目前較為成熟的技術(shù)方法,本文通過一個(gè)具體實(shí)例來了解下在 MATLAB 中是如何使用這些方法的,以及在得到解釋的數(shù)據(jù)之后我們該如何理解分析結(jié)果。
?
要分析的機(jī)器學(xué)習(xí)模型
?
我們以一個(gè)經(jīng)典的人體姿態(tài)識別為例,該模型的目標(biāo)是通過訓(xùn)練來從傳感器數(shù)據(jù)中檢測人體活動。傳感器數(shù)據(jù)包括三軸加速計(jì)和三軸陀螺儀共6組數(shù)據(jù),我們可以通過手機(jī)或其他設(shè)備收集,訓(xùn)練的目的是識別出人體目前是走路、站立、坐、躺等六種姿態(tài)中的哪一種。我們將收集到的數(shù)據(jù)做進(jìn)一步統(tǒng)計(jì)分析,如求均值和標(biāo)準(zhǔn)差等,最終獲得18組數(shù)據(jù),即18個(gè)特征。然后可以在 MATLAB 中使用分類學(xué)習(xí)器 App 或者通過編程的形式進(jìn)行訓(xùn)練,訓(xùn)練得到的模型混淆矩陣如下,可以看到對于某些姿態(tài)的識別,模型會存在一定誤差。那么接下來我們就通過一系列模型可解釋性的方法去嘗試解讀一下錯誤判別的來源。
從混淆矩陣中可以看到,模型對于躺 ‘Laying’ 的姿態(tài)識別率為 100%,而對于正常走路和上下樓這三種 ‘Walking’ 的姿態(tài)識別準(zhǔn)確率較低,尤其是上樓和下樓均低于70%。這也符合我們的預(yù)期,因?yàn)樘傻淖藨B(tài)和其他差別較大,而幾種走路之間差異較小。
但我們也留意到模型在 ‘Sitting’ 和 ‘Standing’ 之間也產(chǎn)生了較大的誤差,考慮到這兩者之間的差異,我們想探究一下產(chǎn)生這種分類錯誤背后的原因。首先我們從圖中所示的區(qū)域選擇了一個(gè)樣本點(diǎn) query point,該樣本的正確姿態(tài)為 ‘Sitting’,但是模型識別成了 ‘Standing’,為便于下一步分析,這里將該樣本點(diǎn)所有特征及其取值列舉了出來,如前所述一共 18 個(gè),分別對應(yīng)于原始的6個(gè)傳感器數(shù)據(jù)的平均值、標(biāo)準(zhǔn)差以及第一主成分:
使用可解釋性方法進(jìn)行分析
模型可解釋性分析的目的在于嘗試對機(jī)器學(xué)習(xí)黑盒模型的預(yù)測結(jié)果給出一個(gè)合理的解釋,定性地反映出輸入數(shù)據(jù)的各個(gè)特征和預(yù)測結(jié)果之間的關(guān)系。對于預(yù)測正確的結(jié)果,我們可以判斷預(yù)測過程是否符合我們基于領(lǐng)域知識對該問題的理解,是否有一些偶然因素導(dǎo)致結(jié)果碰巧正確,從而保證了模型可以在大規(guī)模生產(chǎn)環(huán)境下做進(jìn)一步應(yīng)用,也可以滿足一些法規(guī)的要求。
而對于錯誤的結(jié)果,如上文中的姿態(tài)識別,我們可以通過可解釋性來分析錯誤結(jié)果是由哪些因素導(dǎo)致的,更具體地說,即上述 18 個(gè)特征對結(jié)果的影響。在此基礎(chǔ)上,可以更有針對性地進(jìn)行特征選擇、參數(shù)優(yōu)化等模型改進(jìn)工作。
接下來我們就嘗試用幾種不同的可解釋性方法來對上文中的 query point 做進(jìn)一步分析,希望可以找到一些模型分類錯誤的線索。
2.1 Shapley 值
我們嘗試的第一個(gè)方法是 Shapley 值,Shapley 值起源于合作博弈理論,它基于嚴(yán)格的理論分析并給出了完整的解釋。作為一個(gè)局部解釋方法,Shapley 值通過對所有可能的特征組合依次計(jì)算,從而得到每個(gè)特征對預(yù)測結(jié)果的平均邊際貢獻(xiàn),并且這些值是相對于該分類的平均得分而言的。可以簡單理解為邊際貢獻(xiàn)的分值越高,對產(chǎn)生當(dāng)前預(yù)測結(jié)果的影響越大。因?yàn)橛兄晟频睦碚摶A(chǔ)且發(fā)展時(shí)間較長,Shapley 值被廣泛應(yīng)用于金融領(lǐng)域來滿足一些法律法規(guī)的要求。
在 MATLAB 中使用 Shapley 值的方法也非常簡單,具體代碼如下:
exp = shapley(model,humanActivityDataTest,'QueryPoint',queryPt,'MaxNumSubsets',400);
plot(exp)
其中 shapley 即我們要調(diào)用的函數(shù),函數(shù)的輸入依次是訓(xùn)練好的模型,測試時(shí)完整的數(shù)據(jù)集,上文中要探測的樣本點(diǎn) query point。值得一提的是,由于嚴(yán)格的 Shapley 值計(jì)算過程中需要對所有可能的特征組合依次計(jì)算,計(jì)算時(shí)間隨特征數(shù)量呈指數(shù)增長,所以我們在調(diào)用時(shí)設(shè)置了控制計(jì)算時(shí)間的 Subsets 參數(shù)。函數(shù)的輸出 exp 是結(jié)構(gòu)體的形式,可以直接使用 plot 進(jìn)行繪制,結(jié)果如下圖:
圖中按照 Shapley 值的絕對值大小依次進(jìn)行了排序,那么該如何理解這些值即圖中所示的得分的含義呢?
我們之前已經(jīng)了解到 Shapley 值反應(yīng)的是每個(gè)特征的平均邊際貢獻(xiàn),并且這些值是相對于該分類的平均得分而言的。首先需要計(jì)算出 ‘Standing’ 的平均得分,我們會將數(shù)據(jù)集中所有點(diǎn)關(guān)于 ‘Standing’ 的預(yù)測得分取平均得到相應(yīng)的值,即 0.17577。而我們關(guān)注的樣本點(diǎn)預(yù)測為 ‘Standing’ 的得分為 1,相對較高,它和所有點(diǎn)的平均值相比差值為 0.82423,Shapley 值反應(yīng)的正是該樣本點(diǎn)中每個(gè)特征對這個(gè)差值的貢獻(xiàn),其總和也正是 0.82423。
圖中顯示了排行前十的特征及對應(yīng)的 Shapley 值,我們可以看到 rowmean_body_gyro_z 的值最大,說明它對錯誤判別的影響最大,當(dāng)然緊隨其后的幾個(gè)特征的 Shapley 值也較為接近。
特征 rowmean_body_gyro_z的實(shí)際含義為z方向陀螺儀的平均值,為什么這個(gè)特征可能導(dǎo)致了錯誤的結(jié)果?我們可以接著往下分析。
2.2 PDP - Partial Dependency Plot
Shapley 值雖然很清晰地給出了各個(gè)特征對于最終預(yù)測結(jié)果的貢獻(xiàn),但是我們需要更多的信息來分析錯誤產(chǎn)生的來源,一個(gè)有效的方法是結(jié)合 PDP 又稱部分依賴圖來進(jìn)行查看。
PDP 是一個(gè)全局解釋方法,關(guān)注單個(gè)特征對某一預(yù)測結(jié)果的整體影響,其思想是假設(shè)所有樣本中的該特征等于某一個(gè)固定值,從而計(jì)算出一個(gè)預(yù)測結(jié)果的平均值。當(dāng)我們將該特征取一系列值時(shí)(取值范圍仍然來源于樣本),便可以繪制出對應(yīng)的曲線。我們接著 Shapley 值的分析選擇特征 rowmean_body_gyro_z(對應(yīng)數(shù)據(jù)中的位置為第6個(gè)特征),以及 query point 對應(yīng)的真實(shí)分類 ‘Sitting’ 和錯誤分類 ‘Standing’ 分別繪制 PDP,在 MATLAB 中使用的方法仍然非常簡單,具體代碼及對應(yīng)結(jié)果如下:
plotPartialDependence(model,6,'Sitting');
% rowmean_body_gyro_zis the 6th predictor in our data table
plotPartialDependence(model,6,'Standing');
根據(jù)上圖以及第 1 節(jié)中 query point 在該特征的實(shí)際取值 0.017 可以看出,當(dāng)該特征的取值接近于 0 時(shí),分類為 ‘Standing’ 的分?jǐn)?shù)較高,而當(dāng)取值向兩端靠攏尤其是接近于 -0.5 時(shí)分類為 ‘sitting’ 的分?jǐn)?shù)較高,甚至大于 0.5,這也符合該點(diǎn)的實(shí)際預(yù)測值。
為了驗(yàn)證上述分析結(jié)果,我們繪制了一部分樣本點(diǎn)(約 1000 個(gè))body_gyro_z 的實(shí)際取值,結(jié)果如下圖所示,可以看到 ‘Sitting’(圖中紫色數(shù)據(jù))的整體趨勢確實(shí)比 ‘Standing’(圖中綠色數(shù)據(jù))要小一些,這說明了模型的訓(xùn)練及預(yù)測過程是合理的。但兩者的差別并不大,而且對于單個(gè)的樣本點(diǎn),比如我們現(xiàn)在關(guān)注的 query point,取值可能更大或者更小,并不符合大多數(shù)樣本的整體趨勢,這也是預(yù)測結(jié)果中個(gè)別樣本分類錯誤的原因之一。
通過部分依賴圖我們對 Shapley 值的分析結(jié)果有了更清楚的認(rèn)識,雖然該樣本點(diǎn)的預(yù)測結(jié)果是錯誤的,但結(jié)合原始數(shù)據(jù)可以看出,這樣的結(jié)果是有跡可循且合理的。
在討論下一步工作之前,我們再嘗試一個(gè)新的可解釋性方法。
2.3 LIME - Local Interpretable Model-Agnostic Explanations
除了 Shapley 值,LIME 是另外一個(gè)應(yīng)用廣泛的局部解釋方法,其簡單易理解,基本思想是針對關(guān)注的樣本點(diǎn),在附近范圍內(nèi)生成擾動數(shù)據(jù)并用黑盒模型獲得對應(yīng)的預(yù)測結(jié)果,然后使用這些數(shù)據(jù)訓(xùn)練出一個(gè)局部近似的可解釋模型,通過該模型幫助分析原始機(jī)器學(xué)習(xí)模型的預(yù)測過程。MATLAB 中可以使用線性模型與決策樹模型作為局部的可解釋模型。
值得一提的是,由于近似模型的訓(xùn)練使用隨機(jī)生成的擾動數(shù)據(jù),模型的預(yù)測結(jié)果以及特征排序也會出現(xiàn)一定的隨機(jī)性。我們?nèi)匀豢紤]上文中姿態(tài)識別模型的 query point,使用線性模型對該點(diǎn)做近似分析,具體代碼及結(jié)果如下:
limeObj= lime(model, humanActivityData, 'QueryPoint',queryPt,'NumImportantPredictors',6);
f =plot(limeObj);
由于是線性模型,預(yù)測結(jié)果只是簡單地給出是否為 ‘Standing’,而橫坐標(biāo)反映的是線性模型中每個(gè)特征對應(yīng)的系數(shù)。一個(gè)有趣的現(xiàn)象是簡單模型的預(yù)測結(jié)果與黑盒模型的預(yù)測結(jié)果并不相同,這是否意味著這樣的結(jié)果是無效的、甚至是錯誤的?
我們先來選擇 rowstd_total_accd_z 與 towmean_total_acc_x,即系數(shù)正值和負(fù)值中絕對值最大的兩個(gè)特征(對應(yīng)在數(shù)據(jù)中的位置為 9 和 1),采用上文中介紹的方法分別繪制 PDP,我們將 ‘Sitting’ 和 ‘Standing’ 兩個(gè)類別的曲線繪制在一張圖中,結(jié)果如下:
plotPartialDependence(model,9,{'Sitting','Standing'},humanActivityDataTest)
plotPartialDependence(model,1,{'Sitting','Standing'},humanActivityDataTest);
這兩個(gè)特征分別代表 z 方向加速度的標(biāo)準(zhǔn)差與 x 方向加速度的均值,結(jié)合第 1 節(jié)中其在該樣本點(diǎn)的實(shí)際取值 rowstd_total_acc_z=0.0048 以及 rowmean_total_acc_x=1.0129 可以看出,1.0129 對于模型做出正確預(yù)測會起到十分積極的作用,這可能也是簡單模型能夠做出不是 ‘Standing’ 的原因,因?yàn)檎玖⒌淖藨B(tài)通常不會在 x 方向產(chǎn)生較大的加速度,與此同時(shí)簡單模型的 rowstd_total_acc_z 的系數(shù)雖然很大,但是取值較小,這意味著z方向加速度標(biāo)準(zhǔn)差較小,數(shù)據(jù)比較集中,從 PDP 中也能看出在該點(diǎn)對于 ‘Standing’ 和 ‘Sitting’ 的區(qū)分度并不高,要在數(shù)值增大之后才會對結(jié)果有較為顯著的影響。
需要說明的是,通過 LIME 得到的特征排序(或系數(shù)大小)和 Shapley 值得到的結(jié)果相差較大,部分原因是在 LIME 中基于隨機(jī)擾動生成的數(shù)據(jù)得到的模型和黑盒模型原本就存在一定差異,可以嘗試使用不同的隨機(jī)數(shù)或使用其他簡單模型來得到多樣化的結(jié)果進(jìn)行對比分析。
回到剛才的問題,這樣的簡單模型是否是無效的?其實(shí)機(jī)器學(xué)習(xí)的模型預(yù)測本身是一個(gè)十分復(fù)雜的過程,這是與黑盒模型強(qiáng)大的功能分不開的,無論是哪種解釋方法,目的都是幫助我們窺探預(yù)測的機(jī)理,從某一個(gè)角度理解分析產(chǎn)生這樣結(jié)果的原因,這些不同的角度相結(jié)合可以讓我們逐漸接近一個(gè)更加全面的分析結(jié)果,因此都是有意義的。
而 LIME 方法本身具備的隨機(jī)性以及簡單模型算法的選擇也給了我們更多可能性來進(jìn)行不同的嘗試,關(guān)于 LIME 的使用可以參考之前的文章了解更多:如何信任機(jī)器學(xué)習(xí)模型的預(yù)測結(jié)果?(上)與 如何信任機(jī)器學(xué)習(xí)模型的預(yù)測結(jié)果?(下)。
后續(xù)工作
獲得模型的解釋結(jié)果只是第一步,在得到以上分析結(jié)果之后我們接下來可以做些什么呢?
現(xiàn)在我們已經(jīng)知道 rowmean_body_gyro_z, rowstd_total_acc_z 等幾個(gè)特征對錯誤的分類結(jié)果有較大影響,我們可以進(jìn)一步從原始數(shù)據(jù)分析更深層次的原因,比如我們采集的這個(gè)樣本點(diǎn)的數(shù)據(jù)是否有誤差?如果原始數(shù)據(jù)沒問題,那么求平均值或標(biāo)準(zhǔn)差的特征提取方式是否合適?是否應(yīng)該選擇更加復(fù)雜的統(tǒng)計(jì)方式獲取特征?在模型的訓(xùn)練階段是否可以通過修改代價(jià)函數(shù)等手段提高預(yù)測準(zhǔn)確率?
顯然通過對一個(gè)樣本的分析,就得出關(guān)于整個(gè)模型的結(jié)論是不嚴(yán)謹(jǐn)?shù)摹R陨戏治鼋Y(jié)果提供給了我們一些思路和線索,我們可以對更多樣本點(diǎn)做類似分析,再結(jié)合其他手段去做下一步的改進(jìn)。
采用類似的方法,我們還可以對判斷正確的樣本進(jìn)行可解釋性的分析,來和我們對該問題的先驗(yàn)知識進(jìn)行對比,從而驗(yàn)證模型是否正確。
其他方法
上文中通過實(shí)例介紹了幾種不同的可解釋性方法,除此之外 MATLAB 還支持與 PDP 類似、但是會將單個(gè)預(yù)測結(jié)果進(jìn)行繪制以體現(xiàn)結(jié)果分布的 ICE 圖,以及本身具備可解釋性的 Generalized Additive 等諸多方法,可以在我們的幫助文檔中了解更多信息。
而對于深度學(xué)習(xí),同樣發(fā)展了很多類似的可解釋性的方法,深度學(xué)習(xí)被廣泛應(yīng)用于圖像、語音、信號處理等領(lǐng)域,針對這類問題,在 MATLAB 中可以很方便地使用 Occlusion Sensitivity, GradCAM 和 Image LIME 等方法,由于篇幅限制,本文不做詳細(xì)展開。
?
審核編輯:湯梓紅
評論
查看更多