本技術涉及通信,尤其涉及一種面向模型訓練的梯度數據傳輸方法及相關設備。
背景技術:
1、隨著大模型的模型參數和訓練數據集的指數級增長,數據中心使用分布式訓練框架來更有效地利用計算資源。數據并行是一種分布式人工智能(artificialintelligence,ai)模型訓練框架,它允許每個計算節點共享相同的模型參數,同時處理不同的數據子集。
2、然而在數據并行的參數服務器(parameter?server,ps)架構中,多個工作節點根據各自的數據子集計算梯度,而一個或多個參數服務器則存儲并更新模型參數。在完成一次梯度計算后,工作端將計算的梯度傳輸到參數服務器以更新模型的整體參數,這種多對一的通信模式被成為梯度聚合。在梯度聚合中,由于所有工作端使用同一條鏈路,容易產生帶寬瓶頸,從而引發網絡擁塞,這種網絡擁塞會延長梯度的傳輸時間,進而降低分布式人工智能模型訓練的效率。
技術實現思路
1、有鑒于此,本技術的目的在于提出一種面向模型訓練的梯度數據傳輸方法及相關設備,用以解決上述技術問題。
2、基于上述目的,本技術的第一方面提供了一種面向模型訓練的梯度數據傳輸方法,應用于梯度數據傳輸控制系統,所述系統包括交換機、參數服務器和多個工作端,所述方法由各個工作端執行,包括:
3、確定當前訓練次數的梯度數據,獲取與當前訓練次數的梯度數據對應的初始傳輸速率,以及獲取與當前訓練次數的梯度數據對應的當前時刻的傳輸速率;
4、確定當前時刻的傳輸速率為目標傳輸速率,基于所述初始傳輸速率和所述目標傳輸速率確定傳輸速率變化比率;
5、獲取當前訓練次數的梯度數據的總字節數,并統計當前訓練次數下已成功傳輸的梯度數據的字節數,基于當前訓練次數的梯度數據的總字節數和當前訓練次數下已成功傳輸的梯度數據的字節數確定當前訓練次數的剩余傳輸字節數;
6、基于當前訓練次數下已成功傳輸的梯度數據的字節數、當前訓練次數的梯度數據的總字節數、當前訓練次數的剩余傳輸字節數和傳輸速率變化比率通過壓縮比算法進行壓縮比確定處理,得到當前訓練次數的壓縮比;
7、基于所述當前訓練次數的壓縮比和當前訓練次數的梯度數據確定當前訓練次數的目標梯度數據;
8、獲取所述交換機與所述參數服務器之間傳輸鏈路的第一鏈路帶寬,以及多個工作端的總數量;
9、基于所述第一鏈路帶寬和所述多個工作端的總數量,確定每個工作端和所述交換機之間傳輸鏈路的第二鏈路帶寬;
10、將所述當前訓練次數的目標梯度數據按照所述第二鏈路帶寬傳輸至所述交換機,以供所述交換機將所述當前訓練次數的目標梯度數據按照所述第一鏈路帶寬傳輸至所述參數服務器。
11、可選地,所述確定當前時刻的傳輸速率為目標傳輸速率,包括:
12、獲取當前時刻的下一時刻的傳輸速率;
13、利用所述當前時刻的傳輸速率和所述下一時刻的傳輸速率進行求差處理,得到第一求差處理結果;
14、對所述第一求差處理結果和所述當前時刻的傳輸速率進行比值處理,得到傳輸速率變化值;
15、響應于所述傳輸速率變化值處于預設的傳輸速率變化值閾值的范圍內,確定所述當前時刻的傳輸速率為目標傳輸速率。
16、可選地,所述基于所述初始傳輸速率和所述目標傳輸速率確定傳輸速率變化比率,包括:
17、利用所述初始傳輸速率和所述目標傳輸速率進行求差處理,得到第二求差處理結果;
18、通過所述第二求差處理結果和所述初始傳輸速率進行比值處理,得到傳輸速率變化比率。
19、可選地,所述基于當前訓練次數的梯度數據的總字節數和當前訓練次數下已成功傳輸的梯度數據的字節數確定當前訓練次數的剩余傳輸字節數,包括:
20、對當前訓練次數的梯度數據的總字節數和當前訓練次數下已成功傳輸的梯度數據的字節數進行求差處理,得到當前訓練次數的剩余傳輸字節數。
21、可選地,所述基于當前訓練次數下已成功傳輸的梯度數據的字節數、當前訓練次數的梯度數據的總字節數、當前訓練次數的剩余傳輸字節數和傳輸速率變化比率通過壓縮比算法進行壓縮比確定處理,得到當前訓練次數的壓縮比,包括:
22、利用當前訓練次數的剩余傳輸字節數和傳輸速率變化比率進行乘積處理,得到乘積處理結果;
23、對當前訓練次數下已成功傳輸的梯度數據的字節數和所述乘積處理結果進行求和處理,得到求和處理結果;
24、利用所述求和處理結果和所述當前訓練次數的梯度數據的總字節數進行比值處理,得到當前訓練次數的壓縮比。
25、可選地,所述基于所述當前訓練次數的壓縮比和當前訓練次數的梯度數據確定當前訓練次數的目標梯度數據,包括:
26、對所述當前訓練次數的壓縮比和當前訓練次數的梯度數據進行乘積處理,得到當前訓練次數的目標梯度數據。
27、可選地,所述基于所述第一鏈路帶寬和所述多個工作端的總數量,確定每個工作端和所述交換機之間傳輸鏈路的第二鏈路帶寬,包括:
28、對所述第一鏈路帶寬和所述多個工作端的總數量進行比值處理,得到每個工作端和所述交換機之間傳輸鏈路的第二鏈路帶寬。
29、基于同一發明構思,本技術的第二方面提供了一種面向模型訓練的梯度數據傳輸裝置,應用于梯度數據傳輸控制系統,所述系統包括交換機、參數服務器和多個工作端,所述裝置設置于各個工作端,所述裝置包括:
30、第一獲取模塊,被配置為確定當前訓練次數的梯度數據,獲取與當前訓練次數的梯度數據對應的初始傳輸速率,以及獲取與當前訓練次數的梯度數據對應的當前時刻的傳輸速率;
31、變化比率確定模塊,被配置為確定當前時刻的傳輸速率為目標傳輸速率,基于所述初始傳輸速率和所述目標傳輸速率確定傳輸速率變化比率;
32、剩余傳輸確定模塊,被配置為獲取當前訓練次數的梯度數據的總字節數,并統計當前訓練次數下已成功傳輸的梯度數據的字節數,基于當前訓練次數的梯度數據的總字節數和當前訓練次數下已成功傳輸的梯度數據的字節數確定當前訓練次數的剩余傳輸字節數;
33、壓縮比確定模塊,被配置為基于當前訓練次數下已成功傳輸的梯度數據的字節數、當前訓練次數的梯度數據的總字節數、當前訓練次數的剩余傳輸字節數和傳輸速率變化比率通過壓縮比算法進行壓縮比確定處理,得到當前訓練次數的壓縮比;
34、目標梯度確定模塊,被配置為基于所述當前訓練次數的壓縮比和當前訓練次數的梯度數據確定當前訓練次數的目標梯度數據;
35、第二獲取模塊,被配置為獲取所述交換機與所述參數服務器之間傳輸鏈路的第一鏈路帶寬,以及多個工作端的總數量;
36、鏈路帶寬確定模塊,被配置為基于所述第一鏈路帶寬和所述多個工作端的總數量,確定每個工作端和所述交換機之間傳輸鏈路的第二鏈路帶寬;
37、梯度數據傳輸模塊,被配置為將所述當前訓練次數的目標梯度數據按照所述第二鏈路帶寬傳輸至所述交換機,以供所述交換機將所述當前訓練次數的目標梯度數據按照所述第一鏈路帶寬傳輸至所述參數服務器。
38、基于同一發明構思,本技術的第三方面提供了一種電子設備,包括存儲器、處理器及存儲在所述存儲器上并在處理器上運行的計算機程序,所述處理器在執行所述計算機程序時實現如上第一方面所述的方法。
39、基于同一發明構思,本技術的第四方面提供了一種非暫態計算機可讀存儲介質,所述非暫態計算機可讀存儲介質存儲計算機指令,所述計算機指令用于使計算機執行如上第一方面所述的方法。
40、從上面所述可以看出,本技術提供的面向模型訓練的梯度數據傳輸方法及相關設備,通過實時監控傳輸速率,根據傳輸速率的變化動態確定壓縮比,進而能夠根據壓縮比智能壓縮梯度數據,顯著減少了數據傳輸量,減輕了網絡負擔;同時,根據交換機與所述參數服務器之間的傳輸鏈路帶寬通過擁塞控制調整每個工作端的傳輸速率來緩解擁塞。在二者的共同作用下,顯著縮短了梯度數據的傳輸時間,提高了數據傳輸效率,從而加速了分布式人工智能模型的訓練過程,提升了整體訓練效率。