Python初心者に毛が生えた程度の筆者が画像分類の評価をした時のお話。
ただのPython文法メモ的な。
分類問題設定
1000枚ぐらいある画像に対してラベル分類する問題です。
1枚にscore付きで複数ラベル出てくるのですが、一定の閾値を設けてそれ以上のラベルを採用 = 多ラベル分類 という問題設定で評価しました。
やりたいことがほぼ同じだったので、Qiitaのこちらの記事のやり方をパクりました。
動作環境
- runtime ··· Google Colaboratory
- DB ··· BigQuery
BigQueryにすでに正解データとMLモデルによる予測結果が入っている状態からスタートです。
正解データ ※画像3は該当するラベルなしを明示的に識別できるようにNoneという文字列が入っているとする
img_name | label |
---|---|
画像1 | ラベル1 |
ラベル2 | |
画像2 | ラベル3 |
画像3 | None |
予測結果 ※imagesはSTRUCT型の配列(オブジェクト配列)です。labelとscoreをプロパティにもつ
img_name | images.label | images.score |
---|---|---|
画像1 | ラベル1 | 0.9 |
ラベル2 | 0.8 | |
ラベル3 | 0.3 | |
ラベル4 | 0.2 | |
ラベル5 | 0.1 | |
画像2 | ラベル3 | 0.9 |
画像3 | ラベル1 | 0.1 |
ラベル2 | 0.1 | |
ラベル3 | 0.1 |
作った物
1 2 3 4 5 6 7 8 |
# ライブラリのインポート from google.colab import auth from google.cloud import bigquery import pandas as pd from google.colab import files # BigQueryの認証 auth.authenticate_user() |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# BQクライアントの用意 client = bigquery.Client('[GCPプロジェクト名]') # 予測結果が格納されているTBLの変数 prediction_tbl = '[予測テーブル名]' # 正解データのDataFrameを作成 sql = """ SELECT img_name, label FROM `[dataset].[正解テーブル名]` """ dft = client.query(sql).to_dataframe() # 変数準備 l_precision = list() l_recall = list() x_score = [i / 100 for i in range(0, 100, 1)] # scoreの閾値を変化させながら計算 for i in range(0,100,1): # 変数初期化 TP = FP = FN = 0 precision = recall = 0 # scoreの閾値設定 threshhold = x_score[i] sql = """ SELECT img_name, ARRAY(SELECT label FROM UNNEST(images) WHERE score > """ + str(threshhold) + """) AS imglabel, FROM `""" + predicition_tbl + """` """ dfy = client.query(sql).to_dataframe() # 正解データの画像枚数分ループさせて多ラベル分類のPrecision, Recallを計算 for index in dft.index.values: # 正解データの集合 seikaisyugou = set(dft.loc[index,'imglabel']) if 'None' in seikaisyugou: seikaisyugou = seikaisyugou.remove('None') if seikaisyugou is None: seikaisyugou = set() # 予測データの集合 imgname = dft.loc[index,'img_name'] if len(dfy.query('img_name == @imgname').index) > 0: yosokusyugou = set(dfy.loc[dfy.query('img_name == @imgname').index.values[0],'imglabel']) else: yosokusyugou = set() # 正解データの集合と予測データの集合の積集合を計算 s_intersection = seikaisyugou & yosokusyugou # TPの更新(TPは正解データと予測データの共通部分=積集合) TP += len(s_intersection) # FPの更新(FPは予測データと正解データの差集合(予測にしかないデータ)) FP += len(yosokusyugou - seikaisyugou) # FNの更新(FNは正解データと予測データの差集合(正解にしかないデータ)) FN += len(seikaisyugou - yosokusyugou) # precision, recallの計算 if (TP+FP) == 0: l_precision.append(1.0) else: l_precision.append(TP/(TP+FP)) if (TP+FN) == 0: l_recall.append(1.0) else: l_recall.append(TP/(TP+FN)) # グラフ出力 df = pd.DataFrame({'score': x_score, 'precision': l_precision, 'recall': l_recall}) df.plot(x='score') |
自分的文法メモ
BigQueryの認証、BQクライアントの用意
ここでブラウザが起動してBigQueryに接続するための認証が走る(認証コードをコピペ入力する)。
公式では'[GCPプロジェクト名]'
はなくてもSQLのFROM句に直接記載すれば良さげだったけどうまくいかなかったので記載。
ほげほげのDataFrameを作成
SELECT hogehoge
FROM
""" + predicition_tbl + """
“””
df = client.query(sql).to_dataframe()
ライブラリimportしているBigQuery接続ライブラリfrom google.cloud import bigquery
を使っている。
ライブラリ関係ないけど"""
は改行込み文字列。
文字列連結はあってるのかわからんけど上のように書いたらできた。
SELECT文
というマジックコマンドも使えるらしいが、その場合、
- SQLの中にどうやって変数含めるのかよくわからんかった
- そもそもソースコードの中に埋め込めるのかよくわからんかった(for文で回したいので)
という理由で素直にライブラリでやりました。
0.01刻みの配列(イテレートオブジェクト)
range()
関数ですぐできると思いきや、range()
は引数intしかダメらしい。
一般的に上のように内包表記するらしい。
range()
は第一引数〜第二引数まで間隔=第三引数の配列を返してくれる関数。
データフレームのindexの値取得
戻り値は配列(だったと思う)
集合(set型)から要素削除
seikaisyugou = seikaisyugou.remove(‘None’)
if seikaisyugou is None:
seikaisyugou = set()
remove()
を使う。ただし集合内に削除対象の要素が含まれてないとエラーになる。
最初はdiscard()
を使っていたが、こちらは集合内に削除対象の要素がなくてもエラーにならない代わりに戻り値が空集合になる。
参考にしたサイトの情報が間違っていてめちゃハマった。
データフレームを検索して一致する行のindexを返す
df.query('hogehgoe [比較演算子] honyahonya')
で検索可能。
変数を使いたければ@らしい。
.index
の戻りはNumPyの配列だったと思う。通常のリストではないので注意
データフレームの特定の行、列の値を取得
df.loc[index, カラム名]
。PK相当のカラムがあるテーブルならそいつをダイレクトに指定して取得することはできるのだろうか?
それとも上述のようにquery()
で絞ってそいつのindexを取得してloc[]
でアクセスするのが普通なのだろうか?
積集合、差集合、その他諸々集合演算
全てはここにある。
グラフ出力
df.plot(x=’score’)
pandasのplot()
でできる。裏はmatplotlibなのかな?(エラった時のstacktrace見たらそんな感じだった。知らんけど)
plot()
の引数なしだと、
- x軸 = dataframeのindex
- y軸 = dataframeの各カラム(全てのカラムがプロットされる)
なので、今回のようにx軸を特定カラムにしたい場合はx='[カラム名]'
で指定してあげる。