本技術涉及人工智能,具體而言,涉及一種高斯預測模型訓練方法、機器人強化學習方法及電子設備。
背景技術:
1、在強化學習?(reinforcement?learning,rl)中,尤其是基于視覺的強化學習任務中,出色的環(huán)境表示參數(shù)對學習效果至關重要,其質(zhì)量直接影響學習任務的完成情況。
2、現(xiàn)有基于視覺的rl通常采用顯示或隱式方式表示環(huán)境,如圖像、點、體素和神經(jīng)輻射場,但是這些表示方式存在諸多缺陷,要么無法描述復雜的局部幾何結(jié)構(gòu),要么難以推廣到未見場景,部分還需要前景掩膜,導致無法為rl提供的出色的環(huán)境表示參數(shù),降低強化學習的效果。
技術實現(xiàn)思路
1、本技術的目的在于,針對上述現(xiàn)有技術中的不足,提供一種高斯預測模型訓練方法、機器人強化學習方法及電子設備,以便生成高質(zhì)量的環(huán)境表示參數(shù),提高強化學習的效果。
2、為實現(xiàn)上述目的,本技術實施例采用的技術方案如下:
3、第一方面,本技術實施例提供了一種高斯預測模型訓練方法,所述方法包括:
4、獲取訓練三維場景在第一預設視角的第一樣本圖像和第一目標視角的第二樣本圖像;
5、根據(jù)所述第一樣本圖像采用預訓練高斯預測模型進行參數(shù)預測,得到第一預測高斯參數(shù)和第二預測高斯參數(shù),所述第二預測高斯參數(shù)是對所述第一預測高斯參數(shù)進行去噪后的屬性參數(shù);
6、根據(jù)所述第二預測高斯參數(shù)生成所述第一目標視角的第一預測圖像;
7、根據(jù)所述第一預測圖像和所述第二樣本圖像,計算第一訓練損失;
8、根據(jù)所述第一預測高斯參數(shù)和所述第二預測高斯參數(shù),計算第二訓練損失;
9、根據(jù)所述第一訓練損失和所述第二訓練損失,對所述預訓練高斯預測模型進行訓練,得到目標高斯預測模型,所述目標高斯預測模型用于部署于機器人,以使得所述機器人根據(jù)所述目標高斯預測模型生成工作場景的場景高斯參數(shù),根據(jù)所述場景高斯參數(shù)進行強化學習,確定所述機器人的動作。
10、可選地,所述預訓練高斯預測模型包括:預訓練參數(shù)預測模型和初始參數(shù)優(yōu)化模型;所述根據(jù)所述第一樣本圖像采用預訓練高斯預測模型進行參數(shù)預測,得到第一預測高斯參數(shù)和第二預測高斯參數(shù),包括:
11、根據(jù)所述第一樣本圖像采用所述預訓練參數(shù)預測模型進行參數(shù)預測,得到所述第一預測高斯參數(shù);
12、根據(jù)所述第一預測高斯參數(shù)采用所述初始參數(shù)優(yōu)化模型進行優(yōu)化,得到所述第二預測高斯參數(shù)。
13、可選地,所述根據(jù)所述第一樣本圖像采用所述預訓練參數(shù)預測模型進行參數(shù)預測,得到所述第一預測高斯參數(shù),包括:
14、采用所述預訓練參數(shù)預測模型中的深度預測模塊獲取所述第一樣本圖像的樣本特征圖,并對所述樣本特征圖進行深度估計,得到深度圖像;
15、采用所述預訓練參數(shù)預測模型中的屬性預測模塊對所述第一樣本圖像、所述樣本特征圖和所述深度圖像對應的融合特征圖進行參數(shù)預測,得到預測旋轉(zhuǎn)參數(shù)、預測縮放參數(shù)和預測透明度參數(shù);
16、根據(jù)所述深度圖像確定預測三維位置參數(shù);
17、根據(jù)所述第一樣本圖像的顏色值確定預測顏色參數(shù);
18、其中,所述第一預測高斯參數(shù)包括:所述預測三維位置參數(shù)、所述預測旋轉(zhuǎn)參數(shù)、所述預測縮放參數(shù)、所述預測顏色參數(shù)和所述預測透明度參數(shù)。
19、可選地,所述采用所述預訓練參數(shù)預測模型中的深度預測模塊對所述樣本特征圖進行深度估計,得到深度圖像,包括:
20、若所述第一樣本圖像為單張樣本圖像,采用單圖像深度預測器對所述單張樣本圖像的樣本特征圖進行深度預測,得到所述深度圖像;
21、若所述第一樣本圖像為兩張樣本圖像,采用視差預測網(wǎng)絡對所述兩張樣本圖像的樣本特征圖進行視差預測,得到所述深度圖像。
22、可選地,所述獲取訓練三維場景在第一預設視角的第一樣本圖像和第一目標視角的第二樣本圖像,包括:
23、從所述訓練三維場景的樣本圖像集中確定所述第二樣本圖像;
24、根據(jù)所述第二樣本圖像的第一目標視角,從所述樣本圖像集中確定距離所述第一目標視角最近的相鄰視角的兩張樣本圖像為所述第一樣本圖像。
25、可選地,所述根據(jù)所述第一樣本圖像采用預訓練高斯預測模型進行參數(shù)預測,得到第一預測高斯參數(shù)和第二預測高斯參數(shù)之前,所述方法還包括:
26、獲取第二預設視角的第三樣本圖像和第二目標視角的第四樣本圖像;
27、根據(jù)所述第四樣本圖像采用初始高斯預測模型進行參數(shù)預測,得到第三預測高斯參數(shù);
28、根據(jù)所述第三預測高斯參數(shù)生成所述第二目標視角的第二預測圖像;
29、根據(jù)所述第二預測圖像和所述第四樣本圖像,計算第三訓練損失;
30、根據(jù)所述第三訓練損失對所述初始高斯預測模型中的深度預測模塊進行預訓練,得到所述預訓練高斯預測模型。
31、可選地,所述根據(jù)所述第一訓練損失和所述第二訓練損失,對所述預訓練高斯預測模型進行訓練,得到目標高斯預測模型,包括:
32、根據(jù)所述第一訓練損失和所述第二訓練損失,對所述預訓練高斯預測模型中除所述深度預測模塊之外的其他模塊進行訓練,得到目標高斯預測模型。
33、可選地,所述初始參數(shù)優(yōu)化模型為圖神經(jīng)網(wǎng)絡模型,所述根據(jù)所述第一預測高斯參數(shù)采用所述初始參數(shù)優(yōu)化模型進行優(yōu)化,得到所述第二預測高斯參數(shù),包括:
34、根據(jù)所述第一預測高斯參數(shù)生成第一預測高斯點云;
35、采用所述圖神經(jīng)網(wǎng)絡模型對所述第一預測高斯點云進行優(yōu)化,得到第二預測高斯點云,所述第二預測高斯點云為所述第二預測高斯參數(shù)表示的點云;
36、所述根據(jù)所述第二預測高斯參數(shù)生成所述第一目標視角的第一預測圖像,包括:
37、根據(jù)所述第二預測高斯點云進行渲染,生成所述第一預測圖像。
38、第二方面,本技術實施例還提供一種基于高斯預測模型的機器人強化學習方法,所述方法包括:
39、獲取預設工作場景的場景圖像;
40、采用預先訓練的目標高斯預測模型對所述場景圖像進行參數(shù)預測,得到所述預設工作場景的場景高斯參數(shù),所述目標高斯預測模型為預先采用如第一方面任一項所述的高斯預測模型訓練方法進行訓練得到的;
41、根據(jù)所述場景高斯參數(shù)采用強化學習網(wǎng)絡預測機器人的動作,以控制所述機器人在所述預設工作場景進行作業(yè)。
42、第三方面,本技術實施例提供了一種高斯預測模型訓練裝置,所述裝置包括:
43、第一圖像獲取模塊,用于獲取訓練三維場景在第一預設視角的第一樣本圖像和第一目標視角的第二樣本圖像;
44、第一參數(shù)預測模塊,用于根據(jù)所述第一樣本圖像采用預訓練高斯預測模型進行參數(shù)預測,得到第一預測高斯參數(shù)和第二預測高斯參數(shù),所述第二預測高斯參數(shù)是對所述第一預測高斯參數(shù)進行去噪后的屬性參數(shù);
45、圖像生成模塊,用于根據(jù)所述第二預測高斯參數(shù)生成所述第一目標視角的第一預測圖像;
46、損失計算模塊,用于根據(jù)所述第一預測圖像和所述第二樣本圖像,計算第一訓練損失;
47、所述損失計算模塊,還用于根據(jù)所述第一預測高斯參數(shù)和所述第二預測高斯參數(shù),計算第二訓練損失;
48、模型訓練模塊,用于根據(jù)所述第一訓練損失和所述第二訓練損失,對所述預訓練高斯預測模型進行訓練,得到目標高斯預測模型,所述目標高斯預測模型用于部署于機器人,以使得所述機器人根據(jù)所述目標高斯預測模型生成工作場景的場景高斯參數(shù),根據(jù)所述場景高斯參數(shù)進行強化學習,確定所述機器人的動作。
49、可選地,所述預訓練高斯預測模型包括:預訓練參數(shù)預測模型和初始參數(shù)優(yōu)化模型;所述第一參數(shù)預測模塊,包括:
50、第一參數(shù)預測單元,用于根據(jù)所述第一樣本圖像采用所述預訓練參數(shù)預測模型進行參數(shù)預測,得到所述第一預測高斯參數(shù);
51、第二參數(shù)預測單元,用于根據(jù)所述第一預測高斯參數(shù)采用所述初始參數(shù)優(yōu)化模型進行優(yōu)化,得到所述第二預測高斯參數(shù)。
52、可選地,所述第一參數(shù)預測單元,具體用于采用所述預訓練參數(shù)預測模型中的深度預測模塊獲取所述第一樣本圖像的樣本特征圖,并對所述樣本特征圖進行深度估計,得到深度圖像;采用所述預訓練參數(shù)預測模型中的屬性預測模塊對所述第一樣本圖像、所述樣本特征圖和所述深度圖像對應的融合特征圖進行參數(shù)預測,得到預測旋轉(zhuǎn)參數(shù)、預測縮放參數(shù)和預測透明度參數(shù);根據(jù)所述深度圖像確定預測三維位置參數(shù);根據(jù)所述第一樣本圖像的顏色值確定預測顏色參數(shù);其中,所述第一預測高斯參數(shù)包括:所述預測三維位置參數(shù)、所述預測旋轉(zhuǎn)參數(shù)、所述預測縮放參數(shù)、所述預測顏色參數(shù)和所述預測透明度參數(shù)。
53、可選地,所述第一參數(shù)預測單元,具體用于若所述第一樣本圖像為單張樣本圖像,采用單圖像深度預測器對所述單張樣本圖像的樣本特征圖進行深度預測,得到所述深度圖像;若所述第一樣本圖像為兩張樣本圖像,采用視差預測網(wǎng)絡對所述兩張樣本圖像的樣本特征圖進行視差預測,得到所述深度圖像。
54、可選地,所述第一圖像獲取模塊,具體用于從所述訓練三維場景的樣本圖像集中確定所述第二樣本圖像;根據(jù)所述第二樣本圖像的第一目標視角,從所述樣本圖像集中確定距離所述第一目標視角最近的相鄰視角的兩張樣本圖像為所述第一樣本圖像。
55、可選地,所述第一圖像獲取模塊,還用于獲取第二預設視角的第三樣本圖像和第二目標視角的第四樣本圖像;
56、所述第一參數(shù)預測模塊,還用于根據(jù)所述第四樣本圖像采用初始高斯預測模型進行參數(shù)預測,得到第三預測高斯參數(shù);
57、所述圖像生成模塊,還用于根據(jù)所述第三預測高斯參數(shù)生成所述第二目標視角的第二預測圖像;
58、所述損失計算模塊,還用于根據(jù)所述第二預測圖像和所述第四樣本圖像,計算第三訓練損失;
59、所述模型訓練模塊,還用于根據(jù)所述第三訓練損失對所述初始高斯預測模型中的深度預測模塊進行預訓練,得到所述預訓練高斯預測模型。
60、可選地,所述模型訓練模塊,具體用于根據(jù)所述第一訓練損失和所述第二訓練損失,對所述預訓練高斯預測模型中除所述深度預測模塊之外的其他模塊進行訓練,得到目標高斯預測模型。
61、可選地,所述初始參數(shù)優(yōu)化模型為圖神經(jīng)網(wǎng)絡模型,所述第二參數(shù)預測單元,具體用于根據(jù)所述第一預測高斯參數(shù)生成第一預測高斯點云;采用所述圖神經(jīng)網(wǎng)絡模型對所述第一預測高斯點云進行優(yōu)化,得到第二預測高斯點云,所述第二預測高斯點云為所述第二預測高斯參數(shù)表示的點云;
62、所述圖像生成模塊,具體用于根據(jù)所述第二預測高斯點云進行渲染,生成所述第一預測圖像。
63、第四方面,本技術實施例還提供一種基于高斯預測模型的機器人強化學習裝置,所述裝置包括:
64、第二圖像獲取模塊,用于獲取預設工作場景的場景圖像;
65、第二參數(shù)預測模塊,用于采用預先訓練的目標高斯預測模型對所述場景圖像進行參數(shù)預測,得到所述預設工作場景的場景高斯參數(shù),所述目標高斯預測模型為預先采用如第一方面任一項所述的高斯預測模型訓練方法進行訓練得到的;
66、控制模塊,用于根據(jù)所述場景高斯參數(shù)采用強化學習網(wǎng)絡預測機器人的動作,以控制所述機器人在所述預設工作場景進行作業(yè)。
67、第五方面,本技術實施例還提供一種電子設備,包括:處理器、存儲介質(zhì)和總線,所述存儲介質(zhì)存儲有所述處理器可執(zhí)行的程序指令,當電子設備運行時,所述處理器與所述存儲介質(zhì)之間通過總線通信,所述處理器執(zhí)行所述程序指令,以執(zhí)行如第一方面任一項所述的高斯預測模型訓練方法的步驟,或者如第二方面所述的基于高斯預測模型的機器人強化學習方法的步驟。
68、第六方面,本技術實施例還提供一種計算機可讀存儲介質(zhì),所述存儲介質(zhì)上存儲有計算機程序,所述計算機程序被處理器運行時執(zhí)行如第一方面任一項所述的高斯預測模型訓練方法,或者如第二方面所述的基于高斯預測模型的機器人強化學習方法。
69、本技術的有益效果是:
70、本技術提供的高斯預測模型訓練方法、機器人強化學習方法及電子設備,通過三維場景的樣本圖像訓練高斯預測模型,以使得通過高斯預測模型可以生成場景的場景高斯參數(shù),以通過場景高斯參數(shù)詳細描述場景的三維局部幾何結(jié)構(gòu),便于機器人基于場景高斯參數(shù)進行強化學習,提高機器人的強化學習效果,保證機器人的作業(yè)精準性;且基于該高斯預測模型可以針對不同場景的機器人提供高斯參數(shù)預測,具有可泛化性,應用場景廣泛。