計算機視覺是一個顯著增長的領域,有許多實際應用,從自動駕駛汽車到面部識別系統。該領域的主要挑戰之一是獲得高質量的數據集來訓練機器學習模型。
Torchvision作為Pytorch的圖形庫,一直服務于PyTorch深度學習框架,主要用于構建計算機視覺模型。
為了解決這一挑戰,Torchvision提供了訪問預先構建的數據集、模型和專門為計算機視覺任務設計的轉換。此外,Torchvision還支持CPU和GPU的加速,使其成為開發計算機視覺應用程序的靈活且強大的工具。
什么是“Torchvision數據集”?
Torchvision數據集是計算機視覺中常用的用于開發和測試機器學習模型的流行數據集集合。運用Torchvision數據集,開發人員可以在一系列任務上訓練和測試他們的機器學習模型,例如,圖像分類、對象檢測和分割。數據集還經過預處理、標記并組織成易于加載和使用的格式。
據了解,Torchvision包由流行的數據集、模型體系結構和通用的計算機視覺圖像轉換組成。簡單地說就是“常用數據集+常見模型+常見圖像增強”方法。
Torchvision中的數據集共有11種:MNIST、CIFAR-10等,下面具體說說。
Torchvision中的11種數據集
MNIST手寫數字數據庫
這個Torchvision數據集在機器學習和計算機視覺領域中非常流行和廣泛應用。它由7萬張手寫數字0-9的灰度圖像組成。其中,6萬張用于訓練,1萬張用于測試。每張圖像的大小為28×28像素,并有相應的標簽表示它所代表的數字。
要訪問此數據集,您可以直接從Kaggle下載或使用torchvision加載數據集:
importtorchvision.datasetsasdatasets#Loadthetrainingdataset train_dataset=datasets.MNIST(root='data/',train=True,transform=None,download=True)#Loadthetestingdataset test_dataset=datasets.MNIST(root='data/',train=False,transform=None,download=True)
左右滑動查看完整代碼
CIFAR-10(廣泛使用的標準數據集)
CIFAR-10數據集由6萬張32×32彩色圖像組成,分為10個類別,每個類別有6000張圖像,總共有5萬張訓練圖像和1萬張測試圖像。這些圖像又分為5個訓練批次和一個測試批次,每個批次有1萬張圖像。數據集可以從Kaggle下載。
importtorchimporttorchvisionimporttorchvision.transformsastransforms transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) trainset=torchvision.datasets.CIFAR10(root='./data',train=True, download=True,transform=transform) testset=torchvision.datasets.CIFAR10(root='./data',train=False, download=True,transform=transform) trainloader=torch.utils.data.DataLoader(trainset,batch_size=4, shuffle=True,num_workers=2) testloader=torch.utils.data.DataLoader(testset,batch_size=4, shuffle=False,num_workers=2)左右滑動查看完整代碼
在此提醒一句,您可以根據需要調整數據加載器的批處理大小和工作進程的數量。
CIFAR-100(廣泛使用的標準數據集)
CIFAR-100數據集在100個類中有60,000張(50,000張訓練圖像和10,000張測試圖像)32×32的彩色圖像。每個類有600張圖像。這100個類被分成20個超類,用一個細標簽表示它的類,另一個粗標簽表示它所屬的超類。
importtorchimporttorchvisionimporttorchvision.transformsastransforms importtorchvision.datasetsasdatasetsimporttorchvision.transformsastransforms#Definetransformtonormalizedata transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])]) #LoadCIFAR-100trainandtestdatasets trainset=datasets.CIFAR100(root='./data',train=True,download=True,transform=transform) testset=datasets.CIFAR100(root='./data',train=False,download=True,transform=transform) #Createdataloadersfortrainandtestdatasets trainloader=torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True) testloader=torch.utils.data.DataLoader(testset,batch_size=64,shuffle=False)左右滑動查看完整代碼
ImageNet數據集
Torchvision中的ImageNet數據集包含大約120萬張訓練圖像,5萬張驗證圖像和10萬張測試圖像。數據集中的每張圖像都被標記為1000個類別中的一個,如“貓”、“狗”、“汽車”、“飛機”等。
importtorchvision.datasetsasdatasetsimporttorchvision.transformsastransforms #SetthepathtotheImageNetdatasetonyourmachine data_path="/path/to/imagenet" #CreatetheImageNetdatasetobjectwithcustomoptions imagenet_train=datasets.ImageNet( root=data_path, split='train', transform=transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]), download=False) imagenet_val=datasets.ImageNet( root=data_path, split='val', transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]), download=False) #Printthenumberofimagesinthetrainingandvalidationsetsprint("Numberofimagesinthetrainingset:",len(imagenet_train))print("Numberofimagesinthevalidationset:",len(imagenet_val))左右滑動查看完整代碼
MSCoco數據集
Microsoft Common Objects in Context(MS Coco)數據集包含32.8萬張日常物體和人類的高質量視覺圖像,通常用作實時物體檢測中比較算法性能的標準。
Fashion-MNIST數據集
時尚MNIST數據集是由Zalando Research創建的,作為原始MNIST數據集的替代品。Fashion MNIST數據集由70000張服裝灰度圖像(訓練集60000張,測試集10000張)組成。
圖片大小為28×28像素,代表10種不同類別的服裝,包括:t恤/上衣、褲子、套頭衫、連衣裙、外套、涼鞋、襯衫、運動鞋、包和短靴。它類似于原始的MNIST數據集,但由于服裝項目的復雜性和多樣性,分類任務更具挑戰性。這個Torchvision數據集可以從Kaggle下載。
importtorchimporttorchvisionimporttorchvision.transformsastransforms #Definetransformations transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))])#Loadthedataset trainset=torchvision.datasets.FashionMNIST(root='./data',train=True, download=True,transform=transform) testset=torchvision.datasets.FashionMNIST(root='./data',train=False, download=True,transform=transform) #Createdataloaders trainloader=torch.utils.data.DataLoader(trainset,batch_size=4, shuffle=True,num_workers=2) testloader=torch.utils.data.DataLoader(testset,batch_size=4, shuffle=False,num_workers=2)左右滑動查看完整代碼
SVHN數據集
SVHN(街景門牌號)數據集是一個來自谷歌街景圖像的圖像數據集,它由從街道級圖像中截取的門牌號的裁剪圖像組成。它包含所有門牌號及其包圍框的完整格式和僅包含門牌號的裁剪格式。完整格式通常用于對象檢測任務,而裁剪格式通常用于分類任務。
SVHN數據集也包含在Torchvision包中,它包含了73,257張用于訓練的圖像、26,032張用于測試的圖像和531,131張用于額外訓練數據的額外圖像。
importtorchvisionimporttorch #Loadthetrainandtestsets train_set=torchvision.datasets.SVHN(root='./data',split='train',download=True,transform=torchvision.transforms.ToTensor()) test_set=torchvision.datasets.SVHN(root='./data',split='test',download=True,transform=torchvision.transforms.ToTensor()) #Createdataloaders train_loader=torch.utils.data.DataLoader(train_set,batch_size=64,shuffle=True) test_loader=torch.utils.data.DataLoader(test_set,batch_size=64,shuffle=False)
左右滑動查看完整代碼
STL-10數據集
STL-10數據集是一個圖像識別數據集,由10個類組成,總共約6000+張圖像。STL-10代表“圖像識別標準訓練和測試集-10類”,數據集中的10個類是:飛機、鳥、汽車、貓、鹿、狗、馬、猴子、船、卡車。您可以直接從Kaggle下載數據集。
importtorchvision.datasetsasdatasetsimporttorchvision.transformsastransforms #Definethetransformationtoapplytothedata transform=transforms.Compose([ transforms.ToTensor(), #ConvertPILimagetoPyTorchtensor transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))#Normalizethedata]) #LoadtheSTL-10dataset train_dataset=datasets.STL10(root='./data',split='train',download=True,transform=transform) test_dataset=datasets.STL10(root='./data',split='test',download=True,transform=transform)
左右滑動查看完整代碼
CelebA數據集
這個Torchvision數據集是一個流行的大規模面部屬性數據集,包含超過20萬張名人圖像。2015年,香港中文大學的研究人員首次發布了這一數據。CelebA中的圖像包含40個面部屬性,如,年齡、頭發顏色、面部表情和性別。
此外,這些圖片是從互聯網上檢索到的,涵蓋了廣泛的面部外觀,包括不同的種族、年齡和性別。每個圖像中面部位置的邊界框注釋,以及眼睛、鼻子和嘴巴的5個地標點。
importtorchvision.datasetsasdatasetsimporttorchvision.transformsastransforms transform=transforms.Compose([ transforms.CenterCrop(178), transforms.Resize(128), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) celeba_dataset=datasets.CelebA(root='./data',split='train',transform=transform,download=True)
左右滑動查看完整代碼
PASCAL VOC數據集
VOC數據集(視覺對象類)于2005年作為PASCAL VOC挑戰的一部分首次引入。該挑戰旨在推進視覺識別的最新水平。它由20種不同類別的物體組成,包括:動物、交通工具和常見的家用物品。這些圖像中的每一個都標注了圖像中物體的位置和分類。注釋包括邊界框和像素級分割掩碼。
數據集分為兩個主要集:訓練集和驗證集。
訓練集包含大約5000張帶有注釋的圖像,而驗證集包含大約5000張沒有注釋的圖像。此外,該數據集還包括一個包含大約10,000張圖像的測試集,但該測試集的注釋是不可公開的。
importtorchimporttorchvisionfromtorchvisionimporttransforms #Definetransformationstoapplytotheimages transform=transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) #Loadthetrainandvalidationdatasets train_dataset=torchvision.datasets.VOCDetection(root='./data',year='2007',image_set='train',transform=transform) val_dataset=torchvision.datasets.VOCDetection(root='./data',year='2007',image_set='val',transform=transform)#Createdataloaders train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=32,shuffle=True) val_loader=torch.utils.data.DataLoader(val_dataset,batch_size=32,shuffle=False)
左右滑動查看完整代碼
Places365數據集
Places365數據集是一個大型場景識別數據集,擁有超過180萬張圖像,涵蓋365個場景類別。Places365標準數據集包含約180萬張圖像,而Places365挑戰數據集包含5萬張額外的驗證圖像,這些圖像對識別模型更具挑戰性。
importtorchimporttorchvisionfromtorchvisionimporttransforms #Definetransformationstoapplytotheimages transform=transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) #Loadthetrainandvalidationdatasets train_dataset=torchvision.datasets.Places365(root='./data',split='train-standard',transform=transform) val_dataset=torchvision.datasets.Places365(root='./data',split='val',transform=transform)#Createdataloaders train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=32,shuffle=True) val_loader=torch.utils.data.DataLoader(val_dataset,batch_size=32,shuffle=False)
左右滑動查看完整代碼
總結
總之,Torchvision數據集通常用于訓練和評估機器學習模型,如卷積神經網絡(CNNs)。這些模型通常用于計算機視覺應用,任何人都可以免費下載和使用。本文的主要圖像是通過HackerNoon的AI穩定擴散模型生成的。
審核編輯:湯梓紅
-
gpu
+關注
關注
28文章
4700瀏覽量
128697 -
計算機
+關注
關注
19文章
7419瀏覽量
87713 -
數據庫
+關注
關注
7文章
3765瀏覽量
64274 -
深度學習
+關注
關注
73文章
5492瀏覽量
120975 -
pytorch
+關注
關注
2文章
803瀏覽量
13146
原文標題:你需要知道的11個Torchvision計算機視覺數據集
文章出處:【微信號:vision263com,微信公眾號:新機器視覺】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論