はじめに
ナレッジグラフ(KG)の埋め込み手法であるProjEを実装,実験してみました. 他の手法に比べると引用数が少ないですが,結構シンプルで精度も出て,リンクの予測ができたりベクトルの事前学習も必要ないなどいろいろな利点が存在するようです. 論文は以下になります.バージョンが2つあり,訓練のアルゴリズムは古い方のものにだけ書いてあります.
[1611.05425] ProjE: Embedding Projection for Knowledge Graph Completion
ナレッジグラフの埋め込みでは,Tramp is Presidentのような主語,述語, 目的語のセット(RDFトリプル) に対して
を満たす空間を設計することを目的としています. 有名な手法としてはTransEなどがあります.
Translating Embeddings for Modeling Multi-relational Data
本論文ではそれらの従来手法よりパラメータ数を減らしつつ,高精度なエンティティ,リンク予測を 行える埋め込み空間を設計する手法を提案しています. ただ,エンティティ予測の際,TransEなどが埋め込みベクトルの類似度のみでそれを行なっている のに対して,本手法では予測のための重みを掛けてから類似度をはかるということをしています. なので単純なエンベディングとはちょっと違うかもしれません.
手法の概要
まず,エンティティベクトルとリレーションベクトルを足し合わせる演算として次を定義します.
ここでとは,ベクトルの次元を縦横のサイズとして持つ対角行列です. つまりベクトルの各要素を定数倍するだけのものになります.ここで普通のdenseな行列でええやろとしていたら,精度があまり出ませんでした.
次に,この関係を満たすエンティティを探すため, 候補エンティティのベクトルを列として持つ行列に対して次の演算を行います.
式中のf, gはsigmoid,tanhなどの活性化関数です.ここでは,で得られたベクトルと, 全ての候補エンティティとの類似度計算をしています.の番目の要素は 候補との類似度ということになります.
最後に損失関数の部分になります. 損失としては関係を満たすエンティティと満たさないエンティティを分類するpointwise lossと, 候補エンティティが正解となる確率について負の対数尤度を最小化するlistwise,listwiseに重みを追加したwlistwiseが提案されています. point wiseの式は以下になります.
ほぼBinary Cross Entropyです.第一項は正例を評価していて,第二項は負例を評価しています. 負例は,e+rを満たさないエンティティから二項分布でサンプリングしてきます.
そしてwlistwiseは以下になります.
softmaxで出力を確信度の形式にしてから,負の対数尤度を取ります. wlistwiseでは対数尤度の足し合わせの際,(e,r)の関係を満たすエンティティの総数の逆数を 重みとして掛け合わせます.通常のlistwiseの際は重みを利用しません.
実装・実験
実装はPytorchで行い,実験はQuadro P6000を用いました. 全然整理できていないのですが,コードは以下になります.
GitHub - sheepover96/ProjE.torch: pytorch implementation of ProjE: Embedding Projection for KGC
また,tensorflowによる公式実装は以下になります.
GitHub - bxshi/ProjE: Embedding Projection for Knowledge Graph Completion
公式実装は学習結果の表示が綺麗です.あと公式なので多分正確です. 私の実装の方は公式に比べ学習が数倍早いのですが,メモリをめちゃくちゃ食います. これはネガティブサンプリングのための候補をキャッシュしているかどうかの違いだと思います.
データセットはFB15KとWN18RRを利用しました.
Relation Prediction | NLP-progress
評価は,論文中で利用されているMean RankとHIT@10を用いました. Mean Rankは,目標のエンティティが現れるまでのランクの平均で, HIT@10は,目標のエンティティが10位以内にあらわれる割合です.
結果は次のようになりました.オレオレの方はtailの予測を行なった結果ですが, 公式の方はtail, head予測どちらをやってるのかわかりませんでした.また, 実験結果はwlistwiseのものだけを示しています.
実装 | MeanRank | HITS@10 |
---|---|---|
論文 | 124 | 54.7 |
公式(tail) | 182.9 | 49.4 |
公式(head) | 275 | 41.5 |
オレオレ(tail) | 153 | 57.8 |
オレオレ(head) | 252 | 49.9 |
公式実装の方は,以前動かした時は論文ぐらいの精度だったのですが,なぜか今回は調子が悪いです. オレオレ実装は,部分的に論文に勝ったり負けたりしていますが,それっぽい結果は出ているんじゃないかと思います.