Blog

ONNXモデルのチューニングテクニック (応用編2)


基礎編 / 応用編1 / 応用編2

サイバーエージェント AI LabHuman Computer Interaction Team に所属している兵頭です。今回は私が半年ほど蓄積したONNXのチューニングテクニックを全てブログに残したいと思います。皆さんが既にご存知であろう基本的なことから、かなりトリッキーなチューニングまで幅広くご紹介したいと思います。長文になりますがご容赦願います。今回は応用編2です。

8. 各種トリック

PyTorchやTensorFlowを経由して生成されたONNXから他のフレームワークへモデルを転用する場合に有効な様々なトリック、ワークアラウンドをご紹介します。あまり真新しい知見ではありませんが、以下でご紹介するようなトリックが記事としてまとまっているものはあまり見かけませんのでご参考になれば幸いです。エンジニアよりもリサーチャーの方々に是非読んでいただきたい内容です。ほとんどの内容が地味で目立たないテクニックですが、実用まで見据えたうえではとても重要なものが一部含まれていると考えています。この最終回では実装の内容へ完全にフォーカスします。

8-1. GridSample (Bilinear no-loop only)

GridSample は ONNX の opset=16 から使用可能なオペレーションです。opset=16 が使用可能になるまで、古い PyTorchonnxruntime では生成できませんでした。これは、PyTorchからONNXへエクスポートするときにバックエンドでコールしている onnx-optimizeronnxruntime、PyTorchの onnx exporter の実装に依存しており、全てのランタイムが最新化されていなければ使用できないためです。しかし、アテンション系やステレオ画像をインプットとして扱うモデルでのアフィン変換で使用する必要があるなど、ニッチですが実は需要が高いオペレーションでもあります。

opset=16 では 下図のようにとてもシンプルな GridSample OP が生成されます。
image

このまま opset=16 の設定のモデルで onnxruntime を使用する場合は特に問題になることはありません。しかし、ONNXの利用シーンは最も基本的な利用形態の ONNX + CPUONNX + CUDA に加えて ONNX + OpenVINOONNX + DirectMLONNX + NNAPIONNX + CoreMLONNX + TensorRT などの様々なハードウェアやフレームワークへ最適化された環境で利用することも多いと思います。その場合、転用先のフレームワークによっては GridSample のような最新のオペレータに対応しておらず、CPUやCUDAへフォールバックしてしまい、動作が遅くなったりすることが多いです。 特に TensorRT は Jetson などのエッジデバイスで高速に推論するために利用されている方が多いと思いますが、JetPack のバージョンによって利用できる TensorRT のバージョンが大きく異なるため、サポートされるオペレーションとサポートされないオペレーションの差が大きいです。

下記の Jetpack の概要を読むと、Jetson ORIN/NX で利用可能な JetPack 5.1 では TensorRT 8.5.2 が利用できるのに対して、Jetson Nano で利用可能な JetPack 4.6 では TensorRT 8.2.1 とあり、マイナーバージョンで 3 の差というのはかなり大きな差で、扱えるオペレーションの種類が大きく異なります。
https://developer.nvidia.com/embedded/jetpack

ご参考までに TensorRT 8.5TensorRT 8.4 の差分を下記の通りマイナーバージョン 1 つ分の差ではどれほど違うかが分かるURLを共有します。

TensorRT 8.5 で対応しているオペレータの一覧は下記のとおりです。

TensorRT 8.4 で対応しているオペレータの一覧は下記のとおりです。

TensorRT 8.2 で対応しているオペレータの一覧は下記の通りです。

opset=17 まで対応しているバージョンでも GridSample に対応しているか、対応していないかの差分があります。また、バージョン差によるバグの内在などの問題によって、一見すると対応しているオペレーションに見えても正常に動作しないこともあります。したがって、最新のバージョンでモデルの構造を大きく最適化することだけが運用するうえでの最適な選択とは言えないため、ココではあえて GridSample をプリミティブなレイヤーへ分解して再実装します。
では、opset=16GridSampleopset=11GridSample (Bilinear no-loop only) を生成してみます。

  • make_GridSample.py

画像が小さく見づらいですが、下図のとおりの生成結果となりました。独自実装の opset=11GridSample はプリミティブなオペレーションのみで構成されているため、比較的古いバージョンのフレームワーク群でも実行可能な構成になっています。

  • opset=16 で生成した GridSample (再掲)
    image

  • opset=11 で生成した自力実装の pseudo_GridSample
    image

    最適化と機能性のトレードオフのギリギリを攻めていますので、Loop などのファンクショナルなオペレーションを使用せずに無理やりリニアに実装していることと、4次元までの Bilinear にしか対応していません。

    20230222172855

念の為、処理パフォーマンスにどれほど差が出るかを見てみます。手元の TensorRT バージョンは 8.5.3、CUDA バージョンは 11.6、cuDNN バージョンは 8.6.0、GPUは RTX3070、CPUは Corei9 Gen.10 です。1ms未満の世界の話ではありますが、CUDAあるいはCPUを選択すると独自実装では 5倍〜10倍程度 のパフォーマンス劣化が発生するようです。GridSample を使用できる環境の場合は素直に GridSample を使用したほうが良さそうですね。TensorRT以外のフレームワークではベンチマークしませんが、その他のフレームワークに GridSample と同じ動作をするオペレーションがある場合は問題にならないかもしれません。オペレーションが存在しない場合はいずれにせよプリミティブなオペレーションへ分解する必要はありそうです。

  • opset=16 で生成した GridSample


  • opset=11 で生成した自力実装の pseudo_GridSample


8-2. Mulit-Class Batched NonMaxSuppression

基礎編 で簡単にご紹介した Multi-Batch / Multi-Class の NonMaxSuppression (NMS) 実装の具体例をご紹介します。まずは PyTorch から正攻法で生成可能な NMS です。

torchvision.opsnms を使用していますが、バッチサイズの指定に対応していないため N=10 (box数)4 (x1y1x2y2) だけを指定してオペレーションを生成しています。

  • NMS

    https://pytorch.org/vision/main/generated/torchvision.ops.nms.html

  • 生成サンプル

    一見すると問題が無いように見えますが、nms に指定可能なパラメータにバッチサイズやクラス数が存在しないため、ONNX生成時にバッチサイズ 1、クラス数 1 が自動的に割り当てられています。したがって、Multi-BatchMulti-Class の NMS を生成する場合は大量のNMSを生成する必要があります。また注意点として、PyTorch の NMS は X1Y1X2Y2 の座標入力を期待しますが、ONNX の NMS は Y1X1Y2X2 の座標入力を期待します。PyTorch 単体で NMS を含めた後処理を構築する場合はあらかじめ X1Y1X2Y2Y1X1Y2X2 などに転置しておくことを忘れないようにしないといけません。また、このように必要機能を満たさないオペレーションを生成するぐらいなら、ロジック側で素直にNMSの処理を実装したほうが効率的です。
    image

  • Mulit-Class Batched NonMaxSuppression の実装紹介

    ONNX の NMS は Mulit-ClassMulit-Batch に対応しています。したがって、冗長な複数のNMSをPyTorchで生成するのではなく、ONNX の NMS を直接生成することでマルチクラス化とマルチバッチ化に対応します。Mulit-Class Batched NonMaxSuppression を生成するサンプルスクリプトは下記のとおりですが、Pythonのコードを書かずに NMS を生成します。私の作業環境がUbuntuであるため、便宜上 bash を利用して下記のスクリプトを実行することを想定しています。コマンドプロンプトやその他のシェルで実行する場合は、変数の宣言方法や参照方法、エスケープの記載方法が異なる点にご注意願います。

    5 Batch、80 Class の NMS が生成されました。この NMS はクラスIDを横断的に見て閾値を判定してくれます。生成したNMS単体ではメリットがありませんので、この部品を生成したあとにモデル本体へマージする必要があります。モデルへの部品のマージは 応用編1 をご覧ください。

    image

  • 各フレームワークへ横展開するうえでのポイント (ArgMax / ReduceMax活用による次元圧縮)

    昔から使用されている方法のため、テクニック、というほどのものではありませんが、バッチサイズとクラス数をNMSに入力する前の段階で 1 にシュリンクする方法があります。この方法であれば最近のほぼどのフレームワークでも実行可能な後処理(NMS)が生成されます。なお、ここで言う NMS は、NonMaxSuppression オペレーション単体を指すのではなく、下図画像のほぼ全体のオペレーション群全体を指します。オペレーションの数が多いためとても複雑なことをしているように見えるかもしれませんが、実はこの方法が最も多くのフレームワークで処理可能なシンプルな構造になっています。できるかぎりプリミティブなオペレーションを組み合わせると共にバッチサイズやクラス数を 1 に限定することで、ほとんどのフレームワークで実行可能なNMSになります。(1) ArgMax でスコアが最大のクラスIDを抽出し、(2) ReduceMax で最大スコアを抽出しています。(1) をバウンディングボックス群を特定のクラスIDひとつに絞り込むインデックスとして使用し、(2) の最大スコア (ひとつのクラスのスコア) とともに NonMaxSuppression の入力に使用しています。これは YOLOv7 の後処理をそのまま引用しています。この実装の弱点は、不必要な FlattenGather を大量に使用することにより、バッチサイズを無視してしまう実装になっていることと、1クラスのみのNMSに限定してしまっていることです。ただ、フレームワーク間の横展開の柔軟性が最も高い実装ですので、運用上、複数クラスのNMSが不要な場合や1バッチのみで処理できれば良い場合は良い選択肢になると考えています。

    image

    実装は下記のとおりです。YOLOv7の後処理の実装そのものです。感覚的に、YOLOv7の後処理はとても綺麗に実装されています。ポイントは max_score, category_id = scores.max(2, keepdim=True) の部分で ArgMaxReduceMax を同時に生成しているところのみです。

8-3. Unity Barracuda用 GatherND

Barracuda は Unity がサポートしているデバイス上でONNXによる推論をサポートするためのパッケージです。2023年3月1日時点では v3.0.0 がリリースされているようです。私はあまりUnityを使用したことが無いのですが、VR/AR界隈で使用されているように見えます。Unityの他にも Unreal Engine というエンジンがあり、そちらでもONNXをサポートしているようです。CPU利用がメインのようですが、DirectMLを経由することでGPUを使うこともできるようです。このブログではそれぞれのエンジンの良し悪しを比較することが目的ではありませんので、ここでは Unity Barracuda での事例を取り上げます。

Unity Barracuda を使用してUnityへONNXを取り込み、UnityにONNX推論を行わせることができます。先述のように、すでに Barracuda は多くのオペレーションに対応しているため、多くのONNXモデルを実行することができるように見えます。

image

image

ここで「できるように見えます。」とあえて含みをもたせた表現をしたことには理由があり、仮に対応オペレーション一覧に記載があるオペレーションであったとしても、Barracuda内部の実装の問題で正常に動作しないことがあります。この点に関しては次の項でご紹介します。

さて本題に戻りますが、この項では対応オペレーション一覧に存在しないオペレーション GatherND を Unity Barracuda に対応させるためのワークアラウンドをご紹介します。GatherND を含むモデルを Unity Barracuda に読み込ませるとオペレーション非対応を示すエラーが表示されます。したがって、GatherND を Unity Barracuda に対応したオペレーションのみの組み合わせに置き換えることで無理やり対応させます。

ただし、PyTorchには GatherND が存在しません。自力で GatherND をシミュレートした処理をPyTorchで実装するか、PyTorch以外の手段で代替するか、のいずれかの方法を選択する必要があります。PyTorchの実装をアレコレと試しましたがなかなかうまくいきませんでした。今回は中間処理に TensorFlow を使用して GatherND を生成します。中間処理にわざわざ TensorFlow を使用することにはあまりメリットを感じないかもしれませんが、GatherND は重みを持たない単純なインデクシングのオペレーションですので PyTorch にこだわる必要もありません。最終的な目的がONNXを使用すること、である場合は、TensorFlow を使用して生成した GatherNDtf2onnx を介してONNXへ変換することでONNXの GatherND を生成します。このブログ自体がかなりトリッキーなノウハウを共有することを目的としているため難解な点がある場合はご容赦願います。

TensorFlow GatherND -> tf2onnx -> ONNX GatherND

  • 通常の GatherND

    image

  • Unity Barracuda対応の GatherND

    image

GatherND を生成するためだけにこれほど多くのロジックを書くこと自体がコストに見合わない、という判断があるのはごもっともですが、最小限のオペレーション数で GatherND の処理を実現することができます。また、ここでご紹介したオペレーションの作例は一例です。応用編1 でご紹介したテクニックを使用して直接 GatherND を生成してしまったほうが圧倒的に早いと思います。このテクニックはモデルの実装が公開されていないシチュエーションで主に使用します。

8-4. Unity Barracuda用 Unsqueeze

Unity Barracuda の Unsqueeze オペレーションのハンドリングには致命的なバグがあり、誤った位置に次元が追加されてテンソルが壊れる問題があります。したがって、PyTorchで生成したモデルを最終的に Unity Barracuda へ転用して運用する想定がある場合は、あらかじめ Unsqueeze を別のオペレーションへ置き換えてモデルを生成しておく必要があります。具体的には、次元を外挿したあとのテンソルの形状が自明な場合に最も安全な置き換え方法は Reshape です。

単純に UnsqueezeReshape へ置き換えるだけではありますが、このワークアラウンドを適用してモデルの構造をできるだけ最適化するにあたってのポイントは3つです。

  1. 外挿する次元を指定する axis (dim) のパラメータの定数化が可能な場合はできるだけ定数化してからオペレーションを生成する。
  2. torch.Tensorshape パラメータはモデルによっては x.shape のまま利用するとPyTorchのonnx exporterが型推定に失敗してエラーになることがあるため、shape が自明で定数化が可能な場合はあえてループの内包表記で Python の int型 へキャストしてからリストを再生成する。
  3. Reshapeshape パラメータには -1 を使用しない。

2.に関しては具体的にどういうモデルで発生する事象かまでは言及しませんが、モデルの中間構造が可変形状になるテンソルを持たない場合はできるだけ定数化したほうがオプティマイザの最適化が有効に働きます。3.に関しては、運用フェーズに入ったときに意外とインパクトが出ます。ONNXの形状を書き換えたくなったときに -1 が含まれた形状を shape に指定していると操作不能になります。分かりやすい状況としては、バッチサイズを可変形状に書き換える場合にこの Reshape-1 が邪魔になります。また -1 が含まれている場合、各種フレームワークのバグを誘発しやすくなります。

  • 通常の Unsqueeze (opset=11)

    image

  • Unity Barracuda へ最適化した Reshape(Unsqueeze) (opset=11)

    image

データベースの正規化と同じように、状況によってあえてほんの少しだけコードに冗長性を持たせるだけで、モデル全体の実用性が大幅に向上する例です。

8-5. onnx.onnx_cpp2py_export.shape_inference.InferenceError の回避

PyTorchからONNXを生成するときに、一見すると意味不明なエラー(コード上の表現が正しくPyTorchで正常に動作するもの)が発生することがあります。PyTorchあるいはONNXのバグなのか仕様なのかを判断することが難しいエラーです。例えば下記のようなエラーです。読みやすいように改行を加えています。

エラーが発生する箇所のコードは下記のとおりです。x には4次元のテンソルが入力されてくる想定です。特にこのエラー発生の原因になっている箇所は torch.max(input=x, dim=1)[0].unsqueeze(1) の部分なのですが一見すると意味合いの整合はとれています。指定された次元(dim)を軸にして最大値を抽出し、[0]でスライスして1つ目の出力を取得し、さらに1次元目を拡張する、というだけの処理です。注意点は、torch.max は2つの出力を得るという点で、戻り値が valuesindicies のタプルになるということです。[0] は戻り値のタプルから values を抽出しています。

しかし、下記はエラーになりません。最終的な結果は上記の処理と全く同じですが、次元圧縮と次元展開の処理順序が異なるだけです。torch.max(x, dim=1, keepdim=True) の部分で keepdim=True を指定することで torch.max による次元圧縮を抑止したうえでスライスを施している点がことなります。つまり、Unsqueeze(1) を処理から排除しただけです。おそらく ONNX側の内部実装に問題があるではないか、と推察しています。

ここでご紹介したエラーはめったに発生しないからこそ、いざ発生したときに全く意味が分からなくて解消までに時間が掛かることがある特殊な問題です。onnx の内部実装を解析してプルリクエストを送ったほうが良いかもしれません。実際に気付いている人が居るのかどうかも分かりません。

8-6. PyTorch の Atan2 をONNXにエクスポートするワークアラウンド

ONNXには Atan の実装はありますが Atan2の実装がありません。
https://github.com/onnx/onnx/blob/main/docs/Operators.md#operator-schemas

image

したがって、Atan2 を使用している一部のモデルは何も手を加えない限りONNXへエクスポートすることができません。余談ですが、TensorFlow Lite には Atan2 の実装はありますが Atan の実装がありません。おそらく、少し手を加えるだけでオペレーションを表現できるためだとは思いますが、何故かどちらか一方しか実装されていないことがあります。本項では、PyTorchの Atan2 をONNXで表現します。

この実装は nikola-j/atan2.py をほぼそのまま引用させていただきました。

少々冗長ですが、下図のとおり Atan2 を表現できます。
image

8-7. PyTorch の torch.inverse をONNXにエクスポートするワークアラウンド

Transformer などのモデルで affine_grid が使用されているとき、何らかの方法で affine_grid を代替して再実装する必要があるのですが、計算を組み合わせていく過程で torch.linalg.inv あるいは torch.inverse を使用して逆行列を求めたいことがあります。しかし、ONNXの標準ランタイムには inverse が実装されていないため、inverse そのものをプリミティブなオペレーションへ分解して表現します。

先にお伝えしておくと、inverse が全く使えないわけではなくて、実際には com.microsoft という拡張モジュールを取り込めば利用できます。標準のランタイムにこだわる必要が無い場合は拡張モジュールを利用したほうが早いです。

image

  • https://pytorch.org/docs/stable/generated/torch.linalg.inv.html#torch.linalg.inv

  • 正方行列になるまで再帰処理でひたすらテンソル分解して逆行列を求めるワークアラウンド

  • torch.linalg.det が正方行列にしか対応していない

  • かなり冗長

    参考:https://github.com/pytorch/pytorch/issues/30563

    もっと冗長性を排除したクリーンな実装方法があるかもしれませんが、torch.linalg.det を使用した結果です。正方行列になるまで高次元を分解しますので、実質何次元でも逆行列化できますが、ロジックがとても単調なためOP数が増えてモデルが肥大化します。
    image

8-8. PyTorch の torch.median をONNXにエクスポートするワークアラウンド

PyTorchにはテンソルの中央値を求めるための torch.median というオペレーションが存在しますがONNXには実装されていません。したがって、torch.median を含むモデルをそのままONNXへエクスポートしようとすると下記のようなエラーが発生します。

動作検証が不十分ですのでバグがあるかもしれませんが、torch.median をONNXへエクスポート可能なオペレーションの組み合わせに置き換えて下記のように表現します。テスト用のコードはあえて中央値を求める次元を偶数の次元にしていますので、中央値の判定の仕方に少しクセがあります。実装上は、mode を切り替えることで、ちょうど境目にある2つの要素のうち、大きい方を選択するか、小さい方を選択するか、2つの要素の平均を計算するか、を切り替えできるようにしているつもりです。

  1. mode='ceil' : 境目の2つの要素のうち大きいほうを選択
  2. mode='floor' : 境目の2つの要素のうち小さいほうを選択
  3. mode='mean' : 境目の2つの要素の平均値を計算

かなり回りくどい書き方をしていますが、なるべく Tensor型 で定数を保持しないようにしていることと、形状指定に -1 をできるだけ使用しないテクニックを再現しています。定数を再計算する箇所では torch.xxx() を使用せず、np.xxx() あるいは PythonのList型 を使用して Tensor型 に変換されないようにしています。

最終次元で中央値を比較すると、PyTorchとONNXの出力値が一致しました。

ONNXの形状は下図のとおりです。ArgMax で値の降順にソートしたテンソルを 要素数 // 2 + 1 個抽出し、抽出したテンソルからさらに末尾の要素、あるいは末尾からひとつ前の要素を抽出することで中央値を取得しています。6 // 2 + 1 = 4個+1個を余分に抽出しているのは、mode='mean' のときに、境目にある2つの要素の平均を算出するために、境界をまたいだ2つ分の要素を抽出するためです。あえて回りくどい書き方をしたおかげで、UnsqueezeSqueezeGather などの各種フレームワーク側で推論時に問題となるオペレーションを含めずに torch.median を表現できました。

image

8-9. MVN (MeanVarianceNormalization) のワークアラウンド

MeanVarianceNormalization は入力テンソルの平均-分散正規化を計算します。PyTorchそのものには MeanVarianceNormalization がオペレーションとして実装されていないように見えますが私の調査が足りていないだけかもしれません。少なくともONNXには MeanVarianceNormalization が実装されており、数式で表現してもあまり難しいものではありませんので、オペレーションそのものを利用することにあまり固執する必要は無いかもしれません。

https://github.com/onnx/onnx/blob/main/docs/Changelog.md#MeanVarianceNormalization-13

下記のように実装しました。今回もとても回りくどい実装をしています。PyTorchのひとつのオペレーションで処理できる部分をあえて分解して個別に実装しています。特に意味が無いことをすぐに見つけることができる部分は (reduce_sub * 1.0)* 1.0 の箇所だと思いますが、実はこれには意味があります。まず、MeanVarianceNormalization が実装されているフレームワーク上でしか MeanVarianceNormalization を持つモデルを実行できないことは当然ですが、あえて MeanVarianceNormalization をプリミティブに近いオペレーションの組み合わせに分解して再実装することでほとんどのフレームワーク上で実行可能になります。また、仮にONNXのモデルに MeanVarianceNormalization が含まれているときにそのまま他のフレームワークへ変換すると MVN という名前の等価なオペレーションへそのまま変換されることが多いです。たとえば、ONNXからOpenVINOへ変換するとモデルオプティマイザが自動的に MVN へ置き換えて処理を最適化してくれます。ただ、MeanVarianceNormalization に対応していないフレームワークがその先の変換作業の工程にひとつでも存在するとその時点でモデルの汎用性が著しく低下します。したがって、各種フレームワークのモデルオプティマイザが優秀な最適化性能を持つことを逆手に取って、無意味な演算(計算値に影響を与えない演算)を加えて MeanVarianceNormalization が自動生成されないようにフォローしています。

通常の MeanVarianceNormalization です。

image

左側の分岐の Mul が各フレームワークへ横展開するうえで有効になるワークアラウンドの * 1 の箇所です。これだけでフレームワークの推論エラーを一部回避できるようになります。全体の動作は MeanVarianceNormalization と同じです。出力値は一致します。

image

念の為それぞれの実装でベンチマークしてみます。結果は下記のとおりです。MeanVarianceNormalization 単体の実装のほうが推論速度が数倍速いですがフレームワーク間の汎用性は低いです。

分解した MeanVarianceNormalization のベンチマーク結果です。何故か TensorRT による推論が異様に遅いです。また、CPU推論が異様に速いです。なお、フレームワーク間での汎用性は高めです。

8-10. TopDown 骨格推定モデルのバッチ処理化ワークアラウンド

TopDown アプローチの骨格推定モデルに関するワークアラウンドです。TopDownの骨格推定モデルは、まず1段階目で物体検出モデルなどを使用して人物のエリアをクロップし、2段階目でクロップした人物エリアの画像に対して骨格推定を行うアプローチです。つまり、ひとつの画角に10人の人物が映っている場合は、物体検出された人物画像10人分を TopDown アプローチの骨格推定モデルに入力して利用します。特に難しいアプローチではなく、人物の身体特徴に合わせて Batch x Channel x Height x Width = 1 x 3 x 256 x 128 のような縦長の入力解像度が設定されていることが多いモデル群です。

モデル設計のアプローチ自体には問題がありません。しかし先述したとおり、物体検出で複数人、なおかつ撮影する映像によっては何人検出されるかが分からない状況でバッチサイズを予め固定したモデルを使用すると、無意味なゼロ埋めデータでパディングして推論する必要があるため推論パフォーマンスを最大化できないばかりか、無用なパディングロジックを記述する必要が生じるため処理が少しだけ冗長になります。例えば、最大10人を検出できるように 10 x 3 x 256 x 128 としてしまった場合、画角内に1人の人物しか映っていなくても不足する9人分の無意味なゼロパディングデータを追加して推論する必要がありますのでとても非効率です。フレームワーク側が可変バッチサイズによる推論に対応している場合(一部のエッジデバイス向けフレームワークを除いて直近の1年でほとんどのフレームワークが対応しました)はバッチサイズを可変サイズのままにしてモデルを生成しておき、必要な画像枚数だけを推論に掛けたいです。

一例として、こちらでプルリクエストを取り込んで頂き、MMPose の動作を変更しました。実装上、とても特別な配慮が必要な内容ではありません。

https://github.com/open-mmlab/mmpose/pull/1242

PyTorchからONNXへモデルをエクスポートする時点で動的サイズを指定できることはご存知だと思いますが、バッチサイズを可変にするにあたって注意しなければいけないポイントがあります。それは、モデルの中間部分のオペレーションにバッチサイズ以外の次元に -1 で形状指定された Reshape が含まれる場合に多くのフレームワークで可変バッチによる推論ができなくなることです。正確には、1次元目のバッチサイズにあたる部分に可変形状を指定できなくなるため、無理やりバッチサイズを可変形状に書き換えるとモデルが破損します。例えば、[-1, -1, 64, 48] のような状況になることを指します。下図は、バッチサイズを可変形状に書き換える前の Reshape の動作を示していますが、仮にこのバッチサイズにあたる1次元目を -1 に書き換えた場合、[-1, -1, 64, 48] となり、推論時にテンソルの形状を正しく推定できなくなりランタイムがAbortします。これはどのフレームワークでも共通の動作です。

image

したがって、下図のとおり Reshape の動作を変更します。これは、PyTorchでモデルを設計する時点で実用性を向上させる観点での配慮が必要な事項です。

image

モデル設計上の修正点は下記のとおりで承認いただけました。

  • From:

    To:

モデルを実際に実戦投入して利用することを想定する場合は、サイズが自明な次元には必ず定数値でサイズ指定していただくようにしたほうが良いと感じています。

8-11. 入力テンソルのサイズが大きめの ArgMax のEdgeTPU対応とTensorRT対応と推論速度向上のワークアラウンド

背景透過などの用途でセグメンテーションモデルを利用されている方々がいらっしゃるかもしれません。世の中に出回っているセグメンテーション系のモデルはどんどん高性能化しており、最近ではかなり綺麗に背景透過ができるようになってきています。ただ、セグメンテーションモデルは全体的に処理負荷が高いものが多く、スペックが高めのマシンを使用することを前提として設計されていることが多いです。仮に、セグメンテーションモデルをロースペックなデバイスで利用しようとした場合に課題・問題となる点は下記のとおりです。

  1. セグメンテーションモデルの後処理をモデルの最後尾にマージしてピクセル単位でクラス分類が終わった状態で最終出力を得たい
  2. クラス分類をするとなると ArgMax を使用したいが、ArgMax の処理負荷が高く推論速度が非常に遅い
  3. バッファが小さいデバイスでは ArgMax を内部処理しきれず Abort してしまうことがある

そもそも、ArgMax の処理が遅いという話は昔からある内容ですので、これからご紹介するようなトリックを使用しない限りは上記の問題をクリアできません。では、どのように問題を小さくするかという点を以下にご紹介します。

FusedArgMax という仕組みをモデルに導入します。ArgMax に入力するテンソルのサイズをできる限り小さく押さえて演算し、ArgMax から出力されたテンソルをもともと必要としていたサイズへリサイズします。ものすごく単純ですね。

引用:https://github.com/tensorflow/models/tree/master/official/projects/edgetpu/vision#argmax-fusion-to-improve-segmentation-model-latency

image

image

ArgMax にテンソルを入力する時にフルサイズの解像度で入力するのではなく、一定比率のスケールへ縮小した状態で処理したあとに、ふたたび本来のスケールへリサイズしています。元の入力解像度が 512x512 でリスケール後の解像度が 256x256 だとした場合、要素数にして 2,097,152 個分の要素の演算を削減できます。スマートフォンの Pixel6 での推論パフォーマンスがおよそ2倍に向上しています。実装上は例えば、torch.nn.functional.interpolate の拡大率を調整するだけの簡単なカスタマイズでパフォーマンスを向上できます。

8-12. Korniaの最適化(rgb_to_ycbcr, ycbcr_to_rgb)

Kornia にはPyTorchで利用可能な便利なオペレーションが多数用意されています。その中でも画像処理で利用されることのある YCbCr 変換に関する処理を取り上げたいと思います。なお、モデルの精度向上などの観点ではなく色変換をする際の処理の書き方による汎用性の変化についてのみフォーカスします。

まずは、Kornia の公式実装の ycbcr_to_rgb です。

image

次に、少しだけ Kornia の実装を改造した ycbcr_to_rgb です。ほとんど公式実装と差がありませんので無意味な修正に感じるかもしれません。公式実装の GatherSlice に置き換わり、スライスしたチャンネルの次元が保持されています。また、公式実装のモデルの最後方の Unsqueeze が全て消失しています。つまりたったこれだけで、Unsqueeze が無くなったことでフレームワーク間の転用がしやすくなったということです。ちなみに、モデル最後尾の ycbcr_to_rgb の部分のみをクローズアップしていますが、モデルの入口の部分には逆の変換の rgb_to_ycbcr が存在するため、合計で6個の Unsqueeze を消失させることができます。モデルを設計・変換する際には、できるだけ次元の圧縮を伴う Gather を使用するのではなく、次元の圧縮を伴わない Slice を使用したほうがモデル全体の構造を綺麗にすることができます。これは、モデルの構造が大きくなればなるほどインパクトが大きくなります。Korniaにプルリクエストをすればいいじゃないか、と思う方がいらっしゃるかもしれませんが、しません。

image

他にもまだ様々なトリックがありますが、終わらなくなるためここで止めます。

9. ONNXから各種フレームワーク向けモデルへの変換

PyTorchからONNXを生成できてさえいれば、あとはどのフレームワークへも期待通りの形式で簡単に変換することができます。主要な機械学習フレームワーク用、ランタイム用のモデルフォーマットへの変換方法を簡単にご紹介します。既出の情報を含みますが、あらためて一気通貫で全部記載します。

9-1. TensorFlow

onnx2tf を使用します。2022年09月に開発を始めた自作ツールですが、最初から他のフレームワークへモデルを転用することを想定した設計にしていますので、現時点ではRNNに対応していないことを除いて最も変換効率が高いです。基本構文は下記です。オプションを何も指定せずに変換すると、 TFLite のモデルが生成されます。

メタデータなども含めて saved_model を出力する場合は下記の通りです。主に saved_model を他のフレームワーク向けに変換するために使用します。

Keras の .h5 形式で出力する場合は下記のとおりです。

Keras の keras_v3 形式で出力する場合は下記のとおりです。

MS-COCO データセットのサンプルデータ 20件 を使用して自動キャリブレーションとINT8量子化を行う場合は下記のとおりです。デフォルト動作は per-channel の量子化です。

per-tensor の量子化を行う場合は下記のとおりです。

オペレーションごとの変換前後の推論誤差を計測する場合は下記の通りです。

下図GIFのように、全てのオペレーションに対してONNXと変換後のTensorFlowの推論誤差を計測して表示します。主にモデル変換時のデバッグに使用します。

212460284-f3480105-4d94-4519-94dc-320d641f5647

TFLite のバイナリを書き換えて、ONNXの入出力名と入出力オーダーにTFLiteの入出力名と入出力オーダーを合わせることができます。ただし、私が独自にカスタマイズした flatbuffers-compiler を導入する必要があります。上記のDocker環境内にはすでに導入済みです。

https://github.com/PINTO0309/onnx2tf#environment

 

image

flatbuffers-compiler の改造内容を知りたい方はこちらをご覧ください。コンパイラが生成する量子化値の算術精度不足を一部だけ解消しています。

https://github.com/PINTO0309/onnx2tf/issues/196

その他にも、FlexOP を無効化するオプションや疑似オペレーションを生成するオプション、 Transpose を指定した次元数まで分解するオプションやこのブログでご紹介した Fused ArgMax を組み込むオプションなど、様々なオプションが用意されていますので気になる方は README をご覧ください。

9-2. TensorFlow.js

昔からある方法ですので今更細かくはご紹介しません。上記の onnx2tf で -osd を指定して生成した saved_model を使用して生成可能です。

9-3. TensorRT

trtexec というTensorRTに付属するコンバーターを使用するか、ONNX の TensorRT ExecutionProvider で推論開始時に自動変換します。匠要素は少ないですが後者のほうが簡単です。手抜きをするならば、sit4onnx という本ブログのベンチマークで使用していたツールを実行するだけで trtengine を自動生成することができます。

image

推論コードは下記です。Dockerの中で実行すれば TensorRT で動作確認できます。

9-4. CoreML

私は iPhone ユーザーではありませんので生成された CoreML モデルの動作確認をすることができませんが、下記でモデルを変換可能です。

https://coremltools.readme.io/docs/unified-conversion-api

image

なお、NCHW形式のCoreMLモデルを生成したい場合は onnx2tf を実行する際に -k オプションを使用して下記のとおり処理します。

  • ONNX

    image

-k オプションはONNXの入力OPの形状を維持するようにツール動作を変更します。-ois オプションは入力形状を指定した静的形状へ上書きします。

  • CoreML

    image

9-5. OpenVINO

OpenVINO 2022.3 が手元の環境にインストールされている前提での変換コマンドは下記です。つまづく要素がありません。なお、さらに OAK-D (Myriad) に対応した形式へコンバートする場合は、ONNXの生成時点でMyriad用の多くのワークアラウンドが必要となることがあります。

X. おわりに

少々長くなりましたが、ONNXのモデルチューニングテクニック (応用編2) は以上です。これでチューニングテクニックシリーズは完結とします。

Author

アバター
hyodo