導讀
本文總結了13種圖像增強技術的pytorch實現方法,附代碼詳解。?
使用數據增強技術可以增加數據集中圖像的多樣性,從而提高模型的性能和泛化能力。主要的圖像增強技術包括:
- 調整大小
- 灰度變換
- 標準化
- 隨機旋轉
- 中心裁剪
- 隨機裁剪
- 高斯模糊
- 亮度、對比度調節
- 水平翻轉
- 垂直翻轉
- 高斯噪聲
- 隨機塊
- 中心區域
- 調整大小
在開始圖像大小的調整之前我們需要導入數據(圖像以眼底圖像為例)。
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/000001.tif')) torch.manual_seed(0) # 設置 CPU 生成隨機數的 種子 ,方便下次復現實驗結果 print(np.asarray(orig_img).shape) #(800, 800, 3) #圖像大小的調整 resized_imgs = [T.Resize(size=size)(orig_img) for size in [128,256]] # plt.figure('resize:128*128') ax1 = plt.subplot(131) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(132) ax2.set_title('resize:128*128') ax2.imshow(resized_imgs[0]) ax3 = plt.subplot(133) ax3.set_title('resize:256*256') ax3.imshow(resized_imgs[1]) plt.show()
灰度變換
此操作將RGB圖像轉化為灰度圖像。
gray_img = T.Grayscale()(orig_img) # plt.figure('resize:128*128') ax1 = plt.subplot(121) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(122) ax2.set_title('gray') ax2.imshow(gray_img,cmap='gray')
標準化
標準化可以加快基于神經網絡結構的模型的計算速度,加快學習速度。
從每個輸入通道中減去通道平均值
將其除以通道標準差。
normalized_img = T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(T.ToTensor()(orig_img)) normalized_img = [T.ToPILImage()(normalized_img)] # plt.figure('resize:128*128') ax1 = plt.subplot(121) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(122) ax2.set_title('normalize') ax2.imshow(normalized_img[0]) plt.show()
隨機旋轉
設計角度旋轉圖像
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) rotated_imgs = [T.RandomRotation(degrees=90)(orig_img)] print(rotated_imgs) plt.figure('resize:128*128') ax1 = plt.subplot(121) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(122) ax2.set_title('90°') ax2.imshow(np.array(rotated_imgs[0]))
?
?
中心剪切
剪切圖像的中心區域
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) center_crops = [T.CenterCrop(size=size)(orig_img) for size in (128,64)] plt.figure('resize:128*128') ax1 = plt.subplot(131) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(132) ax2.set_title('128*128°') ax2.imshow(np.array(center_crops[0])) ax3 = plt.subplot(133) ax3.set_title('64*64') ax3.imshow(np.array(center_crops[1])) plt.show()
隨機裁剪
隨機剪切圖像的某一部分
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) random_crops = [T.RandomCrop(size=size)(orig_img) for size in (400,300)] plt.figure('resize:128*128') ax1 = plt.subplot(131) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(132) ax2.set_title('400*400') ax2.imshow(np.array(random_crops[0])) ax3 = plt.subplot(133) ax3.set_title('300*300') ax3.imshow(np.array(random_crops[1])) plt.show()
高斯模糊
使用高斯核對圖像進行模糊變換
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) blurred_imgs = [T.GaussianBlur(kernel_size=(3, 3), sigma=sigma)(orig_img) for sigma in (3,7)] plt.figure('resize:128*128') ax1 = plt.subplot(131) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(132) ax2.set_title('sigma=3') ax2.imshow(np.array(blurred_imgs[0])) ax3 = plt.subplot(133) ax3.set_title('sigma=7') ax3.imshow(np.array(blurred_imgs[1])) plt.show()
亮度、對比度和飽和度調節
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) # random_crops = [T.RandomCrop(size=size)(orig_img) for size in (832,704, 256)] colorjitter_img = [T.ColorJitter(brightness=(2,2), contrast=(0.5,0.5), saturation=(0.5,0.5))(orig_img)] plt.figure('resize:128*128') ax1 = plt.subplot(121) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(122) ax2.set_title('colorjitter_img') ax2.imshow(np.array(colorjitter_img[0])) plt.show()?水平翻轉
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) HorizontalFlip_img = [T.RandomHorizontalFlip(p=1)(orig_img)] plt.figure('resize:128*128') ax1 = plt.subplot(121) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(122) ax2.set_title('colorjitter_img') ax2.imshow(np.array(HorizontalFlip_img[0])) plt.show()
垂直翻轉
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) VerticalFlip_img = [T.RandomVerticalFlip(p=1)(orig_img)] plt.figure('resize:128*128') ax1 = plt.subplot(121) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(122) ax2.set_title('VerticalFlip') ax2.imshow(np.array(VerticalFlip_img[0])) # ax3 = plt.subplot(133) # ax3.set_title('sigma=7') # ax3.imshow(np.array(blurred_imgs[1])) plt.show()?高斯噪聲
向圖像中加入高斯噪聲。通過設置噪聲因子,噪聲因子越高,圖像的噪聲越大。
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) def add_noise(inputs, noise_factor=0.3): noisy = inputs + torch.randn_like(inputs) * noise_factor noisy = torch.clip(noisy, 0., 1.) return noisy noise_imgs = [add_noise(T.ToTensor()(orig_img), noise_factor) for noise_factor in (0.3, 0.6)] noise_imgs = [T.ToPILImage()(noise_img) for noise_img in noise_imgs] plt.figure('resize:128*128') ax1 = plt.subplot(131) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(132) ax2.set_title('noise_factor=0.3') ax2.imshow(np.array(noise_imgs[0])) ax3 = plt.subplot(133) ax3.set_title('noise_factor=0.6') ax3.imshow(np.array(noise_imgs[1])) plt.show()
隨機塊
正方形補丁隨機應用在圖像中。這些補丁的數量越多,神經網絡解決問題的難度就越大。
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) def add_random_boxes(img,n_k,size=64): h,w = size,size img = np.asarray(img).copy() img_size = img.shape[1] boxes = [] for k in range(n_k): y,x = np.random.randint(0,img_size-w,(2,)) img[y:y+h,x:x+w] = 0 boxes.append((x,y,h,w)) img = Image.fromarray(img.astype('uint8'), 'RGB') return img blocks_imgs = [add_random_boxes(orig_img,n_k=10)] plt.figure('resize:128*128') ax1 = plt.subplot(131) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(132) ax2.set_title('10 black boxes') ax2.imshow(np.array(blocks_imgs[0])) plt.show()
中心區域
和隨機塊類似,只不過在圖像的中心加入補丁
from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np import sys import torch import numpy as np import torchvision.transforms as T plt.rcParams["savefig.bbox"] = 'tight' orig_img = Image.open(Path('image/2.png')) def add_central_region(img, size=32): h, w = size, size img = np.asarray(img).copy() img_size = img.shape[1] img[int(img_size / 2 - h):int(img_size / 2 + h), int(img_size / 2 - w):int(img_size / 2 + w)] = 0 img = Image.fromarray(img.astype('uint8'), 'RGB') return img central_imgs = [add_central_region(orig_img, size=128)] plt.figure('resize:128*128') ax1 = plt.subplot(131) ax1.set_title('original') ax1.imshow(orig_img) ax2 = plt.subplot(132) ax2.set_title('') ax2.imshow(np.array(central_imgs[0])) # # ax3 = plt.subplot(133) # ax3.set_title('20 black boxes') # ax3.imshow(np.array(blocks_imgs[1])) plt.show()
?
編輯:黃飛
評論
查看更多