matplotlibを使って混同行列を描画するPythonライブラリをつくってみた

March 03, 2018

はじめに

研究でよく混同行列を作る機会があるんですが、割と使うわりに一発で描画できるような関数がmatplotlibに無いため、練習を兼ねてそれを実現するPythonライブラリを作ってみました。

使い方

ソースコードは以下のGitHubに置いてあります。

https://github.com/yuuuuwwww/useful_graphs

PyPIはこちらです。

https://pypi.python.org/pypi/useful-graphs/

とりあえず以下でインストールできます。

pip install useful_graphs

まず混同行列自体はscikit-learnのconfusion_matrixを利用することを想定しています。 これを使うと、予測結果と正解ラベルすることでnumpyアレイで混同行列が返されます。

そして以下のように呼び出すことで混同行列を描画することができます。

import numpy as np
import useful_graphs

data = np.array([
                    [10, 3, 4],
                    [0, 9, 2],
                    [3, 2, 10],
                ])
classes = ['label-1', 'label-2', 'label-3']
cm = useful_graphs.ConfusionMatrix()
cm.read_cm(data, class_list=classes)
cm.plot()

個人的にはJupyterノートブックで使うと割と便利だなと思います。

以下のようにsave_pathパラメータを指定することで画像として保存することができます。

cm.plot(save_path="path_to_figure.pdf")

デフォルトでは横方向に正規化していて各ラベルに対する精度と、その下にデータ数を表示しています。 これもパラメータを指定すれば正規化をオフにすることもできます。

工夫した点

カラーマップの背景色が暗い時に文字を白くするという点で少し工夫してみました。 背景色のrelative luminanceを計算して、ある閾値を超えた時に文字を白くする仕様にしています。 そのままmin-maxで正規化して閾値で切るよりも、いろんな色で綺麗に文字が見えるようになりました。

おわりに

ライブラリの名前がuseful_graphsなのに混同行列しか今は描けないので、よく使うけどmatplotlibやseabornでまだないような機能を見つけて実装してみたいと思います。