We provide an introductory demo of the usage for the package. This package implements the transfer learning algorithms for high-dimensional generalized linear models (Tian and Feng (2021)).
Given the predictor x, if response y follows the generalized linear models (GLMs), then its distribution satisfies y|x ∼ ℙ(y|x) = ρ(y)exp {yxTw − ψ(xTw)}, where ψ′(xTw) = 𝔼(y|x) is called the (McCullagh and Nelder (1989)). Another important property is that Var(y|x) = ψ″(xTw), which is derived from the exponential family property. It is ψ which characterizes different GLMs. For example, in Gaussian model, we have the continuous response y and $\psi(u)=\frac{1}{2}u^2$; in the logistic model, y is binary and ψ(u) = log (1 + eu); and in Poisson model, we have the integral response y and ψ(u) = eu.
Consider the multi-source transfer learning problem. Suppose we have the data (X(0), Y(0)) = {xi(0), yi(0)}i = 1n0 and data {(X(k), Y(k))}k = 1K = {{(xi(k), yi(k))}i = 1nk}k = 1K for k = 1, …, K. Denote the target coefficient β = w(0). Suppose target and source data follow the GLM as y(k)|x(k) ∼ ℙ(y(k)|x(k)) = ρ(y(k))exp {y(k)(x(k))Tw(k) − ψ((x(k))Tw(k))}.
In order to borrow information from transferable sources, Bastani (2020) and Li, Cai, and Li (2020) developed for high-dimensional linear models. In the first step, an approximate estimator is achieved via the information from the target data and useful source data. In the second step, the target data is used to debias the estimator obtained from the first step, leading to the final estimator.
Tian and Feng (2021) extends the idea into GLM and proposes the corresponding algorithm, which can be easily applied when transferable sources are known. It is proved to enjoy a sharper bound of ℓ2-estimation error when the transferable source and target data are sufficiently similar.
In the multi-source transfer learning problem, some adversarial sources may share little similarity with the target, which can mislead the fitted model. We call this phenomenon as (Pan and Yang (2009), Torrey and Shavlik (2010), Weiss, Khoshgoftaar, and Wang (2016)).
To detect which sources are transferable, Tian and Feng (2021) develops an detection approach. Simply speaking, it tries to compute the gain for transferring each single source and compare it with the baseline where only the target data is used. The sources enjoying significant performance gain compared with the baseline are regarded as transferable ones. Tian and Feng (2021) also proves the detection consistency property for this method under the high-dimensional GLM setting.
The implementation of this package leverages on package
glmnet
, which applies the cyclic coordinate gradient
descent and is very efficient (Friedman, Hastie,
and Tibshirani (2010)). We use the argument offset
provided by the function glmnet
and cv.glmnet
to implement our two-step algorithms. Besides Lasso (Tibshirani (1996)), this package can adapt the
elastic net type penalty (Zou and Hastie
(2005)).
glmtrans
is now available on CRAN and can be easily
intalled by one-line code.
Then we can load the package:
In this section, we show the user how to use the provided functions to fit the model, make predictions and visualize the results. We take logistic data as an example.
We first generate some logistic data through function
models
. For target data, we set the coefficient vector
β = (0.5 ⋅ 1s, 0p − s),
where p = 500 and s = 10. For k in transferable source index set
𝒜, let w(k) = β + h/p ⋅ ℛp(k),
where h = 5 and ℛp(k)
are p independent Rademacher
variables (being −1 or 1 with equal probability) for any k. ℛp(k)
is independent with ℛp(k′)
for any k ≠ k′. The
coefficient of non-transferable sources is set to be ξ + h/p ⋅ ℛp(k).
And ξS′ = 0.5 ⋅ 12s, ξ(S′)c = 0p − 2s,
where S′ = S′1 ∪ S2′
and |S1′| = |S2′| = s = 10.
S1′ = {s + 1, …, 2s},
and S2′ is randomly
sampled from {2s + 1, …, p}. We also add
an intercept 0.5. The generating
procedure of each non-transferable source data is independent. The
target predictor $\bm{x}^{(k)}_i
\overset{i.i.d.}{\sim} N(\bm{0}, \bm{\Sigma})$, where Σ = (0.9|i − j|)p × p.
The source predictor $\bm{x}^{(k)}_i
\overset{i.i.d.}{\sim} t_4$. The target sample size n0 = 100 and each source
sample size nk = 100 for any
k = 1, …, K. Let
K = 5, 𝒜 = {1, 2, 3}.
We generate the training data as follows.
set.seed(1, kind = "L'Ecuyer-CMRG")
D.training <- models(family = "binomial", type = "all", cov.type = 2, Ka = 3,
K = 5, s = 10, n.target = 100, n.source = rep(100, 5))
Then suppose we know 𝒜, let’s fit an “oracle” GLM transfer learning model on the target data and source data in 𝒜 by the oracle algorithm. We denote this procedure as Oralce-Trans-GLM.
fit.oracle <- glmtrans(target = D.training$target, source = D.training$source,
family = "binomial", transfer.source.id = 1:3, cores = 2)
Notice that we set the argument transfer.source.id
equal
to 𝒜 = {1, 2, 3} to transfer only the
first three sources.
And the output of glmtrans
function is an object
belonging to S3 class “glmtrans”. It contains:
beta: the estimated coefficient vector.
family: the response type.
transfer.source.id: the transferable souce index. If in the
input, transfer.source.id = 1:length(source)
or
transfer.source.id = "all"
, then the outputed
transfer.source.id = 1:length(source)
. If the inputed
transfer.source.id = "auto"
, only transferable source
detected by the algorithm will be outputed.
fitting.list:
transfer.source.id = "auto"
.transfer.source.id = "auto"
.epsilon0
)⋅loss of
validation (cv) target data. Only available when
transfer.source.id = "auto"
.transfer.source.id = "auto"
.Then suppose we do not know 𝒜, let’s
set transfer.source.id = "auto"
to apply the transferable
source detection algorithm to get estimate $\hat{\mathcal{A}}$. After that,
glmtrans
will automatically run the oracle algorithm on
$\hat{\mathcal{A}}$ to fit the model.
We denote the approach as Trans-GLM.
fit.detection <- glmtrans(target = D.training$target, source = D.training$source,
family = "binomial", transfer.source.id = "auto", cores = 2)
## Loss difference between source data and the threshold: (negative to be transferable)
## Source 1: -0.046214
## Source 2: -0.064628
## Source 3: -0.105844
## Source 4: 0.113911
## Source 5: 0.211745
##
## Source data set(s) 1, 2, 3 are transferable!
From the results, we could see that 𝒜 = {1, 2, 3} is successfully detected via the detection algorithm. Next, to demonstrate the effectiveness of GLM transfer learning algorithm and the transferable source detection algorithm, we also fit the naive Lasso on target data (Lasso) and transfer learning model using all source data (Pooled-Trans-GLM) as baselines.
library(glmnet)
fit.lasso <- cv.glmnet(x = D.training$target$x, y = D.training$target$y,
family = "binomial")
fit.pooled <- glmtrans(target = D.training$target, source = D.training$source,
family = "binomial", transfer.source.id = "all", cores = 2)
Finally, we compare the ℓ2-estimation errors of the target coefficient β by different methods.
beta <- c(0, rep(0.5, 10), rep(0, 500 - 10))
er <- numeric(4)
names(er) <- c("Lasso", "Pooled-Trans-GLM", "Trans-GLM", "Oracle-Trans-GLM")
er["Lasso"] <- sqrt(sum((coef(fit.lasso) - beta)^2))
er["Pooled-Trans-GLM"] <- sqrt(sum((fit.pooled$beta - beta)^2))
er["Trans-GLM"] <- sqrt(sum((fit.detection$beta - beta)^2))
er["Oracle-Trans-GLM"] <- sqrt(sum((fit.oracle$beta - beta)^2))
er
## Lasso Pooled-Trans-GLM Trans-GLM Oracle-Trans-GLM
## 1.322086 1.298818 1.100324 1.185506
Note that the transfer learning models outperform the classical Lasso fitted on target data. And due to negative transfer, Pooled-Trans-GLM performs worse than Oracle-Trans-GLM. By correctly detecting 𝒜, the behavior of Trans-GLM mimics the oracle.
We could visualize the transferable source detection results by
applying plot
function on objects in class “glmtrans” or
“glmtrans_source_detection”. Loss of each source and the transferability
threshold will be drawed. Function glmtrans
outputs objects
in class “glmtrans”, while function source_detection
outputs objects in class “glmtrans_source_detection”. The function
source_detection
detects 𝒜
without the post-detecting model fitting step.
Call plot
function to visualize the results as
follows.