本發明屬于機器學習和人工智能,尤其涉及一種基于任務關聯度的多任務學習輔助任務選擇方法。
背景技術:
1、在工業應用中,機器學習模型的準確性常常能夠借助遷移學習從相關的輔助信息中得到提升。特別是在制造業領域,預測設備故障時間這一主任務,可以借助輔助信息如設備的歷史傳感器數據、維護記錄和運行狀態等來優化。一種常見的實現這種遷移學習的方法是將這些輔助信息轉化為輔助任務(例如,預測機器的溫度波動或振動異常),并與主任務(如故障預測)在多任務網絡中聯合優化。這種多任務網絡通過共享底層結構,實現了輔助任務知識向主任務的傳遞,從而提升主任務的預測準確性。
2、提升主任務性能的關鍵在于選擇關聯性強的輔助任務。然而,傳統的輔助任務選擇方法存在以下問題,難以滿足復雜多任務場景的需求:
3、1、傳統的輔助任務選擇方法中,輔助任務的選擇往往依賴于專家經驗或簡單的任務組合,缺乏科學系統的選擇機制,無法確保最優的任務組合。
4、2、當前研究主要通過調整任務權重或優化模型結構來提升多任務學習效果。然而,權重調整過程復雜且耗時,特別是在數據量大或任務數量多的情況下,難以實現最優平衡。
5、3、隨著任務數量和數據復雜性增加,傳統方法難以處理復雜任務關聯,輔助任務選擇的準確性隨之下降。
6、4、不當的任務選擇可能導致輔助任務與主任務產生沖突,造成負遷移現象,不僅無法提升主任務的性能,反而會增加模型訓練的難度和不確定性。
7、為了解決以上問題,本發明提出了一種基于任務關聯度的多任務學習輔助任務選擇方法。
技術實現思路
1、本發明的目的在于提供一種基于任務關聯度的多任務學習輔助任務選擇方法,旨在解決上述背景技術中提出的問題。
2、為實現上述目的,本發明提供如下技術方案:
3、一種基于任務關聯度的多任務學習輔助任務選擇方法,包括以下步驟:
4、步驟s1、數據集準備:在多任務學習中,首先定義主任務和若干輔助任務,然后構建一個包含主任務和多個輔助任務的多任務學習數據集,數據集包含各任務的特征變量和目標值;
5、步驟s2、tcc計算:通過衡量每個任務損失變化與主任務梯度方向的相似度來動態評估任務間的相關性;
6、步驟s3、輔助任務選擇:根據計算得到的tcc值,分析輔助任務與主任務之間的協同作用或沖突,并選擇協同作用最強的輔助任務;
7、步驟s4、模型訓練與優化:使用選定的輔助任務和主任務構建多任務學習模型;在訓練過程中,采用標準的訓練步驟,使用所有任務的聯合梯度對模型進行優化;定期計算并更新tcc值。
8、進一步的,所述tcc值的取值范圍為-1~1,接近1的tcc值表示協同作用強,接近-1的值表示存在沖突。
9、進一步的,所述tcc計算步驟的具體過程如下:
10、對于每個訓練批次在時間步 t上,設表示在針對主任務應用梯度步驟后更新的共享參數;在假設隨機梯度下降的條件下,表示為:
11、式1:;
12、其中,表示更新前的共享參數;表示學習率;表示主任務在批量下對參數的梯度;
13、使用更新后的共享參數計算主任務和每個輔助任務的損失變化,同時保持每個任務的特定參數和輸入批次不變;
14、輔助任務 i的損失變化定義為:
15、式2:;
16、其中,表示輔助任務 i的損失函數;
17、主任務的損失變化定義為:
18、式3:;
19、其中,表示主任務特定部分參數;
20、收集所有批次的損失變化序列,并計算tcc,量化輔助任務對主任務優化方向的協同程度或沖突:
21、式4:;
22、其中,表示主任務和輔助任務的tcc分數;表示主任務損失變化序列;表示輔助任務損失變化序列;表示主任務損失變化序列和輔助任務損失變化序列之間的協方差;和分別表示主任務和輔助任務損失變化的標準差;
23、在每個訓練周期結束時計算tcc,并將結果平均,以獲得整個訓練過程中的綜合tcc值,綜合tcc值表示為:
24、式5:;
25、其中, n表示第n個訓練周期; n表示總訓練周期數。
26、與現有技術相比,本發明的有益效果是:
27、1、本發明通過引入tcc,量化輔助任務與主任務之間的關聯性,從而自動篩選最適合主任務的輔助任務。相較于傳統依賴專家經驗和直覺的選擇方法,本發明提升了選擇的效率和精度,有效減少了負遷移現象的發生,使得多任務學習更加高效和可靠。
28、2、本發明引入的tcc能夠有效評估各任務之間的相關性,自動選擇出與主任務最相關的輔助任務,避免直接搜索任務組合所帶來的高昂計算成本,特別是在任務數量較多時顯著降低了計算復雜度,還確保了多任務學習的效果和效率。
29、3、本發明提出的方法不僅不受模型類型的限制,展現出廣泛的適用性,而且無需增加額外的參數,通過簡單的設置即可獲得顯著效果,具備較強的適用性和易用性。這使得本發明為復雜多任務場景下的模型優化提供了可靠的解決方案,并具有廣泛的應用前景。
1.一種基于任務關聯度的多任務學習輔助任務選擇方法,其特征在于,包括以下步驟:
2.根據權利要求1所述的基于任務關聯度的多任務學習輔助任務選擇方法,其特征在于,所述tcc值的取值范圍為-1~1,接近1的tcc值表示協同作用強,接近-1的值表示存在沖突。
3.根據權利要求1所述的基于任務關聯度的多任務學習輔助任務選擇方法,其特征在于,所述tcc計算步驟的具體過程如下: