0x0. 前言
這篇文章來解析一下Megaton-LM涉及到的一個優化gradient_accumulation_fusion。這里fusion的意思是在gemm接口中會將當前的結果累加到先前計算的梯度上,所有這些都在一個操作中完成,可以避免多次訪問global memory提升算子的帶寬。下面解析一下這個優化的調度邏輯和cuda實現。
https://github.com/BBuf/how-to-optim-algorithm-in-cuda 這個倉庫整理了一些cuda優化相關鏈接以及大模型訓練推理相關的知識鏈接(large-language-model-note子目錄下),歡迎查看。
0x1. 調度邏輯解析
gradient_accumulation_fusion的調度邏輯是和LinearWithGradAccumulationAndAsyncCommunication這個類的實現有關的,LinearWithGradAccumulationAndAsyncCommunication 這個類又被包了一層變成 linear_with_grad_accumulation_and_async_allreduce 這個函數,這個函數又給RowParallelLinear和ColumnParallelLinear這兩個實現模型并行的Linear類使用。
下面解析一下linear_with_grad_accumulation_and_async_allreduce這個函數(https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py#L356-L446):
#這部分定義了一個函數,名為linear_with_grad_accumulation_and_async_allreduce, #它接收七個參數:輸入張量、權重張量、一個可選的偏置張量和3個布爾標志。 deflinear_with_grad_accumulation_and_async_allreduce( input:torch.Tensor, weight:torch.Tensor, bias:Optional[torch.Tensor], gradient_accumulation_fusion:bool, async_grad_allreduce:bool, sequence_parallel_enabled:bool, )->torch.Tensor: """帶有反向傳播的異步通信和梯度累積融合的線性層實現. 此函數提供了一個選項,可以將反向傳播計算的結果累積到一個現有的梯度緩沖區中, 從而避免在梯度計算后進行額外的加法核操作。 此外,輸入梯度的張量并行allreduce可以與權重梯度的計算異步進行。 在使用序列并行的情況下,輸入梯度的reducescatter與權重梯度的計算異步進行。 使用此模塊需要環境變量CUDA_DEVICE_MAX_CONNECTIONS=1。代碼中有一些集合操作, 應該在計算核之前調度,以使通信與計算重疊,這對于加速是必要的,但對于正確性則不是必要的, 因此調度器不會強制這種排序。將CUDA_DEVICE_MAX_CONNECTIONS設置為1會強制按照它們被調用的順序調度內核。 Arguments: input(torch.Tensorrequired):輸入,類似torch.nn.functional.linear weight(torch.Tensorrequired):權重,類似torch.nn.functional.linear bias(torch.Tensoroptional):偏置,類似torch.nn.functional.linear gradient_accumulation_fusion(boolrequired):執行梯度累積融合, 需要自定義的CUDA擴展模塊fused_weight_gradient_mlp_cuda。 要使用gradient_accumulation_fusion,你必須使用--cpp_ext和--cuda_ext安裝APEX。 例如:"pipinstall--global-option="--cpp_ext"--global-option="--cuda_ext." 注意,此擴展要求CUDA版本大于或等于11。否則,你必須關閉梯度累積融合。 async_grad_allreduce(boolrequired):異步地與權重梯度的計算進行輸入梯度的allreduce。 如果sequence_parallel_enabled為True,這必須為False,因為不執行allreduce。 sequence_parallel_enabled(boolrequired):表示使用了序列并行, 因此在前向傳播中,輸入是addgather后的,在反向傳播中,輸入梯度是reducescatter后的。 """ #這部分創建了一個名為args的列表,它基本上是函數輸入參數的集合。 args=[ input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel_enabled, ] #這部分檢查是否已經發出警告。函數使用一個類級別變量warned來記住是否已經向用戶顯示了警告。 ifnotlinear_with_grad_accumulation_and_async_allreduce.warned: #這部分檢查環境變量CUDA_DEVICE_MAX_CONNECTIONS是否設置為"1"。 #如果沒有,并且滿足某些條件(sequence_parallel_enabled或async_grad_allreduce), #它會發出警告。然后將warned標志設置為True,以便不會重復發出此警告。 ifos.environ.get('CUDA_DEVICE_MAX_CONNECTIONS')!="1": ifsequence_parallel_enabled: warnings.warn( "Whenusingsequenceparallelismitisrecommendedtosetthe" "environmentvariableCUDA_DEVICE_MAX_CONNECTIONSto1for" "maximumspeedup") linear_with_grad_accumulation_and_async_allreduce.warned=True ifasync_grad_allreduce: warnings.warn( "Whenusingasyncgradallreduceitisrecommendedtosetthe" "environmentvariableCUDA_DEVICE_MAX_CONNECTIONSto1for" "maximumspeedup") linear_with_grad_accumulation_and_async_allreduce.warned=True #最后,函數調用另一個名為LinearWithGradAccumulationAndAsyncCommunication的類并返回其結果。 returnLinearWithGradAccumulationAndAsyncCommunication.apply(*args) #在函數外部,初始化屬性warned為False。這用于檢查是否已經向用戶發出警告。 linear_with_grad_accumulation_and_async_allreduce.warned=False
解著解析一下LinearWithGradAccumulationAndAsyncCommunication這個類的實現(https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py#L232):
#這定義了一個名為LinearWithGradAccumulationAndAsyncCommunication的類, #該類繼承自torch.autograd.Function。 classLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): """Seelinear_with_grad_accumulation_and_async_allreduce""" #使用兩個裝飾器標記forward方法。其中@staticmethod表示這是一個靜態方法, #而@custom_fwd是一個自定義裝飾器,用于特定的前向傳播操作。 @staticmethod @custom_fwd defforward( ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel, ): #使用上下文對象ctx保存輸入和權重,以便在后向傳播中使用。 ctx.save_for_backward(input,weight) #在上下文對象ctx中存儲其他變量和標志。 ctx.use_bias=biasisnotNone ctx.gradient_accumulation_fusion=gradient_accumulation_fusion ctx.async_grad_allreduce=async_grad_allreduce ctx.sequence_parallel=sequence_parallel #如果啟用了序列并行,則進行以下操作: ifsequence_parallel: #獲取模型并行的world_size(通常是參與并行處理的GPU數量)。 world_size=get_tensor_model_parallel_world_size() #更改輸入的第一個維度以考慮模型并行的全部大小。 dim_size=list(input.size()) dim_size[0]=dim_size[0]*world_size #收集所有GPU上的輸入。 all_gather_buffer=get_global_memory_buffer().get_tensor(dim_size,input.dtype,"mpu") torch.distributed._all_gather_base( all_gather_buffer,input,group=get_tensor_model_parallel_group() ) #更新total_input為收集的數據。 total_input=all_gather_buffer else: #如果不使用序列并行,則total_input僅僅是傳入的輸入。 total_input=input #對total_input和weight的轉置進行矩陣乘法以計算輸出。 output=torch.matmul(total_input,weight.t()) #如果提供了偏置,則將其添加到輸出中 ifbiasisnotNone: output=output+bias returnoutput @staticmethod @custom_bwd defbackward(ctx,grad_output): #從上下文對象中恢復前向傳播保存的張量。 input,weight=ctx.saved_tensors #從上下文對象中恢復偏置使用的信息。 use_bias=ctx.use_bias #如果啟用了序列并行,要如何獲取完整的輸入數據。 #它通過分布式的_all_gather_base函數來異步地聚集所有輸入。 ifctx.sequence_parallel: world_size=get_tensor_model_parallel_world_size() dim_size=list(input.size()) dim_size[0]=dim_size[0]*world_size all_gather_buffer=get_global_memory_buffer().get_tensor(dim_size,input.dtype,"mpu") handle=torch.distributed._all_gather_base( all_gather_buffer,input,group=get_tensor_model_parallel_group(),async_op=True ) #HerewerelyonCUDA_DEVICE_MAX_CONNECTIONS=1toensurethatthe #gatherisscheduledbeforetheinputgradientcomputation total_input=all_gather_buffer #如果沒有啟用序列并行,那么完整的輸入就是原始輸入。 else: total_input=input #通過矩陣乘法計算關于輸入的梯度。 grad_input=grad_output.matmul(weight) #如果啟用了序列并行,則等待所有聚集操作完成。 ifctx.sequence_parallel: handle.wait() #Doinggather+slicingduringtheNeMoforwardpasscanmakethistensor #notbecontiguous.PyTorchonlychecksifthetensoriscontiguous,andonly #clonesitifit'snotcontiguous: #https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 #這些是注釋,提到在NeMo的前向傳遞中,執行gather和slicing操作可能會導致grad_output張量 #不是連續的。PyTorch只檢查張量是否是連續的,并且只在不連續時克隆它。 grad_output=grad_output.contiguous()#確保grad_output是連續的 #Convertthetensorshapesto2Dforexecutioncompatibility #將grad_output張量的形狀轉化為2D,以確保兼容性。 grad_output=grad_output.view( grad_output.shape[0]*grad_output.shape[1],grad_output.shape[2] ) #同樣地,將total_input張量也轉化為2D。 total_input=total_input.view( total_input.shape[0]*total_input.shape[1],total_input.shape[2] ) #如果啟用了異步的梯度all-reduce,執行該操作。這是一個分布式操作,用于聚合所有工作節點上的梯度。 ifctx.async_grad_allreduce: #Asynchronousall-reduce handle=torch.distributed.all_reduce( grad_input,group=get_tensor_model_parallel_group(),async_op=True ) #HerewerelyonCUDA_DEVICE_MAX_CONNECTIONS=1toensurethatthe #all-reduceisscheduledbeforetheweightgradientcomputation #如果啟用了序列并行,則不應該在此處啟用異步all-reduce(由assert語句確保)。 #接著,創建一個新的sub_grad_input張量,并執行一個reduce_scatter操作。 #這是一個分布式操作,它會將輸入的梯度從所有工作節點上聚合到一個工作節點上。 ifctx.sequence_parallel: assertnotctx.async_grad_allreduce dim_size=list(input.size()) sub_grad_input=torch.empty( dim_size,dtype=input.dtype,device=torch.cuda.current_device(),requires_grad=False ) #reduce_scatter handle=torch.distributed._reduce_scatter_base( sub_grad_input,grad_input,group=get_tensor_model_parallel_group(),async_op=True ) #HerewerelyonCUDA_DEVICE_MAX_CONNECTIONS=1toensurethatthe #reducescatterisscheduledbeforetheweightgradientcomputation #根據是否啟用了梯度累積融合,使用特定的CUDA操作或標準的矩陣乘法來計算權重的梯度。 #這個條件檢查是否啟用了梯度累積融合。梯度累積通常在小批量訓練中用于累積梯度以在較大的有效批量上更新模型。 ifctx.gradient_accumulation_fusion: ifweight.main_grad.dtype==torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( total_input,grad_output,weight.main_grad ) elifweight.main_grad.dtypein(torch.float16,torch.bfloat16): fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( total_input,grad_output,weight.main_grad ) else: raiseRuntimeError("Unsupportedgradienttypeforgradientaccumulationfusion") #在梯度累積融合的情況下,設置grad_weight為None, #這意味著梯度已經在前面的CUDA函數中直接更新了(weight.main_grad),所以在這里沒有返回值。 grad_weight=None else: grad_weight=grad_output.t().matmul(total_input) #如果使用偏置,則計算關于偏置的梯度。 grad_bias=grad_output.sum(dim=0)ifuse_biaselseNone #如果啟用了序列并行,等待上述操作完成,并返回計算得到的梯度。 ifctx.sequence_parallel: handle.wait() returnsub_grad_input,grad_weight,grad_bias,None,None,None #如果啟用了異步all-reduce,等待all-reduce操作完成。 ifctx.async_grad_allreduce: handle.wait() returngrad_input,grad_weight,grad_bias,None,None,None
可以看到gradient_accumulation_fusion這個優化作用于Linear層中對weight求梯度的時候,調用了apex庫提供的2個fuse cuda kernel原地更新了weight的梯度。
0x2. fused_weight_gradient_mlp_cuda 實現
fused_weight_gradient_mlp_cuda接口分別為float32和float16/bfloat16提供了2個cuda kernel實現,我們先看一下上層的接口。(https://github.com/NVIDIA/apex/blob/master/csrc/megatron/fused_weight_gradient_dense.cpp)
//定義了一個名為wgrad_gemm_accum_fp32_cuda_stub的函數原型。這是一個CUDAC++函數, //用于處理float32數據類型的權重梯度累積。該函數接受三個at::Tensor參數: //input_2d,d_output_2d,和d_weight。 voidwgrad_gemm_accum_fp32_cuda_stub( at::Tensor&input_2d, at::Tensor&d_output_2d, at::Tensor&d_weight ); //定義了一個名為wgrad_gemm_accum_fp16_cuda_stub的函數原型,與上面的函數類似, //但它是為float16數據類型設計的。 voidwgrad_gemm_accum_fp16_cuda_stub( at::Tensor&input_2d, at::Tensor&d_output_2d, at::Tensor&d_weight ); PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){ m.def("wgrad_gemm_accum_fp32",&wgrad_gemm_accum_fp32_cuda_stub,"wgradgemmaccuminfp32"); m.def("wgrad_gemm_accum_fp16",&wgrad_gemm_accum_fp16_cuda_stub,"wgradgemmaccuminfp16"); }
接下來解析一下wgrad_gemm_accum_fp32這個kernel,對應 https://github.com/NVIDIA/apex/blob/master/csrc/megatron/fused_weight_gradient_dense_cuda.cu 這個文件。
//這個函數是一個封裝了NVIDIAcuBLAS庫中的cublasGemmEx函數的C++函數, //專門用于執行BFloat16(BF16)的矩陣乘法(GEMM)操作。 //函數的名稱為gemmex_wrapper,它的設計意圖是提供一個簡單的接口, //使得PyTorch可以方便地利用cuBLAS中的高效GEMM操作,特別是當使用BFloat16數據類型時。 //BF16TensorcorewrapperaroundcublasGEMMEx voidgemmex_wrapper( cublasHandle_thandle,//cuBLAS庫的句柄,用于管理cuBLAS調用。 cublasOperation_ttransa, cublasOperation_ttransb,//這兩個參數描述了兩個輸入矩陣A和B是否需要轉置。 //定義了矩陣A,B和輸出矩陣C的維度。具體來說,矩陣A的維度為mxk, //矩陣B的維度為kxn,輸出矩陣C的維度為mxn。 intm, intn, intk, constfloat*alpha,//標量系數,用于計算alpha*A*B。 at::BFloat16*A,//輸入矩陣A,它們都是BFloat16數據類型。 intlda,//這個參數是矩陣A的leadingdim,通常與矩陣的行數相同。 at::BFloat16*B, intldb, constfloat*beta,//標量系數,用于計算beta*C。 float*C,//輸出矩陣C,它是float數據類型。 intldc){//矩陣C的leading維度,通常與矩陣C的行數相同。 //使用TORCH_CUDABLAS_CHECK宏調用了cublasGemmEx函數。這是cuBLAS庫中用于執行混合精度矩陣乘法的函數。 //cublasGemmEx函數的參數主要用于描述輸入和輸出矩陣的屬性,以及要執行的具體操作。 //在這里,輸入矩陣A和B都是BFloat16數據類型,而輸出矩陣C是float數據類型。 //CUDA_R_16BF和CUDA_R_32F是枚舉值,用于描述矩陣的數據類型。 //CUBLAS_GEMM_DEFAULT_TENSOR_OP是一個枚舉值,指示cuBLAS使用默認的TensorCore操作來執行GEMM。 TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } //類似上面的函數,用于執行FP16的矩陣乘法 //FP16TensorcorewrapperaroundcublasGEMMEx voidgemmex_wrapper( cublasHandle_thandle, cublasOperation_ttransa, cublasOperation_ttransb, intm, intn, intk, constfloat*alpha, at::Half*A, intlda, at::Half*B, intldb, constfloat*beta, float*C, intldc){ TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } //類似上面的函數,用于執行FP32的矩陣乘法 //FP32wrapperaroundcublasGEMMEx voidgemmex_wrapper( cublasHandle_thandle, cublasOperation_ttransa, cublasOperation_ttransb, intm, intn, intk, constfloat*alpha, float*A, intlda, float*B, intldb, constfloat*beta, float*C, intldc){ TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_32F, lda, B, CUDA_R_32F, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } //這個函數wgrad_gemm_accum_fp32_cuda是一個模板函數,用于在CUDA上執行累加的權重梯度計算(矩陣乘法)。 //它使用了前面提到的gemmex_wrapper函數,該函數是NVIDIAcuBLAS庫中的cublasGemmEx函數的封裝, //用于執行高效的矩陣乘法。 template voidwgrad_gemm_accum_fp32_cuda(T*input,T*d_output,float*d_weight,intin_dim,inthidden_dim,intout_dim){ //獲取當前CUDAcuBLAS句柄。 cublasHandle_thandle=at::getCurrentCUDABlasHandle(); //獲取CUDAStream。 cudaStream_tstream; //從cuBLAS句柄獲取當前CUDA流。 cublasGetStream(handle,&stream); //定義矩陣乘法的標量系數,用于計算alpha*A*B+beta*C。 constfloatalpha=1.0; constfloatbeta=1.0; //使用CUBLAS_OP_N和CUBLAS_OP_T作為參數,表示輸入矩陣不需要轉置,但d_output矩陣需要轉置。 //使用輸入矩陣input和輸出矩陣的梯度d_output作為輸入,將結果存儲在權重梯度d_weight中。 gemmex_wrapper( handle, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, out_dim, hidden_dim, &alpha, input, in_dim, d_output, out_dim, &beta, d_weight, in_dim); } //這是為數據類型at::Half(即半精度浮點型,也稱為FP16)顯式實例化的wgrad_gemm_accum_fp32_cuda函數。 //使用此數據類型的版本,可以進行更快速的計算,尤其是在支持FP16計算的硬件上。 templatevoidwgrad_gemm_accum_fp32_cuda(at::Half*input,at::Half*d_output,float*d_weight,intin_dim,inthidden_dim,intout_dim); templatevoidwgrad_gemm_accum_fp32_cuda(at::BFloat16*input,at::BFloat16*d_output,float*d_weight,intin_dim,inthidden_dim,intout_dim); templatevoidwgrad_gemm_accum_fp32_cuda(float*input,float*d_output,float*d_weight,intin_dim,inthidden_dim,intout_dim); //這個函數名為wgrad_gemm_accum_fp32_cuda_stub,從名字中可以看出這是一個為CUDA定義的存根函數。 //它處理輸入的張量,調整它們的維度,然后調用對應的CUDA模板函數來完成具體的操作。 voidwgrad_gemm_accum_fp32_cuda_stub( at::Tensor&input, at::Tensor&d_output, at::Tensor&d_weight ){ at::Tensorinput_2d,d_output_2d; //inputtensor:collapsetothefirstdim autoin_sizes=input.sizes(); //如果input張量的維度大于2,它將最后一個維度以外的所有維度折疊為第一個維度, //使其成為一個2D張量input_2d。否則,它將使用原始input張量。 if(input.dim()>2){ input_2d=input.view({-1,in_sizes[in_sizes.size()-1]}); }else{ input_2d=input; } //d_outputtensor:collapsetothefirstdim //類似地,如果d_output張量的維度大于2,它也會進行同樣的維度轉換。 //否則,它會使用原始的d_output張量。 autod_out_sizes=d_output.sizes(); if(d_output.dim()>2){ d_output_2d=d_output.view({-1,d_out_sizes[d_out_sizes.size()-1]}); }else{ d_output_2d=d_output; } //hidden_dim是input_2d的第一個維度的大小。 constinthidden_dim=input_2d.size(0); //in_dim是input_2d的第二個維度的大小。 constintin_dim=input_2d.size(1); //out_dim是d_weight的第一個維度的大小。 constintout_dim=d_weight.size(0); //使用DISPATCH_FLOAT_HALF_AND_BFLOAT宏來基于input_2d的數據類型調用相應的函數。 //這意味著,根據輸入數據的數據類型(浮點、半精度或BFloat16), //它將選擇正確的版本的wgrad_gemm_accum_fp32_cuda函數進行調用。 DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(),0,"wgrad_gemm_accum_fp32", wgrad_gemm_accum_fp32_cuda( input_2d.data_ptr(), d_output_2d.data_ptr(), d_weight.data_ptr(), in_dim, hidden_dim, out_dim); ); }
注意,在Kernel中這里會將當前的結果累加到先前計算的梯度上,所有這些都在一個操作中完成,這是fuse的思想,可以避免多次訪問global memory提升算子的帶寬。
審核編輯:彭菁
-
邏輯
+關注
關注
2文章
824瀏覽量
29387 -
異步通信
+關注
關注
1文章
55瀏覽量
10088 -
函數
+關注
關注
3文章
4235瀏覽量
61965 -
大模型
+關注
關注
2文章
2134瀏覽量
1974
原文標題:0x3. 總結
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論