Magicode logo
Magicode
4 min read

torchvision.transformsでランダムに画像を90°単位で回転させる

https://cdn.apollon.ai/media/notebox/blob_7LrD1Xz

はじめに

torchvision.transformsで画像を90°単位でランダム回転させたい!
そう思ったのが事の始まりです。

RandomRotationじゃダメなの?

torchvision.transformsでランダム回転をさせるにはRandomRotationがあるのですが、90°単位で回転させるとなるとちょっと痒い所に手が届かない感じ。
主に次の二点が問題。
  • 指定した角度の範囲内でランダムに回転させるものとなっていること
    • ※角度の範囲指定を(90, 90)とかにすればできなくはないのですが、90°, 180°, 270°でそれぞれ用意する必要があるためあまりスマートではないですよね。
  • 回転対象がTensor型の場合、その配列がPILで読み込める(例えば配列中身のデータ型がint8である)ものでないとエラーになってしまうこと
    • ※例えば地理データとかを画像っぽくCNNで扱うとき(画素値がfloat型だったり、そもそもチャンネル数が1や3でなかったり)とかエラーになってしまうのです。。。
ということで、そういったものにも柔軟に対応できる自作のtransformを実装してみました。

実装

以下が実装したものです。
動作の確認も後半で用意してます。 コード実行できなくなったのですね。。。
python
# pip
!pip install torch torchvision pillow scikit-image

Requirement already satisfied: torch in /srv/conda/envs/notebook/lib/python3.7/site-packages (1.11.0) Requirement already satisfied: torchvision in /srv/conda/envs/notebook/lib/python3.7/site-packages (0.12.0) Requirement already satisfied: pillow in /srv/conda/envs/notebook/lib/python3.7/site-packages (9.1.0) Requirement already satisfied: scikit-image in /srv/conda/envs/notebook/lib/python3.7/site-packages (0.19.2) Requirement already satisfied: typing-extensions in /srv/conda/envs/notebook/lib/python3.7/site-packages (from torch) (4.0.1) Requirement already satisfied: numpy in /srv/conda/envs/notebook/lib/python3.7/site-packages (from torchvision) (1.19.5) Requirement already satisfied: requests in /srv/conda/envs/notebook/lib/python3.7/site-packages (from torchvision) (2.27.1)
Requirement already satisfied: scipy>=1.4.1 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from scikit-image) (1.7.3) Requirement already satisfied: networkx>=2.2 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from scikit-image) (2.6.3) Requirement already satisfied: tifffile>=2019.7.26 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from scikit-image) (2021.11.2) Requirement already satisfied: imageio>=2.4.1 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from scikit-image) (2.19.1) Requirement already satisfied: packaging>=20.0 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from scikit-image) (21.3) Requirement already satisfied: PyWavelets>=1.1.1 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from scikit-image) (1.3.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from packaging>=20.0->scikit-image) (3.0.7) Requirement already satisfied: charset-normalizer~=2.0.0 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from requests->torchvision) (2.0.10) Requirement already satisfied: certifi>=2017.4.17 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from requests->torchvision) (2021.10.8) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from requests->torchvision) (1.26.8) Requirement already satisfied: idna<4,>=2.5 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from requests->torchvision) (3.3)
python
import random, torch
from PIL import Image
from torchvision import transforms

class RandomRotation90:
  def __init__(self, p=0.5):
    self.p = p

  def __call__(self, x):
    if random.random() < self.p:
      i = random.randint(1, 3)

      if isinstance(x, Image.Image):
        x = transforms.RandomRotation((90*i, 90*i), expand=True)(x)

      elif isinstance(x, torch.Tensor):
        x = torch.rot90(x, i, [1, 2])

      else:
        raise TypeError(f'{type(x)} is unexpected type.')

    return x

動作確認

scilit-imageのサンプルデータに用意がある猫のチェルシー君を回転させて動作を見て行きます。
python
import skimage
origin = skimage.data.chelsea()
origin = Image.fromarray(origin)
origin

<PIL.Image.Image image mode=RGB size=451x300>

PIL

まずはPILで読み込まれた画像の回転。
引数pは回転をさせる確率です。
今回は動作確認なので100%回転するよう1.0を指定しています(以後の動作確認も同様)。
python
transformer = RandomRotation90(p=1.0)
rotated = transformer(origin)
print(f'{type(origin)} >> {type(rotated)}')
rotated

<class 'PIL.Image.Image'> >> <class 'PIL.Image.Image'>
<PIL.Image.Image image mode=RGB size=300x451>

Tensor

python
# origin を Tensor型に変換
tensor = transforms.ToTensor()(origin)

# 回転
transformer = RandomRotation90(p=1.0)
rotated = transformer(tensor)
print(f'{type(tensor)} >> {type(rotated)}')

# 表示するために再度PILに変換
rotated = transforms.ToPILImage()(rotated)
rotated

<class 'torch.Tensor'> >> <class 'torch.Tensor'>
<PIL.Image.Image image mode=RGB size=451x300>

パイプライン化

torchvision.Composeでパイプラインに組み込むこともできます。
ここでは先ほどTensor型の動作確認のときに、PIL画像をTensor型に変換していたところをパイプラインで一気に処理が流れるようにします。
python
transformer = transforms.Compose([      
    transforms.ToTensor(),                         
    RandomRotation90(p=1.0)        
])
rotated = transformer(origin)
print(f'{type(origin)} >> {type(rotated)}')

# 表示するために再度PILに変換
rotated = transforms.ToPILImage()(rotated)
rotated

<class 'PIL.Image.Image'> >> <class 'torch.Tensor'>