MAX78000是具有超低功耗卷積神經(jīng)網(wǎng)絡(luò)加速器的人工智能微控制器,可以在芯片上有效地運(yùn)行人工智能模型。用戶應(yīng)首先使用ADI公司在PyTorch上的開(kāi)發(fā)流程開(kāi)發(fā)神經(jīng)網(wǎng)絡(luò)模型。然后,MAX78000頻率合成器工具接受YAML格式的PyTorch檢查點(diǎn)和模型描述,自動(dòng)生成C代碼,在MAX78000上編譯和執(zhí)行。模型開(kāi)發(fā)階段使用的基本軟件組件之一是數(shù)據(jù)加載器,它負(fù)責(zé)特定于應(yīng)用程序的數(shù)據(jù)準(zhǔn)備任務(wù)。本文檔介紹在準(zhǔn)備適合MAX78000模型訓(xùn)練的特定應(yīng)用訓(xùn)練和驗(yàn)證/測(cè)試集實(shí)體時(shí),數(shù)據(jù)加載器實(shí)現(xiàn)的原則和設(shè)計(jì)注意事項(xiàng)。
介紹
在應(yīng)用程序開(kāi)發(fā)周期中,第一步是準(zhǔn)備和預(yù)處理可用數(shù)據(jù)以創(chuàng)建訓(xùn)練和驗(yàn)證/測(cè)試數(shù)據(jù)集。除了通常的數(shù)據(jù)預(yù)處理外,在MAX78000上運(yùn)行模型還需要考慮幾個(gè)硬件限制。
數(shù)據(jù)加載器的主要職責(zé)可以總結(jié)如下:
[可選]將原始資源的輸入和標(biāo)簽數(shù)據(jù)下載到通過(guò)調(diào)用ADI公司的CNN培訓(xùn)工具(培訓(xùn)存儲(chǔ)庫(kù)/train.py提供的數(shù)據(jù)路徑中。
從指定的數(shù)據(jù)路徑(csv/二進(jìn)制文件、帶或不帶層次結(jié)構(gòu)的文件夾/s 等)讀取原始輸入數(shù)據(jù)。
讀取提供的數(shù)據(jù)路徑中的原始標(biāo)簽/注釋?zhuān)╟sv/二進(jìn)制文件/s、帶或不帶層次結(jié)構(gòu)的文件夾等)。
[可選]應(yīng)用數(shù)據(jù)預(yù)處理步驟,如增強(qiáng)、數(shù)據(jù)清理等。
對(duì)輸入數(shù)據(jù)和標(biāo)簽應(yīng)用所需的數(shù)據(jù)類(lèi)型和范圍轉(zhuǎn)換。
執(zhí)行訓(xùn)練和測(cè)試/驗(yàn)證拆分。
提供數(shù)據(jù)加載器方法和與MAX78000模型訓(xùn)練工具兼容的定義字典。
[可選]將處理后的數(shù)據(jù)實(shí)體保留在磁盤(pán)上,以便將來(lái)訪問(wèn)。
[可選]將上述步驟應(yīng)用于可從同一原始數(shù)據(jù)源生成的每個(gè)不同數(shù)據(jù)集變體。
提供兩個(gè)用于訓(xùn)練和測(cè)試數(shù)據(jù)的 PyTorch 數(shù)據(jù)集。
以下各節(jié)提供了有關(guān)創(chuàng)建高效數(shù)據(jù)加載器的說(shuō)明,以滿足所需的功能并方便地集成訓(xùn)練工具。
圖 1 抽象地顯示了數(shù)據(jù)加載器實(shí)現(xiàn)的主流。以下各節(jié)介紹了詳細(xì)信息。
圖1.數(shù)據(jù)加載器模塊的主流。
自定義數(shù)據(jù)加載器實(shí)現(xiàn)的設(shè)計(jì)原則
數(shù)據(jù)加載器實(shí)現(xiàn)的主要職責(zé)之一是在將數(shù)據(jù)集實(shí)體饋送到 CNN 模型之前進(jìn)行數(shù)據(jù)范圍調(diào)整和數(shù)據(jù)類(lèi)型管理。圖 2 總結(jié)了這些操作,以下各節(jié)將詳細(xì)介紹這些操作。
圖2.數(shù)據(jù)范圍規(guī)范化和類(lèi)型轉(zhuǎn)換。
預(yù)期數(shù)據(jù)范圍
對(duì)于訓(xùn)練,輸入數(shù)據(jù)應(yīng)在
.當(dāng)評(píng)估量化權(quán)重或在硬件上運(yùn)行時(shí),輸入數(shù)據(jù)應(yīng)位于本機(jī)MAX7800X范圍[-128, +127]。
如以下部分所述,數(shù)據(jù)加載器函數(shù)將數(shù)據(jù)路徑和一些參數(shù)作為輸入?yún)?shù)。參數(shù)字段包括兩個(gè)必填字段:act_mode_8bit 和 truncate_testset。當(dāng)設(shè)置為T(mén)rue時(shí),第一個(gè)參數(shù)是指對(duì)于本地MAX7800X范圍,即范圍[-128, +127],應(yīng)正確進(jìn)行大小寫(xiě)歸一化。設(shè)置為 False 時(shí),規(guī)范化應(yīng)在訓(xùn)練范圍內(nèi)
如果可用數(shù)據(jù)在 [0 1] 范圍內(nèi),例如,在 PIL 圖像中,數(shù)據(jù)加載器可以直接調(diào)用 ai8x.normalize() 函數(shù),使用提供的 args 參數(shù)將數(shù)據(jù)規(guī)范化為兩個(gè)支持的數(shù)據(jù)范圍:
class normalize: """ Normalize input to either [-128/128, +127/128] or [-128, +127] """ def __init__(self, args): self.args = args def __call__(self, img): if self.args.act_mode_8bit: return img.sub(0.5).mul(256.).round().clamp(min=-128, max=127) return img.sub(0.5).mul(256.).round().clamp(min=-128, max=127).div(128.)
如果可用數(shù)據(jù)范圍為 [0 255],則需要在調(diào)用 ai256x.normalize() 函數(shù)之前將其除以 0 以使其達(dá)到 [1 8] 范圍。
注意:ai8x 模塊的設(shè)備設(shè)置方法ai8x.set_device也接受相關(guān)參數(shù)模擬:True 表示訓(xùn)練案例 (act_mode_8bit = True),F(xiàn)alse 表示量化模型的評(píng)估或在也初始化 act_mode_8bit = False 的硬件上運(yùn)行。此方法由具有適當(dāng)參數(shù)管理的訓(xùn)練腳本使用,但如果在外部調(diào)用函數(shù),則應(yīng)正確設(shè)置模擬參數(shù)。
在MAX7800X硬件上運(yùn)行推理時(shí),必須考慮本地?cái)?shù)據(jù)格式,并且在推理過(guò)程中應(yīng)盡可能少地進(jìn)行預(yù)處理。
數(shù)據(jù)類(lèi)型
數(shù)據(jù)源可能具有不同范圍內(nèi)各種格式和值的原始數(shù)據(jù)文件。數(shù)據(jù)集類(lèi)和數(shù)據(jù)加載器函數(shù)負(fù)責(zé)處理必要的轉(zhuǎn)換。
數(shù)據(jù)加載器函數(shù)應(yīng)返回?cái)?shù)據(jù)類(lèi)的訓(xùn)練和測(cè)試數(shù)據(jù)集元組。類(lèi)型轉(zhuǎn)換和轉(zhuǎn)換通常在 __get_item__ 函數(shù)內(nèi)處理,該函數(shù)應(yīng)返回指定索引數(shù)據(jù)實(shí)體的數(shù)據(jù)元組和標(biāo)簽。數(shù)據(jù)項(xiàng)的類(lèi)型應(yīng)為:火炬。庫(kù)達(dá)]。形狀火炬的浮動(dòng)張量。大小(數(shù)據(jù)集字典的相關(guān)條目“輸入”字段)。
標(biāo)簽維度可能因問(wèn)題類(lèi)型或輸入數(shù)據(jù)形狀而異。每個(gè)標(biāo)簽類(lèi)型都應(yīng)強(qiáng)制轉(zhuǎn)換為 np.long,以便在訓(xùn)練腳本中正確計(jì)算訓(xùn)練損失。
在完成所有數(shù)據(jù)增強(qiáng)和預(yù)處理任務(wù)并將數(shù)據(jù)范圍規(guī)范化為 [0 1] 后,應(yīng)使用該ai8x_normalize進(jìn)行適當(dāng)?shù)倪M(jìn)一步規(guī)范化,然后可以使用 torchvision.transforms.ToTensor 執(zhí)行類(lèi)型轉(zhuǎn)換。
注意:拿到火炬。庫(kù)達(dá)]。FloatTensor,numpy 數(shù)組必須事先轉(zhuǎn)換為 float32。
Torchvision軟件包包括各種預(yù)處理轉(zhuǎn)換,例如可以根據(jù)應(yīng)用程序需求使用的PIL圖像的隨機(jī)裁剪。數(shù)據(jù)類(lèi)可以利用Torchvision包的復(fù)合轉(zhuǎn)換按順序應(yīng)用多個(gè)轉(zhuǎn)換,例如ToTensor轉(zhuǎn)換和ai8x_normalize每當(dāng)訪問(wèn)數(shù)據(jù)條目時(shí)。
數(shù)據(jù)實(shí)體的存儲(chǔ)
通常,有兩種方法可以存儲(chǔ)數(shù)據(jù)集條目;整個(gè)數(shù)據(jù)集條目可以存儲(chǔ)在內(nèi)存中,也可以在使用 __getitem__ 方法訪問(wèn)時(shí)從磁盤(pán)讀取。基本的決策因素是數(shù)據(jù)集的大小和每個(gè)實(shí)體的大小。當(dāng)數(shù)據(jù)集太大而無(wú)法放入內(nèi)存時(shí)(在初始化函數(shù)中處理預(yù)處理和增強(qiáng)任務(wù)后),所有數(shù)據(jù)集條目都可以保存到磁盤(pán)中,并在以后每次訪問(wèn)時(shí)從磁盤(pán)單獨(dú)讀取。雖然將數(shù)據(jù)條目保留在內(nèi)存中可以加快數(shù)據(jù)訪問(wèn)速度,但內(nèi)存限制可能會(huì)阻止在所有情況下使用基于內(nèi)存的方法。
注意:即使采用基于內(nèi)存的方法,也建議將預(yù)處理和增強(qiáng)的數(shù)據(jù)條目寫(xiě)入磁盤(pán),因?yàn)樗鼈冎粓?zhí)行一次。然后,在每次生成數(shù)據(jù)類(lèi)實(shí)例時(shí),可以執(zhí)行將所有數(shù)據(jù)批量讀取到內(nèi)存中。
表 1 總結(jié)了同一數(shù)據(jù)源的兩個(gè)數(shù)據(jù)加載器實(shí)現(xiàn)選項(xiàng)的一些度量。從第一行可以看出,磁盤(pán)存儲(chǔ)方法可以處理更多圖像。內(nèi)存預(yù)算是限制已處理圖像數(shù)量的因素。這兩種方法的數(shù)據(jù)集生成時(shí)間都很長(zhǎng),因?yàn)榈谝环N方法還處理預(yù)處理、擴(kuò)充等步驟,然后將所有可用數(shù)據(jù)寫(xiě)入磁盤(pán)。基于內(nèi)存的方法在以后生成數(shù)據(jù)集需要更長(zhǎng)的時(shí)間,因?yàn)閷?duì)象創(chuàng)建需要將大文件從磁盤(pán)批量讀取到內(nèi)存中。而在第二種方法中,每個(gè)數(shù)據(jù)集項(xiàng)都是獨(dú)立保存的,__getitem__方法創(chuàng)建數(shù)據(jù)集對(duì)象和實(shí)體檢索都花費(fèi)很少的時(shí)間。第一種方法的內(nèi)存消耗很高,因?yàn)樗鼘⑺袛?shù)據(jù)集實(shí)體保留在內(nèi)存中。在磁盤(pán)使用情況方面,第一種方法通常使用所有數(shù)據(jù)條目的單個(gè)文件,第二種方法為每個(gè)數(shù)據(jù)條目使用單獨(dú)的文件。這應(yīng)該會(huì)導(dǎo)致大致相似的磁盤(pán)預(yù)算。
注意:在表 1 中,由于處理圖像的數(shù)量減少,第一種方法的磁盤(pán)空間要小得多。磁盤(pán)方法的唯一缺點(diǎn)是它增加了訓(xùn)練時(shí)間,因?yàn)槊總€(gè)數(shù)據(jù)輸入讀取都是作為單獨(dú)的磁盤(pán)操作完成的。
數(shù)據(jù)加載器 圖像存儲(chǔ)在內(nèi)存中 圖像存儲(chǔ)在內(nèi)存中 | 數(shù)據(jù)加載器 使用從磁盤(pán)讀取的映像 | |
---|---|---|
可處理的圖像數(shù)量 | 20 000 * 1 = 20000 | 34 426 * 3 = 103 278 |
數(shù)據(jù)集生成時(shí)間 – 首次運(yùn)行 | 30 分 | 60 分 |
數(shù)據(jù)集生成時(shí)間 – 后續(xù)運(yùn)行 | 15 分 | 瞬間 |
運(yùn)行時(shí)內(nèi)存消耗峰值 | ~55 千兆字節(jié) | ~5 千兆字節(jié) |
磁盤(pán)消耗 | ~50 千兆字節(jié) | ~ 240 GB |
訓(xùn)練時(shí)間 單紀(jì)元 | 60-62秒 | 1450 秒 |
自定義數(shù)據(jù)加載器實(shí)現(xiàn)的編程原則
數(shù)據(jù)加載器模塊將在 PyTorch 中實(shí)現(xiàn),預(yù)計(jì)至少具有以下三個(gè)組件:
數(shù)據(jù)集類(lèi)定義。示例:類(lèi) AISegment(數(shù)據(jù)集)
torch.utils.data.Dataset是自定義數(shù)據(jù)集實(shí)現(xiàn)類(lèi)應(yīng)繼承的抽象類(lèi)。有關(guān) PyTorch 中自定義數(shù)據(jù)加載器實(shí)現(xiàn)的教程,請(qǐng)參閱 [1]。
應(yīng)重寫(xiě)__len__方法,以便 len(dataset) 返回?cái)?shù)據(jù)集的大小。
還應(yīng)該實(shí)現(xiàn)__getitem__來(lái)支持索引,以便數(shù)據(jù)集[i]可用于獲取i千樣本。對(duì)于MAX78000應(yīng)用,該方法應(yīng)返回一個(gè)數(shù)據(jù)元組及其相應(yīng)的標(biāo)簽。
__init__功能參數(shù)和內(nèi)容可根據(jù)應(yīng)用需求進(jìn)行定制。前兩個(gè)參數(shù)通常是數(shù)據(jù)根路徑和類(lèi)型(測(cè)試或訓(xùn)練),如MAX78000訓(xùn)練存儲(chǔ)庫(kù)數(shù)據(jù)集文件夾中的幾個(gè)數(shù)據(jù)加載器實(shí)現(xiàn)所示。但是,只要以下項(xiàng)目中介紹并作為外部通信點(diǎn)提供的數(shù)據(jù)加載器函數(shù)是固定的并執(zhí)行所需的操作,就可以更改這些參數(shù)的順序或命名。
數(shù)據(jù)加載器函數(shù):不應(yīng)修改此函數(shù)的簽名。第一個(gè)輸入是指定數(shù)據(jù)目錄和程序參數(shù)的元組。兩個(gè)前導(dǎo)布爾輸入指定是否應(yīng)加載訓(xùn)練和/或測(cè)試數(shù)據(jù)。
程序參數(shù)有兩個(gè)與數(shù)據(jù)類(lèi)實(shí)現(xiàn)相關(guān)的關(guān)鍵字段;act_mode_8bit和truncate_testset。第一個(gè)是指規(guī)范化類(lèi)型(有關(guān)更多詳細(xì)信息,請(qǐng)參閱預(yù)期數(shù)據(jù)范圍部分),第二個(gè)用于將測(cè)試集截?cái)酁閱蝹€(gè)元素集。
示例:def AISegment352_get_datasets(data, load_train=True, load_test=True)。
數(shù)據(jù)集字典包括可用的數(shù)據(jù)加載器函數(shù)。字典名稱(chēng)和鍵值都不應(yīng)更改,僅應(yīng)根據(jù)自定義數(shù)據(jù)集實(shí)現(xiàn)調(diào)整值。同一數(shù)據(jù)源的每個(gè)變體都可以作為此字典中的單獨(dú)元素存在。
例:
“name”鍵的值使ADI公司的CNN訓(xùn)練工具(訓(xùn)練存儲(chǔ)庫(kù)/train.py)能夠在提供--dataset參數(shù)時(shí)查找數(shù)據(jù)集。因此,此字段的值在自定義數(shù)據(jù)集中應(yīng)該是唯一的。
“input”鍵的值是輸入數(shù)據(jù)的維度。第一個(gè)維度作為num_channels傳遞給模型,而其余維度作為維度傳遞。例如,“input”:(1, 28, 28) 以 num_channels=1 和 dimension=(28, 28) 的形式傳遞給模型。一維輸入使用單個(gè)“維度”,例如“input”:(2, 512) 以 num_channels=2 和 dimension=(512, ) 的形式傳遞給模型。
“output”鍵的值指定分類(lèi)問(wèn)題的可用類(lèi)類(lèi)型。此鍵的值也可以使用字符串文本定義。
示例:“輸出”:(“背景”、“縱向”)。
“weight”鍵的值指定每個(gè)數(shù)據(jù)實(shí)體的權(quán)重,并引用類(lèi)標(biāo)簽。這是一個(gè)可選字段,如果未提供,則全部為“1”。
可以通過(guò)使用可用樣本數(shù)為每個(gè)類(lèi)提供成反比的權(quán)重來(lái)解決訓(xùn)練數(shù)據(jù)集中的類(lèi)不平衡問(wèn)題。因此,訓(xùn)練腳本更加關(guān)注頻率較低的樣本。
可選的回歸可以設(shè)置為 True 以自動(dòng)選擇訓(xùn)練腳本的 --regression 命令行參數(shù)。
注意:當(dāng)類(lèi)數(shù)為 1 時(shí),訓(xùn)練腳本會(huì)自動(dòng)設(shè)置回歸。 示例:“輸出”:(“id”),“回歸”:真。
示例數(shù)據(jù)加載器
MAX78000訓(xùn)練存儲(chǔ)庫(kù)數(shù)據(jù)集文件夾包括幾種不同的數(shù)據(jù)加載器實(shí)現(xiàn),詳情請(qǐng)參見(jiàn)[2]。在本節(jié)中,將介紹一個(gè)定制的數(shù)據(jù)加載器來(lái)舉例說(shuō)明所有提到的原則。縱向分割數(shù)據(jù)集用于此目的,更多詳細(xì)信息請(qǐng)參見(jiàn) [3]。此數(shù)據(jù)集源包括 34,427 張分辨率為 600 × 800 的人類(lèi)肖像圖像(紅色、綠色和藍(lán)色 (RGB) 顏色格式),以及相同數(shù)量的標(biāo)簽圖像,具有相同大小的紅色、綠色、藍(lán)色和 Alpha (RGBA) 格式的相應(yīng)蒙版。
初始化
設(shè)計(jì)的數(shù)據(jù)加載器模塊的第一個(gè)組件是具有以下初始化函數(shù)的數(shù)據(jù)加載器類(lèi)。將跳過(guò)有關(guān)生成數(shù)據(jù)集信息數(shù)據(jù)框的詳細(xì)信息。簡(jiǎn)而言之,這些行包括一些路徑處理代碼,用于保留原始圖像路徑、原始遮罩文件路徑、裁剪 idx 和要保存到的數(shù)據(jù)集條目的泡菜文件路徑。除了這些路徑生成部分之外,初始化函數(shù)的主要功能是保留提供的參數(shù)并相應(yīng)地排列一些局部變量(例如,訓(xùn)練或測(cè)試數(shù)據(jù)集信息數(shù)據(jù)幀)并生成數(shù)據(jù)集實(shí)體。
對(duì)于第一次初始化調(diào)用,所有數(shù)據(jù)處理任務(wù)都使用 __gen_datasets__ 方法處理,并為每個(gè)數(shù)據(jù)集項(xiàng)生成 pickle 文件并存儲(chǔ)在磁盤(pán)上,以便在每次數(shù)據(jù)訪問(wèn)時(shí)讀取。
class AISegment(Dataset): … def __init__(self, root_dir, d_type, transform=None, im_size=[80, 80], fold_ratio=1): … self.d_type = d_type self.transform = transform self.img_ds_dim = im_size self.fold_ratio = fold_ratio # Generate and save dataset information file if not already available # Training and Test split is also performed here using the hash of file names (all three cropped images should fall into the same set) # Information data frames include raw data path, raw label path, crop idx, pickle file path, etc. for each data entity … # One of the created data frames is selected from: train_img_files_info & test_img_files_info if self.d_type == 'train': self.img_files_info = train_img_files_info elif self.d_type == 'test': self.img_files_info = test_img_files_info else: print('Unknown data type: %s' % self.d_type) return # Create and save pt files for each data entity (if not available before) self.__create_pt_files() self.is_truncated = False def __create_pt_files(self): if self.__check_pt_files_exist(): return self.__makedir_exist_ok(self.processed_train_data_folder) self.__makedir_exist_ok(self.processed_test_data_folder) self.__gen_datasets()
數(shù)據(jù)增強(qiáng)
gen_datasets方法處理所有必需的預(yù)處理、擴(kuò)充和預(yù)規(guī)范化步驟。實(shí)施的步驟如下:
從原始圖像裁剪三個(gè)正方形圖像(因?yàn)閁-Net模型使用方形圖像)。
根據(jù)提供的數(shù)據(jù)集參數(shù),裁剪圖像和遮罩圖像被下采樣為 80×80 或 352×352。
相應(yīng)的遮罩圖像將轉(zhuǎn)換為二進(jìn)制“背景”或“肖像”標(biāo)簽。
如果需要,將折疊圖像(352×352 圖像折疊為大小為 88×88×48 的圖像)。
圖像在保存之前按 256 縮放,因?yàn)閺?fù)合變壓器期望圖像在 [0 1] 范圍內(nèi),但原始圖像是 RGB 并且值在 [0 255] 范圍內(nèi)。
圖像裁剪原理如圖3所示。
圖 4 包括從同一原始圖像構(gòu)建的三個(gè)示例訓(xùn)練圖像。
圖3.示例數(shù)據(jù)加載器 – 數(shù)據(jù)增強(qiáng):從原始圖像中裁剪出三個(gè)方形圖像。
圖4.從原始圖像(600 × 600)裁剪了三個(gè)大小為 800 × 600 的圖像(以及相應(yīng)的消光圖像)。
gen_datasets方法的實(shí)現(xiàn)方式如下:
def __normalize_image(self, image): return image / 256 def __gen_datasets(self): # For each entry in dataset information dataframe for _, row in tqdm(self.img_files_info.iterrows()): img_file = row['img_file_path'] matting_file = row['lbl_file_path'] pickle_file = row['pickle_file_path'] img_crp_idx = row['crp_idx'] img = Image.open(img_file) lbl_rgba = Image.open(matting_file) vertical_crop_area = AISegment.img_dim[0] - AISegment.img_crp_dim[0] step_size = vertical_crop_area / (AISegment.num_of_cropped_imgs - 1) # Determine top left coordinate of the crop area top_left_x = 0 top_left_y = 0 + img_crp_idx * step_size # Determine bottom right coordinate of the crop area bottom_right_x = AISegment.img_crp_dim[0] bottom_right_y = top_left_y + AISegment.img_crp_dim[0] img_crp = img.crop((top_left_x, top_left_y, bottom_right_x, bottom_right_y)) img_crp_lbl = lbl_rgba.crop((top_left_x, top_left_y, bottom_right_x, bottom_right_y)) img_crp = img_crp.resize(self.img_ds_dim) img_crp = np.asarray(img_crp).astype(np.uint8) img_crp_lbl = img_crp_lbl.resize(self.img_ds_dim) img_crp_lbl = (np.asarray(img_crp_lbl)[:, :, 3] == 0).astype(np.uint8) # Fold the data (ex: 352 x 352 x 3 folded into 88 x 88 x 48) if required and save to pt file if self.fold_ratio == 1: img_crp_folded = img_crp else: img_crp_folded = None for i in range(self.fold_ratio): for j in range(self.fold_ratio): if img_crp_folded is not None: img_crp_folded = np.concatenate((img_crp_folded, img_crp[i::self.fold_ratio, j::self.fold_ratio, :]), axis=2) else: img_crp_folded = img_crp[i::self.fold_ratio, j::self.fold_ratio, :] pickle.dump((img_crp_folded, img_crp_lbl), open(pickle_file, 'wb'))
數(shù)據(jù)加載器方法和轉(zhuǎn)換器定義
數(shù)據(jù)加載器方法是第二個(gè)必需的組件定制數(shù)據(jù)加載器模塊。對(duì)于示例 AISegment 數(shù)據(jù)集,實(shí)現(xiàn)了兩個(gè)不同的數(shù)據(jù)加載器函數(shù)。第一個(gè) (AISegment_get_datasets) 使用較小的 U-Net 網(wǎng)絡(luò)模型返回大小為 80x80 的圖像。后者 (AISegment352_get_datasets) 返回大小為 352×352 的圖像。下面是第二個(gè)的實(shí)現(xiàn),它生成具有所需屬性的 AISegment 對(duì)象。復(fù)合變壓器也在此函數(shù)中定義。此外,如果需要截?cái)啵瑒t測(cè)試數(shù)據(jù)集將被截?cái)唷?/p>
def AISegment352_get_datasets(data, load_train=True, load_test=True): """…""" (data_dir, args) = data if load_train: train_transform = transforms.Compose([ transforms.ToTensor(), ai8x.normalize(args=args) ]) train_dataset = AISegment(root_dir=data_dir, d_type='train', transform=train_transform, im_size=[352, 352]) else: train_dataset = None if load_test: test_transform = transforms.Compose([ transforms.ToTensor(), ai8x.normalize(args=args) ]) test_dataset = AISegment(root_dir=data_dir, d_type='test', transform=test_transform, im_size=[352, 352]) if args.truncate_testset: test_dataset.data = test_dataset.data[:1] else: test_dataset = None return train_dataset, test_dataset
數(shù)據(jù)集字典
數(shù)據(jù)集字典是自定義數(shù)據(jù)加載器模塊的第三個(gè)必需組件,該模塊包含可用的數(shù)據(jù)加載器功能。對(duì)于示例 AISegment 數(shù)據(jù)加載程序,由于有兩種數(shù)據(jù)集變體可以生成具有不同分辨率(80×80 或 352×352)的數(shù)據(jù)集實(shí)體,因此數(shù)據(jù)集字典有兩個(gè)實(shí)體,每個(gè)實(shí)體都包括輸入和輸出大小的正確定義以及數(shù)據(jù)加載器函數(shù)名稱(chēng)。
使用圖像測(cè)試訓(xùn)練模型
在模型開(kāi)發(fā)階段之后,可以使用測(cè)試數(shù)據(jù)集或任意樣本來(lái)測(cè)試模型。關(guān)鍵的一點(diǎn)是,在向模型提供輸入之前,必須在外部完成正確的轉(zhuǎn)換/s數(shù)據(jù)加載器實(shí)現(xiàn)。
例如,使用AISegment數(shù)據(jù)集訓(xùn)練的示例模型需要輸入形狀[48, 88, 88],這是分辨率為352×352的折疊RGB圖像的通道優(yōu)先表示,具有MAX7800X所需的歸一化像素值。外部提供的測(cè)試圖像甚至可能沒(méi)有相同的顏色格式,但必須事先實(shí)現(xiàn)所需的轉(zhuǎn)換,因?yàn)槟P褪轻槍?duì) RGB 圖像訓(xùn)練的。下面是具有 470×470 分辨率和 YbCr 顏色格式的縱向圖像的測(cè)試模型的示例代碼片段:
rgb_img = yuv_img.convert('RGB') rgb_img_ds = rgb_img.resize([352, 352]) # Image to numpy array conversion: rgb_img_ds = np.asarray(rgb_img_ds).astype(np.uint8) # Fold image (352 x 352 x 3 folded into 88 x 88 x 48) rgb_img_ds_folded = fold_image(rgb_img_ds, 4) # Covert pixel values to range [0 1] and cast to float type (required for Torch) rgb_img_ds_folded_scaled = (rgb_img_ds_folded / 256).astype(np.float32) # Normalize for MAX78000 # Set act_mode_8bit=True as we will set model parameter simulate=True args = Args(act_mode_8bit=True) transform = transforms.Compose([ transforms.ToTensor(), ai8x.normalize(args=args) ]) rgb_img_ds_folded_scaled_normalized = transform(rgb_img_ds_folded_scaled) # Add batch dimension rgb_img_batch = rgb_img_ds_folded_scaled_normalized.unsqueeze(0) # Load model device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') load_model_path = 'unet/qat_ai85unet_v7_352_4_best_q.pth.tar' ai8x.set_device(device=85, simulate=True, round_avg=False) model = mod.AI85Unet_v7_pt(num_classes=2, num_channels=3, dimensions=(88, 88), bias=True, fold_ratio=4) checkpoint = torch.load(load_model_path, map_location=lambda storage, loc: storage) ai8x.fuse_bn_layers(model) model = apputils.load_lean_checkpoint(model, load_model_path, model_device=device) ai8x.update_model(model) model = model.to(device) # Run model with torch.no_grad(): sample_img_rgb_batch = rgb_img_batch.to(device) model_out_rgb = model(sample_img_rgb_batch) # Retrieve model output out_vals = np.argmax(model_out_rgb[0, :, :, :].detach().cpu().numpy(), axis=0) plt.imshow(out_vals, cmap='Greys')
圖5包括以YCbCr格式給出的示例外部測(cè)試數(shù)據(jù)項(xiàng),相應(yīng)的RGB圖像以及執(zhí)行所有必需轉(zhuǎn)換后的模型輸出。首先,色彩空間需要轉(zhuǎn)換為RGB。然后,應(yīng)將圖像縮減像素采樣,使其具有 352 × 352 分辨率。下一個(gè)操作是折疊,需要轉(zhuǎn)換和規(guī)范化。
圖5.在 YCbCr 色彩空間、RGB 空間和模型輸出中采樣縱向圖像。
審核編輯:郭婷
-
微控制器
+關(guān)注
關(guān)注
48文章
7487瀏覽量
151045 -
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4762瀏覽量
100535 -
人工智能
+關(guān)注
關(guān)注
1791文章
46853瀏覽量
237546
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論