本發明涉及計算機視覺,特別是涉及一種模型訓練方法、目標檢測方法、裝置及電子設備。
背景技術:
1、在計算機視覺技術領域,可以利用目標檢測模型對圖像中的對象進行檢測。在一些實時目標檢測場景下,需要能夠以較低的時延得到圖像的檢測結果。
技術實現思路
1、本發明實施例的目的在于提供一種模型訓練方法、目標檢測方法、裝置及電子設備,以降低目標檢測的時延,提高目標檢測的效率。具體技術方案如下:
2、本技術實施例的第一方面,首先提供了一種模型訓練方法,所述方法包括:
3、獲取樣本圖像,以及各樣本圖像的標簽;其中,一個樣本圖像的標簽用于表示該樣本圖像中的各對象;
4、針對每一樣本圖像,利用初始狀態的目標檢測模型中的編碼器對該樣本圖像進行特征提取,得到第一樣本特征向量;其中,初始狀態的目標檢測模型還包含解碼器;所述解碼器包含至少一個部分交叉注意力模塊,一個部分交叉注意力模塊包含并行連接的一個自注意力層和一個交叉注意力層;
5、針對每一部分交叉注意力模塊,對該部分交叉注意力模塊的輸入數據在通道維度進行拆分,得到第一輸入部分和第二輸入部分;其中,第一個部分交叉注意力模塊的輸入數據為所述第一樣本特征向量;
6、利用該部分交叉注意力模塊的自注意力層和交叉注意力層,分別對所述第一輸入部分和第二輸入部分進行處理;
7、結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到該部分交叉注意力模塊的輸出數據;
8、利用最后一個部分交叉注意力模塊的輸出數據進行檢測,得到該樣本圖像中每一對象的檢測結果;
9、基于該樣本圖像中每一對象的檢測結果,與該樣本圖像的標簽之間的差異,計算損失值,并基于得到的損失值對初始狀態的目標檢測模型的模型參數進行調整,直至達到收斂,得到訓練好的目標檢測模型。
10、在一些實施例中,所述解碼器包含:多個串行連接的部分交叉注意力模塊。
11、在一些實施例中,一個部分交叉注意力模塊還包括:一個殘差門網絡;
12、所述結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到該部分交叉注意力模塊的輸出數據,包括:
13、結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到第一樣本融合結果;
14、將該部分交叉注意力模塊的輸入數據和所述第一樣本融合結果輸入至所述殘差門網絡,得到該部分交叉注意力模塊的輸出數據。
15、在一些實施例中,一個部分交叉注意力模塊還包括前饋神經網絡;一個殘差門網絡中包含一個多層感知機和激活層;
16、所述將該部分交叉注意力模塊的輸入數據和所述第一樣本融合結果輸入至所述殘差門網絡,得到該部分交叉注意力模塊的輸出數據,包括:
17、對該部分交叉注意力模塊的輸入數據和所述第一樣本融合結果進行拼接,得到第一樣本拼接結果;
18、將所述第一樣本拼接結果輸入至所述多層感知機,并將所述多層感知機的輸出結果輸入至所述激活層,得到所述第一樣本拼接結果中每一特征向量的權重向量;
19、分別計算該部分交叉注意力模塊的輸入數據中各特征向量與對應的權重向量的點乘結果,得到第一樣本點乘結果;
20、分別計算所述第一樣本融合結果中各特征向量與對應的權重向量的點乘結果,得到第二樣本點乘結果;
21、計算所述第一樣本點乘結果和第二樣本點乘結果的點加結果;
22、將計算得到的點加結果輸入至所述前饋神經網絡,得到該部分交叉注意力模塊的輸出數據。
23、在一些實施例中,一個部分交叉注意力模塊還包括:第一歸一化層和第二歸一化層;
24、所述結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到第一樣本融合結果,包括:
25、將該部分交叉注意力模塊的自注意力層的處理結果,以及該部分交叉注意力模塊的交叉注意力層的輸入數據的點加結果,輸入至所述第一歸一化層,得到第一樣本歸一化結果;
26、將該部分交叉注意力模塊的交叉注意力層的處理結果,以及該部分交叉注意力模塊的自注意力層的輸入數據的點加結果,輸入至所述第二歸一化層,得到第二樣本歸一化結果;
27、對所述第一樣本歸一化結果和第二樣本歸一化結果進行拼接,得到第一樣本融合結果;
28、和/或,
29、一個部分交叉注意力模塊還包括:第三歸一化層:
30、所述將計算得到的點加結果輸入至所述前饋神經網絡,得到該部分交叉注意力模塊的輸出數據,包括:
31、將計算得到的點加結果輸入至所述前饋神經網絡,得到所述前饋神經網絡的輸出結果;
32、將所述前饋神經網絡的輸出結果輸入至所述第三歸一化層,得到該部分交叉注意力模塊的輸出數據。
33、在一些實施例中,一個樣本圖像的標簽包含:該樣本圖像中的各對象的真實框和真實類別;該樣本圖像中每一對象的檢測結果包含:該對象的預測框和表示該對象屬于對應的真實類別的概率的置信度;
34、所述基于該樣本圖像中每一對象的檢測結果,與該樣本圖像的標簽之間的差異,計算損失值,包括:
35、利用預設的損失函數,基于該樣本圖像中每一對象的檢測結果,與該樣本圖像的標簽之間的差異,計算損失值;其中,預設的損失函數為:
36、;
37、表示損失值;表示該樣本圖像中一個對象的預測框與真實框之間的交并比;表示該對象的置信度;為超參數;表示對數函數。
38、在一些實施例中,所述第一輸入部分和第二輸入部分對應的通道數相同。
39、本技術實施例的第二方面,提供了一種目標檢測方法,所述方法包括:
40、獲取待檢測圖像;
41、利用訓練好的目標檢測模型中的編碼器對所述待檢測圖像進行特征提取,得到第一預測特征向量;其中,訓練好的目標檢測模型還包含解碼器;所述解碼器包含至少一個部分交叉注意力模塊,一個部分交叉注意力模塊包含并行連接的一個自注意力層和一個交叉注意力層;
42、針對每一部分交叉注意力模塊,對該部分交叉注意力模塊的輸入數據在通道維度進行拆分,得到第三輸入部分和第四輸入部分;其中,第一個部分交叉注意力模塊的輸入數據為所述第一預測特征向量;
43、利用該部分交叉注意力模塊的自注意力層和交叉注意力層,分別對所述第三輸入部分和第四輸入部分進行處理;
44、結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到該部分交叉注意力模塊的輸出數據;
45、利用最后一個部分交叉注意力模塊的輸出數據進行檢測,得到所述待檢測圖像中每一對象的檢測結果。
46、在一些實施例中,所述解碼器包含:多個串行連接的部分交叉注意力模塊。
47、在一些實施例中,一個部分交叉注意力模塊還包括:一個殘差門網絡;
48、所述結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到該部分交叉注意力模塊的輸出數據,包括:
49、結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到第一預測融合結果;
50、將該部分交叉注意力模塊的輸入數據和所述第一預測融合結果輸入至所述殘差門網絡,得到該部分交叉注意力模塊的輸出數據。
51、在一些實施例中,一個部分交叉注意力模塊還包括前饋神經網絡;一個殘差門網絡中包含一個多層感知機和激活層;
52、所述將該部分交叉注意力模塊的輸入數據和所述第一預測融合結果輸入至所述殘差門網絡,得到該部分交叉注意力模塊的輸出數據,包括:
53、對該部分交叉注意力模塊的輸入數據和所述第一預測融合結果進行拼接,得到第一預測拼接結果;
54、將所述第一預測拼接結果輸入至所述多層感知機,并將所述多層感知機的輸出結果輸入至所述激活層,得到所述第一預測拼接結果中每一特征向量的權重向量;
55、分別計算該部分交叉注意力模塊的輸入數據中各特征向量與對應的權重向量的點乘結果,得到第一預測點乘結果;
56、分別計算所述第一預測融合結果中各特征向量與對應的權重向量的點乘結果,得到第二預測乘結果;
57、計算所述第一預測點乘結果和第二預測點乘結果的點加結果;
58、將計算得到的點加結果輸入至所述前饋神經網絡,得到該部分交叉注意力模塊的輸出數據。
59、在一些實施例中,一個部分交叉注意力模塊還包括:第一歸一化層和第二歸一化層;
60、所述結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到第一預測融合結果,包括:
61、將該部分交叉注意力模塊的自注意力層的處理結果,以及該部分交叉注意力模塊的交叉注意力層的輸入數據的點加結果,輸入至所述第一歸一化層,得到第一預測歸一化結果;
62、將該部分交叉注意力模塊的交叉注意力層的處理結果,以及該部分交叉注意力模塊的自注意力層的輸入數據的點加結果,輸入至所述第二歸一化層,得到第二預測歸一化結果;
63、對所述第一預測歸一化結果和第二預測歸一化結果進行拼接,得到第一預測融合結果;
64、和/或,
65、一個部分交叉注意力模塊還包括:第三歸一化層:
66、所述將計算得到的點加結果輸入至所述前饋神經網絡,得到該部分交叉注意力模塊的輸出數據,包括:
67、將計算得到的點加結果輸入至所述前饋神經網絡,得到所述前饋神經網絡的輸出結果;
68、將所述前饋神經網絡的輸出結果輸入至所述第三歸一化層,得到該部分交叉注意力模塊的輸出數據。
69、在一些實施例中,所述第三輸入部分和第四輸入部分對應的通道數相同。
70、在一些實施例中,所述待檢測圖像中每一對象的檢測結果包含:該對象的預測框和表示該對象所屬類別的概率的置信度;
71、所述目標檢測模型為基于樣本圖像、各樣本圖像的標簽,利用預設的損失函數進行訓練得到的;其中,一個樣本圖像的標簽包含:該樣本圖像中的各對象的真實框和真實類別;
72、預設的損失函數為:
73、;
74、表示損失值;表示一個樣本圖像中一個對象的預測框與真實框之間的交并比;表示該對象的置信度;為超參數;表示對數函數。
75、本技術實施例的第三方面,提供了一種模型訓練裝置,所述裝置包括:
76、樣本獲取單元,用于獲取樣本圖像,以及各樣本圖像的標簽;其中,一個樣本圖像的標簽用于表示該樣本圖像中的各對象;
77、第一特征提取單元,用于針對每一樣本圖像,利用初始狀態的目標檢測模型中的編碼器對該樣本圖像進行特征提取,得到第一樣本特征向量;其中,初始狀態的目標檢測模型還包含解碼器;所述解碼器包含至少一個部分交叉注意力模塊,一個部分交叉注意力模塊包含并行連接的一個自注意力層和一個交叉注意力層;
78、第一拆分單元,用于針對每一部分交叉注意力模塊,對該部分交叉注意力模塊的輸入數據在通道維度進行拆分,得到第一輸入部分和第二輸入部分;其中,第一個部分交叉注意力模塊的輸入數據為所述第一樣本特征向量;
79、第一處理單元,用于利用該部分交叉注意力模塊的自注意力層和交叉注意力層,分別對所述第一輸入部分和第二輸入部分進行處理;
80、第一融合單元,用于結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到該部分交叉注意力模塊的輸出數據;
81、第一檢測單元,用于利用最后一個部分交叉注意力模塊的輸出數據進行檢測,得到該樣本圖像中每一對象的檢測結果;
82、損失調整單元,用于基于該樣本圖像中每一對象的檢測結果,與該樣本圖像的標簽之間的差異,計算損失值,并基于得到的損失值對初始狀態的目標檢測模型的模型參數進行調整,直至達到收斂,得到訓練好的目標檢測模型。
83、在一些實施例中,所述解碼器包含:多個串行連接的部分交叉注意力模塊。
84、在一些實施例中,一個部分交叉注意力模塊還包括:一個殘差門網絡;
85、所述第一融合單元,包括:
86、第一融合子單元,用于結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到第一樣本融合結果;
87、第一殘差處理子單元,用于將該部分交叉注意力模塊的輸入數據和所述第一樣本融合結果輸入至所述殘差門網絡,得到該部分交叉注意力模塊的輸出數據。
88、在一些實施例中,一個部分交叉注意力模塊還包括前饋神經網絡;一個殘差門網絡中包含一個多層感知機和激活層;
89、所述第一殘差處理子單元,具體用于:
90、對該部分交叉注意力模塊的輸入數據和所述第一樣本融合結果進行拼接,得到第一樣本拼接結果;
91、將所述第一樣本拼接結果輸入至所述多層感知機,并將所述多層感知機的輸出結果輸入至所述激活層,得到所述第一樣本拼接結果中每一特征向量的權重向量;
92、分別計算該部分交叉注意力模塊的輸入數據中各特征向量與對應的權重向量的點乘結果,得到第一樣本點乘結果;
93、分別計算所述第一樣本融合結果中各特征向量與對應的權重向量的點乘結果,得到第二樣本點乘結果;
94、計算所述第一樣本點乘結果和第二樣本點乘結果的點加結果;
95、將計算得到的點加結果輸入至所述前饋神經網絡,得到該部分交叉注意力模塊的輸出數據。
96、在一些實施例中,一個部分交叉注意力模塊還包括:第一歸一化層和第二歸一化層;
97、所述第一融合子單元,具體用于:
98、將該部分交叉注意力模塊的自注意力層的處理結果,以及該部分交叉注意力模塊的交叉注意力層的輸入數據的點加結果,輸入至所述第一歸一化層,得到第一樣本歸一化結果;
99、將該部分交叉注意力模塊的交叉注意力層的處理結果,以及該部分交叉注意力模塊的自注意力層的輸入數據的點加結果,輸入至所述第二歸一化層,得到第二樣本歸一化結果;
100、對所述第一樣本歸一化結果和第二樣本歸一化結果進行拼接,得到第一樣本融合結果;
101、和/或,
102、一個部分交叉注意力模塊還包括:第三歸一化層:
103、所述第一殘差處理子單元,具體用于:
104、將計算得到的點加結果輸入至所述前饋神經網絡,得到所述前饋神經網絡的輸出結果;
105、將所述前饋神經網絡的輸出結果輸入至所述第三歸一化層,得到該部分交叉注意力模塊的輸出數據。
106、在一些實施例中,一個樣本圖像的標簽包含:該樣本圖像中的各對象的真實框和真實類別;該樣本圖像中每一對象的檢測結果包含:該對象的預測框和表示該對象屬于對應的真實類別的概率的置信度;
107、所述損失調整單元,具體用于:
108、利用預設的損失函數,基于該樣本圖像中每一對象的檢測結果,與該樣本圖像的標簽之間的差異,計算損失值;其中,預設的損失函數為:
109、;
110、表示損失值;表示該樣本圖像中一個對象的預測框與真實框之間的交并比;表示該對象的置信度;為超參數;表示對數函數。
111、在一些實施例中,所述第一輸入部分和第二輸入部分對應的通道數相同。
112、本技術實施例的第四方面,提供了一種目標檢測裝置,所述裝置包括:
113、待檢測圖像獲取單元,用于獲取待檢測圖像;
114、第二特征提取單元,用于利用訓練好的目標檢測模型中的編碼器對所述待檢測圖像進行特征提取,得到第一預測特征向量;其中,訓練好的目標檢測模型還包含解碼器;所述解碼器包含至少一個部分交叉注意力模塊,一個部分交叉注意力模塊包含并行連接的一個自注意力層和一個交叉注意力層;
115、第二拆分單元,用于針對每一部分交叉注意力模塊,對該部分交叉注意力模塊的輸入數據在通道維度進行拆分,得到第三輸入部分和第四輸入部分;其中,第一個部分交叉注意力模塊的輸入數據為所述第一預測特征向量;
116、第二處理單元,用于利用該部分交叉注意力模塊的自注意力層和交叉注意力層,分別對所述第三輸入部分和第四輸入部分進行處理;
117、第二融合單元,用于結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到該部分交叉注意力模塊的輸出數據;
118、第二檢測單元,用于利用最后一個部分交叉注意力模塊的輸出數據進行檢測,得到所述待檢測圖像中每一對象的檢測結果。
119、在一些實施例中,所述解碼器包含:多個串行連接的部分交叉注意力模塊。
120、在一些實施例中,一個部分交叉注意力模塊還包括:一個殘差門網絡;
121、所述第二融合單元,包括:
122、第二融合子單元,用于結合該部分交叉注意力模塊的自注意力層和交叉注意力層的處理結果,得到第一預測融合結果;
123、第二殘差處理子單元,用于將該部分交叉注意力模塊的輸入數據和所述第一預測融合結果輸入至所述殘差門網絡,得到該部分交叉注意力模塊的輸出數據。
124、在一些實施例中,一個部分交叉注意力模塊還包括前饋神經網絡;一個殘差門網絡中包含一個多層感知機和激活層;
125、所述第二殘差處理子單元,具體用于:
126、對該部分交叉注意力模塊的輸入數據和所述第一預測融合結果進行拼接,得到第一預測拼接結果;
127、將所述第一預測拼接結果輸入至所述多層感知機,并將所述多層感知機的輸出結果輸入至所述激活層,得到所述第一預測拼接結果中每一特征向量的權重向量;
128、分別計算該部分交叉注意力模塊的輸入數據中各特征向量與對應的權重向量的點乘結果,得到第一預測點乘結果;
129、分別計算所述第一預測融合結果中各特征向量與對應的權重向量的點乘結果,得到第二預測乘結果;
130、計算所述第一預測點乘結果和第二預測點乘結果的點加結果;
131、將計算得到的點加結果輸入至所述前饋神經網絡,得到該部分交叉注意力模塊的輸出數據。
132、在一些實施例中,一個部分交叉注意力模塊還包括:第一歸一化層和第二歸一化層;
133、所述第二融合子單元,具體用于:
134、將該部分交叉注意力模塊的自注意力層的處理結果,以及該部分交叉注意力模塊的交叉注意力層的輸入數據的點加結果,輸入至所述第一歸一化層,得到第一預測歸一化結果;
135、將該部分交叉注意力模塊的交叉注意力層的處理結果,以及該部分交叉注意力模塊的自注意力層的輸入數據的點加結果,輸入至所述第二歸一化層,得到第二預測歸一化結果;
136、對所述第一預測歸一化結果和第二預測歸一化結果進行拼接,得到第一預測融合結果;
137、和/或,
138、一個部分交叉注意力模塊還包括:第三歸一化層:
139、所述第二殘差處理子單元,具體用于:
140、將計算得到的點加結果輸入至所述前饋神經網絡,得到所述前饋神經網絡的輸出結果;
141、將所述前饋神經網絡的輸出結果輸入至所述第三歸一化層,得到該部分交叉注意力模塊的輸出數據。
142、在一些實施例中,所述第三輸入部分和第四輸入部分對應的通道數相同。
143、在一些實施例中,所述待檢測圖像中每一對象的檢測結果包含:該對象的預測框和表示該對象所屬類別的概率的置信度;
144、所述目標檢測模型為基于樣本圖像、各樣本圖像的標簽,利用預設的損失函數進行訓練得到的;其中,一個樣本圖像的標簽包含:該樣本圖像中的各對象的真實框和真實類別;
145、預設的損失函數為:
146、;
147、表示損失值;表示一個樣本圖像中一個對象的預測框與真實框之間的交并比;表示該對象的置信度;為超參數;表示對數函數。
148、本技術實施例的第五方面,提供了一種電子設備,包括:
149、存儲器,用于存放計算機程序;
150、處理器,用于執行存儲器上所存放的程序時,實現上述任一所述的模型訓練方法,或,目標檢測方法。
151、本技術實施例的又一方面,提供了一種計算機可讀存儲介質,所述計算機可讀存儲介質內存儲有計算機程序,所述計算機程序被處理器執行時實現上述任一所述的模型訓練方法,或,目標檢測方法。
152、本技術實施例的又一方面,提供了一種包含指令的計算機程序產品,當其在計算機上運行時,使得計算機執行上述任一所述的模型訓練方法,或,目標檢測方法。
153、本發明實施例有益效果:
154、本發明實施例提供的模型訓練方法,在通過目標檢測模型對圖像(即樣本圖像或待檢測圖像)進行目標檢測的過程中,在得到目標檢測模型中編碼器提取得到的圖像的特征表示(即第一樣本特征向量或第一預測特征向量)之后,可以將編碼器的輸出數據輸入至目標檢測模型中解碼器進行處理。本技術中,解碼器包含至少一個部分交叉注意力模塊,且一個部分交叉注意力模塊包含并行連接的一個自注意力層和一個交叉注意力層。相應的,針對每一部分交叉注意力模塊,可以在通道維度對該部分交叉注意力模塊的輸入數據進行拆分,利用該部分交叉注意力模塊的自注意力層和交叉注意力層,分別對拆分得到的兩個部分(即第一輸入部分和第二輸入部分,或,第三輸入部分或第四輸入部分)進行處理。也就是說,解碼器中的部分交叉注意力模塊能夠利用自注意力層和交叉注意力層的并行結構對編碼器的輸出數據進行處理,也就能夠基于自注意力機制和交叉注意力機制,得到針對圖像中對象的特征表示(即解碼器中的最后一個部分交叉注意力模塊的輸出數據)。進而,可以利用最后一個部分交叉注意力模塊的輸出數據進行檢測,得到圖像中每一對象的檢測結果。如此,能夠通過部分交叉注意力模塊中自注意力層和交叉注意力層的并行結構,降低目標檢測的時延,提高目標檢測的效率。
155、當然,實施本發明的任一產品或方法并不一定需要同時達到以上所述的所有優點。