本文將基于蝰蛇峽谷(Serpent Canyon) 詳細(xì)介紹如何在英特爾獨(dú)立顯卡上訓(xùn)練 TensorFlow 模型的全流程。
1.1 英特爾 銳炫 獨(dú)立顯卡簡(jiǎn)介
英特爾 銳炫 顯卡基于 Xe-HPG 微架構(gòu),Xe HPG GPU 中的每個(gè) Xe 內(nèi)核都配置了一組 256 位矢量引擎,旨在加速傳統(tǒng)圖形和計(jì)算工作負(fù)載,以及新的 1024 位矩陣引擎或 Xe 矩陣擴(kuò)展,旨在加速人工智能工作負(fù)載。
1.2 蝰蛇峽谷簡(jiǎn)介
蝰蛇峽谷(Serpent Canyon) 是一款性能強(qiáng)勁,并且體積小巧的高性能迷你主機(jī),搭載全新一代混合架構(gòu)的第 12 代智能英特爾 酷睿 處理器,并且內(nèi)置了英特爾 銳炫 A770M 獨(dú)立顯卡。
搭建訓(xùn)練 TensorFlow 模型的開發(fā)環(huán)境
Windows 版本要求
訓(xùn)練 TensorFlow 所依賴的軟件包 TensorFlow-DirectML-Plugin 包要求:
Windows 10的版本≥1709
Windows 11的版本≥21H2
用“Windows logo 鍵+ R鍵”啟動(dòng)“運(yùn)行”窗口,然后輸入命令“winver”可以查得Windows版本。
到英特爾官網(wǎng)下載并安裝最新的英特爾顯卡驅(qū)動(dòng)。驅(qū)動(dòng)下載鏈接:
https://www.intel.cn/content/www/cn/zh/download/726609/intel-arc-iris-xe-graphics-whql-windows.html
下載并安裝Anaconda
下載并安裝 Python 虛擬環(huán)境和軟件包管理工具Anaconda:
https://www.anaconda.com/
安裝完畢后,用下面的命令創(chuàng)建并激活虛擬環(huán)境tf2_a770:
conda create --name tf2_a770 python=3.9 conda activate tf2_a770
向右滑動(dòng)查看完整代碼
安裝TensorFlow2
在虛擬環(huán)境 tf2_a770 中安裝 TensorFlow 2.10。需要注意的是:tensorflow-directml-plugin軟件包當(dāng)前只支持TensorFlow 2.10。
pip install tensorflow-cpu==2.10
向右滑動(dòng)查看完整代碼
安裝 tensorflow-directml-plugin
在虛擬環(huán)境 tf2_a770 中安裝 tensorflow-directml-plugin,這是一個(gè)在 Windows 平臺(tái)上的機(jī)器學(xué)習(xí)訓(xùn)練加速軟件包。
// @brief 加載推理數(shù)據(jù) // @param input_node_name 輸入節(jié)點(diǎn)名 // @param input_data 輸入數(shù)據(jù)數(shù)組 public void load_input_data(string input_node_name, float[] input_data) { ptr = NativeMethods.load_input_data(ptr, input_node_name, ref input_data[0]); } // @brief 加載圖片推理數(shù)據(jù) // @param input_node_name 輸入節(jié)點(diǎn)名 // @param image_data 圖片矩陣 // @param image_size 圖片矩陣長(zhǎng)度 public void load_input_data(string input_node_name, byte[] image_data, ulong image_size, int type) { ptr = NativeMethods.load_image_input_data(ptr, input_node_name, ref image_data[0], image_size, type); }
向右滑動(dòng)查看完整代碼
到此,在 Windows 平臺(tái)上用英特爾獨(dú)立顯卡訓(xùn)練 TensorFlow 模型的開發(fā)環(huán)境配置完畢。
在英特爾獨(dú)立顯卡上訓(xùn)練 TensorFlow 模型
下載并解壓 flower 數(shù)據(jù)集
用下載器(例如,迅雷)下載并解壓 flower 數(shù)據(jù)集,下載鏈接:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
下載訓(xùn)練代碼啟動(dòng)訓(xùn)練
請(qǐng)下載 tf2_training_on_A770.py 并放入 flower_photos 同一個(gè)文件夾下運(yùn)行。鏈接:
https://gitee.com/ppov-nuc/training_on_intel_GPU/blob/main/tf2_training_on_A770.py
from pathlib import Path import tensorflow as tf data_dir = Path("flower_photos") image_count = len(list(data_dir.glob('*/*.jpg'))) print("Number of image files:", image_count) # 導(dǎo)入Flower數(shù)據(jù)集 train_ds = tf.keras.utils.image_dataset_from_directory(data_dir, validation_split=0.2, subset="training", seed=123, image_size=(180, 180), batch_size=32) val_ds = tf.keras.utils.image_dataset_from_directory(data_dir, validation_split=0.2, subset="validation", seed=123, image_size=(180, 180), batch_size=32) # 啟動(dòng)預(yù)取和數(shù)據(jù)緩存 train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE) val_ds = val_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE) # 創(chuàng)建模型 model = tf.keras.Sequential([ tf.keras.layers.Rescaling(1./255), tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Dropout(0.2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(5) ]) # 編譯模型 model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) #訓(xùn)練模型 model.fit(train_ds,validation_data=val_ds,epochs=20)
向右滑動(dòng)查看完整代碼
總結(jié)
英特爾獨(dú)立顯卡支持 TensorFlow 模型訓(xùn)練。下一篇文章,我們將介紹在英特爾獨(dú)立顯卡上訓(xùn)練 PyTorch 模型。
審核編輯 :李倩
-
英特爾
+關(guān)注
關(guān)注
60文章
9694瀏覽量
170392 -
模型
+關(guān)注
關(guān)注
1文章
3003瀏覽量
48231 -
tensorflow
+關(guān)注
關(guān)注
13文章
327瀏覽量
60375
原文標(biāo)題:在英特爾獨(dú)立顯卡上訓(xùn)練TensorFlow模型 | 開發(fā)者實(shí)戰(zhàn)
文章出處:【微信號(hào):英特爾物聯(lián)網(wǎng),微信公眾號(hào):英特爾物聯(lián)網(wǎng)】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論