在前面幾節中,我們了解了概率論和隨機變量。為了將這一理論付諸實踐,讓我們介紹一下樸素貝葉斯分類器。這只使用概率基礎知識來讓我們執行數字分類。
學習就是做假設。如果我們想要對以前從未見過的新數據示例進行分類,我們必須對哪些數據示例彼此相似做出一些假設。樸素貝葉斯分類器是一種流行且非常清晰的算法,它假設所有特征彼此獨立以簡化計算。在本節中,我們將應用此模型來識別圖像中的字符。
%matplotlib inline
import math
import tensorflow as tf
from d2l import tensorflow as d2l
d2l.use_svg_display()
22.9.1。光學字符識別
MNIST ( LeCun et al. , 1998 )是廣泛使用的數據集之一。它包含 60,000 張用于訓練的圖像和 10,000 張用于驗證的圖像。每個圖像包含一個從 0 到 9 的手寫數字。任務是將每個圖像分類為相應的數字。
GluonMNIST
在模塊中提供了一個類data.vision
來自動從 Internet 檢索數據集。隨后,Gluon 將使用已經下載的本地副本。train
我們通過將參數的值分別設置為True
或來指定我們是請求訓練集還是測試集False
。每個圖像都是一個灰度圖像,寬度和高度都是28具有形狀(28,28,1). 我們使用自定義轉換來刪除最后一個通道維度。此外,數據集用無符號表示每個像素8位整數。我們將它們量化為二進制特征以簡化問題。
data_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
lambda x: torch.floor(x * 255 / 128).squeeze(dim=0)
])
mnist_train = torchvision.datasets.MNIST(
root='./temp', train=True, transform=data_transform, download=True)
mnist_test = torchvision.datasets.MNIST(
root='./temp', train=False, transform=data_transform, download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./temp/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00
0%| | 0/28881 [00:00
Extracting ./temp/MNIST/raw/train-labels-idx1-ubyte.gz to ./temp/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./temp/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00
Extracting ./temp/MNIST/raw/t10k-images-idx3-ubyte.gz to ./temp/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./temp/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00
Extracting ./temp/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./temp/MNIST/raw
((train_images, train_labels), (
test_images, test_labels)) = tf.keras.datasets.mnist.load_data()
# Original pixel values of MNIST range from 0-255 (as the digits are stored as
# uint8). For this section, pixel values that are greater than 128 (in the
# original image) are converted to 1 and values that are less than 128 are
# converted to 0. See section 18.9.2 and 18.9.3 for why
train_images = tf.floor(tf.constant(train_images / 128, dtype = tf.float32))
test_images = tf.floor(tf.constant(test_images / 128, dtype = tf.float32))
train_labels = tf.constant(train_labels, dtype = tf.int32)
test_labels = tf.constant(test_labels, dtype = tf.int32)
我們可以訪問一個特定的示例,其中包含圖像和相應的標簽。
我們的示例存儲在此處的變量中image
,對應于高度和寬度為28像素。
我們的代碼將每個圖像的標簽存儲為標量。它的類型是 32位整數。
label,
評論
查看更多