日々のつれづれ

不惑をむかえ戸惑いを隠せない男性の独り言

tSNEをまとめた

R ライブラリの tSNE を比べる

作者の van der Maaten 氏が公式 HP で tSNE をまとめている 。

R は複数の tSNE を持つ 。それぞれ異なるので、各々を比較する。

一覧

ライブラリ 関数名 CRAN リンク ソースコード
tsne tsne link source
Rtsne Rtsne link source
mp tSNE link source

ライブラリ別の説明

tsne

公式の Python 実装のbhtsne を R に実装したもの 。

C++を使わない R だけのコード 。実装状態がわかるのが利点。

tsne = function(X, initial_config = NULL, k=2, initial_dims=30, perplexity=30,
              max_iter = 1000, min_cost=0,epoch_callback=NULL,whiten=TRUE, epoch=100)

tsne:T - Distributed Stochastic Neighbor Embedding for R(t - SNE) A "pure R" implementation of the t - SNE algorithm.

Version:0.1 - 3 Published:2016 - 07 - 15 Author:Justin Donaldson Maintainer:Justin Donaldson < jdonaldson at gmail.com > BugReports:https:// github.com/ jdonaldson/ rtsne/ issues License:GPL - 2 | GPL - 3[expanded from:GPL] URL:https:// github.com/ jdonaldson/ rtsne/

Rtsne

Rtsne は公式の tSNE を高速化したBarnes-Hut t-SNE の実装 。

neighbors の距離計算にBarnes-Hutアルゴリズム を使う 。

内部で C++を呼び出す 。C++のコードを見つけられなかった。

Rtsne.default = function(X, dims=2, initial_dims=50, perplexity=30, theta=0.5,
                        check_duplicates=TRUE, pca=TRUE, partial_pca=FALSE,
                        max_iter=1000,verbose=getOption("verbose", FALSE),
                        is_distance=FALSE, Y_init=NULL, pca_center=TRUE,
                        pca_scale=FALSE, normalize=TRUE,
                        stop_lying_iter=ifelse(is.null(Y_init),250L,0L),
                        mom_switch_iter=ifelse(is.null(Y_init),250L,0L),
                        momentum=0.5, final_momentum=0.8, eta=200.0,
                        exaggeration_factor=12.0, num_threads=1, ...)
{
  ...
  out = do.call(Rtsne_cpp,
              c(list(X=X, distance_precomputed=is_distance, num_threads=num_threads),
              tsne.args))
  ...
}

Rtsne.defaultRtsne_cpp を 、Rtshe_cpp_Rtsne_Rtsne_cpp を Call する 。

Rtsne_cpp = function(X, no_dims, perplexity, theta, verbose, max_iter,
                    distance_precomputed, Y_in, init, stop_lying_iter,
                    mom_switch_iter, momentum, final_momentum, eta,
                    exaggeration_factor, num_threads)
{
  ...
  .Call(`_Rtsne_Rtsne_cpp`, X, no_dims, perplexity, theta, verbose, max_iter,
        distance_precomputed, Y_in, init, stop_lying_iter, mom_switch_iter,
        momentum, final_momentum, eta, exaggeration_factor, num_threads)
  ...
}

Rtsne: T-Distributed Stochastic Neighbor Embedding using a Barnes-Hut Implementation An R wrapper around the fast T-distributed Stochastic Neighbor Embedding implementation by Van der Maaten (see https://github.com/lvdmaaten/bhtsne/ for more information on the original implementation).

Version: 0.15 Imports: Rcpp (≥ 0.11.0), stats LinkingTo: Rcpp Suggests: irlba, testthat Published: 2018-11-10 Author: Jesse Krijthe [aut, cre], Laurens van der Maaten [cph] (Author of original C++ code) Maintainer: Jesse Krijthe License: file LICENSE URL: https://github.com/jkrijthe/Rtsne

mp::tSNE

Multidimensional Projection Techniques の mp パッケージにある関数。

関数の中でmp_tSNEを Call している。mp_tSNEの詳細は不明。

tSNE = function(X, Y = NULL, k = 2, perplexity = 30.0, n.iter = 1000, eta = 500,
              initial.momentum = 0.5, final.momentum = 0.8,
              early.exaggeration = 4.0, gain.fraction = 0.2,
              momentum.threshold.iter = 20, exaggeration.threshold.iter = 100,
              max.binsearch.tries = 50)
{
  ...
  .Call("mp_tSNE", X, Y, perplexity, k, n.iter, is.dist, eta, initial.momentum,
        final.momentum, early.exaggeration, gain.fraction, momentum.threshold.iter,
        exaggeration.threshold.iter, max.binsearch.tries, PACKAGE = "mp")
  ...
}

mp: Multidimensional Projection Techniques Multidimensional projection techniques are used to create two dimensional representations of multidimensional data sets.

Version: 0.4.1 Depends: R (≥ 1.8.0) Imports: Rcpp (≥ 0.11.0) LinkingTo: Rcpp, RcppArmadillo Published: 2016-08-15 Author: Francisco M. Fatore, Samuel G. Fadel Maintainer: Francisco M. Fatore License: GPL-2 | GPL-3 [expanded from: GPL]

RcppArmadilloを内部で使っている。

RcppArmadillo は C++線形代数ライブラリ。

'Armadillo' is a templated C++ linear algebra library (by Conrad Sanderson) that aims towards a good balance between speed and ease of use. Integer, floating point and complex numbers are supported, as well as a subset of trigonometric and statistics functions. Various matrix decompositions are provided through optional integration with LAPACK and ATLAS libraries. The 'RcppArmadillo' package includes the header files from the templated 'Armadillo' library. Thus users do not need to install 'Armadillo' itself in order to use 'RcppArmadillo'. From release 7.800.0 on, 'Armadillo' is licensed under Apache License 2; previous releases were under licensed as MPL 2.0 from version 3.800.0 onwards and LGPL-3 prior to that; 'RcppArmadillo' (the 'Rcpp' bindings/bridge to Armadillo) is licensed under the GNU GPL version 2 or later, as is the rest of 'Rcpp'. Armadillo requires a C++11 compiler.

「作者が決めたデフォルトパラメータを使う」という記載がある。

T-Distributed Stochastic Neighbor Embedding Creates a k-dimensional representation of the data by modeling the probability of picking neighbors using a Gaussian for the high-dimensional data and t-Student for the low-dimensional map and then minimizing the KL divergence between them. This implementation uses the same default parameters as defined by the authors.

比較

勾配法はイプシロン法を使っているが、学習効率を比較できない。

tsne

tsneをベースに考える

  • クラスター数: 2
  • 特異値分解に用いる次元: 30もしくは要素数の少ない方
  • perplexity: 30
  • 勾配法の収束数: 1000
  • コスト: 0
  • ホワイトニング: する
  • 学習効率: わからない

Rtsne

変更したパラメータ

  • initial_dim = 30
  • check_duplicates = FALSE
  • pca = TRUE & partial_pca = FALSE のとき、prcompでPCAする。
  • pca = TRUE & partial_pca = TRUE のとき、irlba::prcomp_irlbaでPCAする。
  • pca_center = TRUE & pca_scale = FALSE でPCAして固有値を求める。
  • normalize = TRUE のとき、行列ともにz変換とスケーリングを行う。
    • tsne::tsne は スケーリングして、内部関数 whiten の中でz変換する。
    • tsne::tsne は whiten の中で、z変換の後に特異値分解する。
  • stop_lying_iter = 250 tsne::tsne の Y_int = NULL
  • mom_switch_iter = 250 が tsne::tsne 規定値。
  • exaggeration_factor = 4 tsne::tsneはP$\times$4が規定値。
  • num_threads = 1 シングルコアで計算する。
  • eta = 200.0 は Rtsne_cpp の引数なので、わからない。

mp::tSNE

変更したパラメータ

  • n.iter = 1000
  • eta = 200
  • early.exaggeration = 4.0
  • max.binsearch.tries = 30
  • initial_dims 不明 (max.binsearch.tries = 30 ?)

比較

関数

setwd("D:/R065021/Desktop/working/")
tsne_3 = function(X, k = 2, initial_dims = 30, perplexity = 30, max_iter = 1000){
  set.seed(123)
  res1 = tsne::tsne(X, initial_config = NULL, k = k, initial_dims = initial_dims, perplexity = perplexity, max_iter = max_iter, min_cost = 0, epoch_callback = NULL, whiten = TRUE, epoch = 100)
  set.seed(123)
  res2 = Rtsne::Rtsne(X, dims = k, initial_dims = initial_dims, perplexity = perplexity, theta = 0.5, check_duplicates = FALSE, pca = TRUE, partial_pca = FALSE, max_iter, verbose = FALSE, is_distance = FALSE, Y_init = NULL, pca_center = TRUE, pca_scale = FALSE, normalize = TRUE, stop_lying_iter = 250, mom_switch_iter = 250, momentum = 0.5, final_momentum = 0.8, eta = 200.0, exaggeration_factor = 4.0, num_threads = 1)
  set.seed(123)
  res3 = mp::tSNE(X, Y = NULL, k = k, perplexity = perplexity, n.iter = max_iter, eta = 200, initial.momentum = 0.5, final.momentum = 0.8, early.exaggeration = 4.0, gain.fraction = 0.2, momentum.threshold.iter = 20, exaggeration.threshold.iter = 100, max.binsearch.tries = 30)
  return(list(tsne = res1, Rtsne = res2$Y, mp_tSNE = res3))
}

plot_tsne_3 = function(X, id, col, ...){
  if(missing(col))
    col = RColorBrewer::brewer.pal(n=length(levels(id)), name="Dark2")
  par(mfrow = c(3,1))
  lapply(names(X), function(i){
    par(mai=c(.4,.4,.5,1.3))
    par(xpd=FALSE)
    plot(X[[i]], type = "n", ann = FALSE)
    title(main = i, cex=4)
    grid()
    lapply(seq(levels(id)), function(j)
      points(X[[i]][as.numeric(id) == j,], pch = 21, col=NA, bg = alpha_brend(col = col[j]), cex = 2))
    par(xpd = TRUE)
    legend(x = par()$usr[2], y = par()$usr[4], legend = levels(id), pch = 16, col = col, cex=1.5)
  })
}

alpha_brend = function(col, alpha=.5){
  col = col2rgb(col)[,1]/255
  col = rgb(red = col[1], green = col[2], blue = col[3], alpha)
  return(col)
}

iris dataset

data = iris
data = data[apply(!is.na(data),1,all),]
X = data[, -5]
id = factor(data[, 5])
col = RColorBrewer::brewer.pal(n=length(levels(id)), name="Dark2")
res = tsne_3(X)

pdf("./iris_tsne3.pdf", family="Japan1GothicBBB", width=6, height=9)
par(mfrow = c(3,1))
lapply(names(res), function(i){
  par(mai=c(.4,.4,.5,1.3))
  par(xpd=FALSE)
  plot(res[[i]], type = "n", ann = FALSE)
  title(main = i, cex=4)
  grid()
  lapply(seq(levels(id)), function(j)
    points(res[[i]][as.numeric(id) == j,], pch = 21, col=NA, bg = alpha_brend(col = col[j]), cex = 2))
  par(xpd = TRUE)
  legend(x = par()$usr[2], y = par()$usr[4], legend = levels(id), pch = 16, col = col, cex=1.5)
})
dev.off()

mtcars dataset

data = mtcars
data = data[apply(!is.na(data),1,all),]
X = data[, -2]
id = factor(data[, 2])
col = RColorBrewer::brewer.pal(n=length(levels(id)), name="Dark2")
res = tsne_3(X)

pdf("./mtcars_tsne3.pdf", family="Japan1GothicBBB", width=6, height=9)
par(mfrow = c(3,1))
lapply(names(res), function(i){
  par(mai=c(.4,.4,.5,1.3))
  par(xpd=FALSE)
  plot(res[[i]], type = "n", ann = FALSE)
  title(main = i, cex=4)
  grid()
  lapply(seq(levels(id)), function(j)
    points(res[[i]][as.numeric(id) == j,], pch = 21, col=NA, bg = alpha_brend(col = col[j]), cex = 2))
  par(xpd = TRUE)
  legend(x = par()$usr[2], y = par()$usr[4], legend = levels(id), pch = 16, col = col, cex=1.5)
})
dev.off()

airquality data

data = airquality
data = data[apply(!is.na(data),1,all),]
X = data[, -5]
id = factor(data[, 5])
col = RColorBrewer::brewer.pal(n=length(levels(id)), name="Dark2")
res = tsne_3(X)

pdf("./airquality_tsne3.pdf", family="Japan1GothicBBB", width=6, height=9)
par(mfrow = c(3,1))
lapply(names(res), function(i){
  par(mai=c(.4,.4,.5,1.3))
  par(xpd=FALSE)
  plot(res[[i]], type = "n", ann = FALSE)
  title(main = i, cex=4)
  grid()
  lapply(seq(levels(id)), function(j)
    points(res[[i]][as.numeric(id) == j,], pch = 21, col=NA, bg = alpha_brend(col = col[j]), cex = 2))
  par(xpd = TRUE)
  legend(x = par()$usr[2], y = par()$usr[4], legend = levels(id), pch = 16, col = col, cex=1.5)
})
dev.off()

van der Maaten 氏がpureRというtsneライブラリを詳しく見ていこう。