作者:安平博,Xilinx高級工程師;來源:AI加速微信公眾號
算符融合將多個計算單元揉進一個計算核中進行,減少了中間數據的搬移,節省了計算時間。TVM中將計算算符分成四種:
1 injective。一一映射函數,比如加法,點乘等。
2 reduction。輸入到輸出具有降維性質的,比如sum。
3 complex-out。這是計算比較復雜的,比如卷積運算等。
4 opaque。無法被融合的算符,比如sort。
根據以上對算符的不同類型,TVM提供了三種融合規則:
從一定角度看,這種融合實際上是數據計算pipeline化,即兩次計算中間數據不再經歷store-load的過程,而是直接給到下一個計算單元完成計算。
在走入fuse ops代碼之前,還需要了解一些算法基礎知識。算符融合中應用了支配樹算法。在一個有向無環圖中,對于一個節點n來說,從初始節點s出發到達n的所有路徑都經歷一個節點m,那么m就是n的支配點。而距離n最近的支配點被稱作立即支配點。以r為樹根,將所有立即支配點按照支配關系連接起來就形成了支配樹。立即后支配點是從一個點n出發所有到終止節點的路徑中通過的最近節點,形成的支配樹是后支配樹。
在DAG中,對于一個點,所有能到達它的點在支配樹中的LCA,就是它支配樹中的父親。為什么算符融合要建立在后支配樹的基礎上呢?我猜測可能是因為對于兩個可融合算符在DAG中位置分為兩種,一種是父子關系,那么可以直接執行算符融合算法;另外一種是它們之間是后支配關系。對于具有后支配關系的兩個節點(n->m),就要判斷未來路徑上的節點是否都能夠和點m發生融合,如果可以,那么n也可以和m發生融合。比如下圖:
Conv2d要和elemwise add融合,必須判斷它的三個op是否能和elemwise add融合。
TVM中融合流程分為三步:
1 遍歷relay樹,建立DAG用于后支配樹分析;
2 建立后支配樹;
3 應用算符融合算法。
一 建立DAG圖
算符融合代碼在src/relay/transforms/fuse_ops.cc中。其中算符融合也應用在常量折疊中。
首先TVM中通過如下代碼來遍歷relay樹結構并建立DAG圖。
VisitExpr可以遞歸的調用在類IndexedforwardGraph中定義的VisitExpr_函數,通過深度優先搜索遍歷relay樹,并且建立DAG圖。深度優先搜索是從exit節點作為根節點反向搜鎖的,因此搜索樹是一個后序搜索樹。Outputs中保存了一個節點的輸入的邊,在構建后序支配樹會通過這些輸入邊求取LCA。那么在這個搜索樹基礎上應用支配樹算法,就能夠得到一個后序支配樹了。在這個類中針對不同節點類型重寫visitExpr_函數,節點類型有FunctionNode,ConstantNode, CallNode, TuppleNode等。我們來看CallNode的訪問函數定義:
在最后還會遞歸調用ExprVisitor::VisitExpr_函數,最終將深度優先搜索到的節點按照葉節點起始順序一次加入DAG圖中。只有ConstantNode的訪問函數中不再調用VisitExpr_,因為常量節點應該不存在葉節點了。在callNode中會將其輸入加入到DAG中,同時遍歷和輸入以及其op連接的節點,ExprVisitor中對CallNode訪問函數定義為:
因為ExprVisitor是被IndexForwardGraph繼承的,而VisitExpr_是虛擬函數,this就會指向IndexForwardGraph實例,最終就會調用這個類中定義的VisitExpr_函數,實現遞歸的遍歷relay樹。
這里要關注一下OpPatternKind,它定義了算子類型,是不同融合算法使用的依據。其定義在include/tvm/relay/op_attr_types.h文件中。
二 建立后序支配樹
接下來看后序支配樹的構建。構建函數是PostDom。因為根節點(DAG圖的出口)在post_dfs_order中最后,所以從根節點開始尋找每個節點出點的LCA,這個LCA就是后序支配點。
GetNode函數是獲得支配點,構建支配樹。在GetNode中,首先初始化根節點,然后求每個節點的輸入節點的LCA,即是這個節點的支配點。
LeastComonAncestor函數中主要代碼是:
通過兩兩求節點的LCA,來求取所有節點的LCA。程序會將計算圖中的末節點深度設置為1。然后向上逐層增加,那么LCA的共同祖先是相同的,深度也一定是一致。遍歷所有的節點,就得到一個后向支配樹。節點的pattern指向他的LCA。在計算支配點的pattern的時候,會依據pattern的定義,選擇pattern值最大的作為LCA的pattern。這塊不是太深入理解。可能是其定義的從最小值到最大值pattern可以向下進行融合,比如kElemWise=0, kInjective=2, 那么前者就能融合到KInjective中。
三 融合
完成了DAG和postDominator tree構建后,就開始融合操作。TVM中定義了group結構體,用于表示融合后的圖結構。Group結構體如下:
如果某些算符可以融合,那么就通過這個結構體中的parent,master_ref將這些節點建立連接關系。Group首先進行初始化和DAG相同的圖。然后分別遍歷dag,postDominator tree,以及group圖中節點,來判斷算符是否能被融合。Dag中和postDom中對應相同index的節點分別是被支配點和支配點。主要融合函數是以下兩個函數:
在runFuse中,有幾種情況是不進行算符融合的:
1 算符類型是Kopaque的。
2 該節點不存在支配點。
3 能夠融合的節點超過了一定數量。
融合操作算法基本上是考察當前節點到其支配點所有路徑上的節點是否都符合融合規則,如果符合就進行融合,不符合就不融合。函數CheckPath就是用于考察src到sink路徑是否能夠融合的。
融合分成了三個phase,每個phase處理不同可融合類型。這里我沒有深入研究。當判斷支配樹的前后節點可以融合后,就通過函數commitFuse執行融合操作。
完成融合之后,會遍歷節點創建新的graph。
審核編輯:何安
-
TVM
+關注
關注
0文章
19瀏覽量
3654
發布評論請先 登錄
相關推薦
評論