Blog
PyTorchで深層学習データセットを効率的に取り扱うために
AI Labの大田(@ciela)です。近頃はリサーチエンジニアとして研究成果を最大化するためのエンジニアリングに責任を持って日々活動しています。
深層学習データセット取扱の課題
昨今の深層学習に用いられるデータセット内のデータ数は一般的に大規模です。実際に学習する上ではデータセットをシャッフルしつつ繰り返しロードし、場合によっては適切な前処理やData Augmentationだってしなくてはなりません。それらの処理を並列化することで、学習にかかる時間を少しでも減らそうと涙ぐましい努力をされている方も多いことでしょう。その一方で、時には数万件にも及ぶファイルに対してランダムアクセスしつつ何度もロードするアクセスパターンは入出力効率を低下させてしまうためストレージデバイスにとって優しくありません。この矛盾を解決するために、格納プロトコルによって規格化された独自アーカイブファイルにデータセットを変換し、それを一定サイズに分割(シャーディング)した形で保管するアプローチが採用されたりします。ある程度まとまった容量のファイルをシーケンシャルにロードすることで、ストレージの入出力効率を高めることが狙いです。
深層学習フレームワークの代表格であるTensorFlowにおいてはプロトコルバッファ形式でデータをシリアライズ・アーカイブするTFRecordという仕組みが提供されています。例えばGoogle Cloud Platform (GCP)のCloud Storageに200〜300MB程度の容量ごとにシャーディングされた学習データのTFRecordファイルを配置しておけば、あとはそのパスパターンを指定するだけで学習データをシーケンシャルにロードすることができ、ファイル単体ごとにランダムにロードする場合と比べて学習速度の向上が期待できます。学習データの前処理パイプライン構築に最適なApache Beam(GCP Dataflow)からも簡単に書き出せるため、GCP上でTensorFlowを利用している方にとっては既に当たり前の仕組みかもしれません。一方で、TensorFlowと肩を並べる深層学習フレームワークであるPyTorchではどういったアプローチが考えられているのでしょうか。昨年8月、PyTorch公式ブログにEfficient PyTorch I/O library for Large Datasets, Many Files, Many GPUsという記事が投稿されており、記事内ではWebDatasetというTensorFlowでのTFRecordに相当する仕組みが紹介されています。本記事ではこのWebDatasetについて簡単にまとめ、実際のデータセットでの利用例をお見せしようと思います。
WebDataset
2021年3月現在はPyTorchとは別リポジトリで開発が進められているものの、既にRFCとしてプロポーザルが上がっており将来的にはPyTorch内部に取り入れられるかもしれません。
WebDatasetの大きな特徴として、実体がPOSIX tar形式でアーカイブされていることが挙げられます。*nixなコンピュータを利用する方にとっては馴染み深すぎる形式ですね。データセット内の各データとラベルを同一のbasenameを持ったファイルとして対応付けた状態でアーカイブしており、tarコマンドで中身を確認したり展開したりといったことも容易に行なえます。読み込み用の特別な仕組みを必要としないため、ポータブルなデータセットアーカイブとしての役割も果たしていると言えるでしょう。下記は手元で作成したWebDatasetの中身を閲覧してみた例です。
1 2 3 4 5 |
$ tar tvf bam-wds-0001.tar | head -n 4 -r--r--r-- bigdata/bigdata 14 2021-03-30 12:02 content_flower/4.attr # ラベルファイル -r--r--r-- bigdata/bigdata 442991 2021-03-30 12:02 content_flower/4.jpg # データファイル -r--r--r-- bigdata/bigdata 14 2021-03-30 12:02 content_flower/5.attr # ラベルファイル -r--r--r-- bigdata/bigdata 176634 2021-03-30 12:02 content_flower/5.jpg # データファイル |
WebDatasetを読み込んだあとの、シャッフル・torch.Tensor化・バッチ化などの学習中前処理はTensorFlow tf.data.Datasetと似たような書き方で定義できます。具体的にはメソッドチェーンを利用して処理を繋げていくfluent interfaceとなっており、下記のような形で記述できます。
1 2 |
# 100件のバッファでデータをシャッフル、RGB3チャンネルのtorch.Tensorとして実体化、バッチサイズ4でバッチ化 wds.WebDataset("bam-wds-0001.tar").shuffle(100).decode("torchrgb").batched(4) |
この他にも、メディアファイルの拡張子に基づく自動デコード・エンコード、取得したリソース(シャード、レコード共に)のローカルキャッシュ、などアーカイブされたデータセットを取り扱う上で嬉しい工夫が凝らされてます。詳しくはGitHubリポジトリのJupyter Notebook集によくまとまっていますのでこちらをご参照ください。
既存のデータセットをWebDatasetへ変換
実際に既存のデータセットをWebDatasetへ変換してみたいと思います。例としてBehance Artistic Media(BAM)データセットを利用してみましょう。BAMデータセットの実体はSQLiteファイルとなっており、DBレコードのsrcカラムに指定されたBehance CDNのURLを参照することで画像をロードできますが、学習時にいちいちインターネットアクセスをしていては学習ルーチンの効率が悪すぎます。さらに既にアクセスが不可能となってしまっている(404)画像に対する例外処理が必要であったりと、そのままでは扱いが若干ややこしいデータセットです。画像を予めダウンロードしてローカルストレージに保存しておき、各DBレコードのURLをファイルパスへ変換しつつランダムアクセスするのがPyTorchでの素直な学習ワークロードになるかと思われますが、ここではそれをWebDatasetに置き換えてみましょう。
今回は下記のようなSQLを発行して得られる、BAMデータセットDBのcrowd_labelsテーブルによる属性ラベルを画像と対応付けたレコードをデータセットとして考え、これをWebDataset化したいと思います。BAMデータセットの主要なDBスキーマについては以前の記事で簡単に紹介しておりますのでよろしければご覧ください。
クエリ
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
SELECT * FROM ( SELECT ROW_NUMBER() OVER(PARTITION BY attribute) AS attrseq, -- ラベルごとの通し番号 src, -- 画像URL attribute -- ラベル FROM modules INNER JOIN crowd_labels ON modules.mid = crowd_labels.mid WHERE label = 'positive' ) WHERE attrseq <= 10 |
レコードサンプル
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
1|https://mir-s3-cdn-cf.behance.net/project_modules/disp/e0533d11631403.560fad7cb723b.jpg|content_bicycle 2|https://mir-s3-cdn-cf.behance.net/project_modules/disp/88e45042270399.5607181dccf11.jpg|content_bicycle 3|https://mir-s3-cdn-cf.behance.net/project_modules/disp/72053821069043.562fb004d1c67.jpg|content_bicycle 4|https://mir-s3-cdn-cf.behance.net/project_modules/disp/37cbc824014175.5632ca1c9f9bb.jpg|content_bicycle 5|https://mir-s3-cdn-cf.behance.net/project_modules/disp/5cd51b922413.5629a98a73644.jpg|content_bicycle 6|https://mir-s3-cdn-cf.behance.net/project_modules/disp/e84475922413.5629a93f2986e.jpg|content_bicycle 7|https://mir-s3-cdn-cf.behance.net/project_modules/disp/7456a329931731.560b01dbde761.jpg|content_bicycle 8|https://mir-s3-cdn-cf.behance.net/project_modules/disp/a6c90115285239.5628f21d09ca1.jpg|content_bicycle 9|https://mir-s3-cdn-cf.behance.net/project_modules/disp/80cdd024014175.5632ca1caebcf.jpg|content_bicycle 10|https://mir-s3-cdn-cf.behance.net/project_modules/disp/8aae9918574135.562cbbaa0c1d3.jpg|content_bicycle 1|https://mir-s3-cdn-cf.behance.net/project_modules/disp/02a0e610386217.560e418f245eb.JPG|content_bird 2|https://mir-s3-cdn-cf.behance.net/project_modules/disp/01d54d10386217.560e4167dc493.JPG|content_bird 3|https://mir-s3-cdn-cf.behance.net/project_modules/disp/eb2a8d10204191.560e10d04ed41.jpg|content_bird 4|https://mir-s3-cdn-cf.behance.net/project_modules/disp/c4b5d220463133.562ebb165443f.jpg|content_bird 5|https://mir-s3-cdn-cf.behance.net/project_modules/disp/6b1b1120463133.562ebb165faa8.jpg|content_bird (以下略) |
今回は少量データで試すために各属性ごとのデータは10件ずつに絞っています。さらにこのデータセットをWebDatasetに変換するために下記のようなコードを実行してみましょう。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
db_file = "20170509-bam-1m-18UThu3ICM.sqlite" wds_pattern = "bam-wds-%04d.tar.gz" with sqlite3.connect(db_file) as con,\ wds.ShardWriter(wds_pattern, maxsize=10_000_000, compress=True) as writer: cur = con.cursor() cur.execute(QUERY) # 上記クエリ for row in cur: url = row[1] try: _, ext = os.path.splitext(url) with urlopen(url) as res: body = res.read() sample = { "__key__": f"{row[2]}/{row[0]}", # [attr]/[attrseq] ext[1:].lower(): body, # png or jpg "attr": row[2], } writer.write(sample) log.debug(f"Archived as WDS {row}") except HTTPError as e: log.warning(f"{url}: {e.code}") except URLError as e: log.warning(f"{url}: {e.reason}") |
URLにアクセスして取得できたバイト列をリソースの拡張子に合わせてWebDatasetに一件一件格納していきます。さらにシャーディングを試すために圧縮前の最大ファイルサイズはあえて10MBに制限しています(デフォルト3GB)。実行した結果、手元の環境に4つのtarballファイルが生成されました。データセットのWebDataset化自体はこれで完了です。簡単ですね。
1 2 |
$ ls bam-wds-*.tar.gz bam-wds-0000.tar.gz bam-wds-0001.tar.gz bam-wds-0002.tar.gz bam-wds-0003.tar.gz |
あとは読み込みです。上記のローカルファイルをファイルパスとして直接読み込んでも良いのですが、せっかくなのでこれらが外部サーバに配置された状態を想定し、HTTP経由でウェブリソースとして取得するシミュレーションを行ってみましょう。Python3のhttp.serverモジュールを利用して簡易ウェブサーバを立上げWebDatasetをローカル配信してみます。
1 |
$ python3 -m http.server |
この状態で下記のコードを実行してみましょう。WebDatasetに格納されているデータセットから、画像はRGB3チャンネルでサイズ224×224のtorch.Tensor、ラベルはstrとしてそれぞれデコード・前処理を行って読み込むコードになります。
1 2 3 4 5 6 7 8 9 10 11 12 |
wds_urls = "http://localhost:8000/bam-wds-{0000..0003}.tar.gz" dataset = ( wds.WebDataset(wds_urls, shardshuffle=True) .shuffle(100) .decode("torchrgb") .to_tuple("jpg;png", "attr") .map_tuple(transforms.Resize((224, 224)), lambda x: x.decode()) .batched(4) ) loader = data.DataLoader(dataset, num_workers=4, batch_size=None) for tensors, attrs in loader: log.info(f"Tensor: {tensors.size()} Labels: {attrs}") |
実行結果は下記のようになりました。無事WebDatasetを読み込むことに成功し、画像とその属性ラベルがバッチ化された状態で取得できたことが確認できました。
1 2 3 4 5 6 |
Tensor: torch.Size([4, 3, 224, 224]) Labels: ['content_cars', 'content_bicycle', 'content_bicycle', 'content_building'] Tensor: torch.Size([4, 3, 224, 224]) Labels: ['content_tree', 'emotion_peaceful', 'emotion_peaceful', 'content_flower'] Tensor: torch.Size([4, 3, 224, 224]) Labels: ['emotion_scary', 'emotion_scary', 'media_comic', 'media_3d_graphics'] Tensor: torch.Size([3, 3, 224, 224]) Labels: ['media_watercolor', 'media_watercolor', 'media_watercolor'] Tensor: torch.Size([4, 3, 224, 224]) Labels: ['content_building', 'content_bird', 'content_bird', 'content_bicycle'] (以下略) |
おわりに
本記事では深層学習データセットを取り巻く課題とWebDatasetによるアプローチを簡単に解説しました。PyTorchで大規模なメディアデータセットを取り扱う上での指針となるのではないでしょうか。
AI Labではこのような研究のためのエンジニアリングをより加速させていきたいと思います。今後は採用も考えていければと思うのでこういった活動に興味のある方はどうぞご連絡ください。
Author