Skip to content

Commit

Permalink
[r] add support for multiclass cliassification (#210)
Browse files Browse the repository at this point in the history
* model_info rebuilded

* new residual function

* Update misc_yhat.R

* add measures for multiclass classification

* typo fixes

* Fix for print

* Update model_performance.R

* Update NAMESPACE

* Update test_model_performance.R

* Requested changes

* Param names changed

* Turn off test for 3.5

* Switch codecov env to R3.6

* Update test-coverage.yaml

* Update test-coverage.yaml

* Update NEWS.md

* Rebuilded vignette

* gh-actions update

* ignore python for pkgdown

* Update R-CMD-check.yaml

* install ggpubr before dependencies

* what is a pillar package

* Extract residuals functions. Fix typos

* Ajust to changes in #219

* Fix dependencies of residual_function

Co-authored-by: hbaniecki <[email protected]>
Co-authored-by: Hubert Baniecki <[email protected]>
  • Loading branch information
3 people authored May 7, 2020
1 parent a400dd4 commit 05aa82a
Show file tree
Hide file tree
Showing 16 changed files with 236 additions and 110 deletions.
15 changes: 8 additions & 7 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ jobs:
fail-fast: false
matrix:
config:
- {os: macOS-latest, r: 'devel'}
- {os: macOS-latest, r: '4.0'}
- {os: windows-latest, r: '4.0'}
- {os: windows-latest, r: '3.6'}
- {os: macOS-latest, r: '3.6'}
- {os: macOS-latest, r: 'devel'}
- {os: ubuntu-16.04, r: '3.5', rspm: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}
- {os: ubuntu-16.04, r: '4.0', rspm: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}
- {os: ubuntu-16.04, r: '3.6', rspm: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}
- {os: ubuntu-16.04, r: '3.5', rspm: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}

env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
Expand Down Expand Up @@ -87,7 +89,7 @@ jobs:
Rscript -e "remotes::install_github('r-hub/sysreqs')"
sysreqs=$(Rscript -e "cat(sysreqs::sysreq_commands('DESCRIPTION'))")
sudo -s eval "$sysreqs"
sudo apt-get install -y qpdf
- name: Install dependencies
run: |
remotes::install_deps(dependencies = TRUE)
Expand All @@ -97,9 +99,8 @@ jobs:
shell: Rscript {0}

- name: Check
env:
_R_CHECK_CRAN_INCOMING_REMOTE_: false
run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran", "--run-donttest", "--run-dontrun"), error_on = "warning", check_dir = "check")
run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran", "--run-donttest", "--run-dontrun"),
build_args = "--compact-vignettes=no", error_on = "warning", check_dir = "check")
shell: Rscript {0}

- name: Upload check results
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ on:
paths-ignore:
- 'python/**'


name: pkgdown

jobs:
Expand All @@ -30,8 +29,8 @@ jobs:
uses: actions/cache@v1
with:
path: ${{ env.R_LIBS_USER }}
key: macOS-r-3.6-${{ hashFiles('.github/depends.Rds') }}
restore-keys: macOS-r-3.6-
key: macOS-r-4.0-1-${{ hashFiles('.github/depends.Rds') }}
restore-keys: macOS-r-4.0-1-

- name: Install dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ jobs:
uses: actions/cache@v1
with:
path: ${{ env.R_LIBS_USER }}
key: macOS-r-3.6-${{ hashFiles('.github/depends.Rds') }}
restore-keys: macOS-r-3.6-
key: macOS-r-4.0-1-${{ hashFiles('.github/depends.Rds') }}
restore-keys: macOS-r-4.0-1-

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: DALEX
Title: moDel Agnostic Language for Exploration and eXplanation
Version: 1.2.1
Version: 1.2.2
Authors@R: c(person("Przemyslaw", "Biecek", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-8423-1823")),
person("Szymon", "Maksymiuk", role = "aut"),
person("Hubert", "Baniecki", role = "aut",
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ importFrom(stats,model.frame)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,reorder)
importFrom(stats,weighted.mean)
importFrom(utils,head)
importFrom(utils,install.packages)
importFrom(utils,installed.packages)
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
DALEX 1.2.2
----------------------------------------------------------------
* `DALEX` now fully supports multiclass classification.
* `explain()` will use new residual function (1 - true class probability) if multiclass classification is detected.
* `model_performance()` now support measures for multiclass classification.
* Remove `ggpubr` from suggests.

DALEX 1.2.1
----------------------------------------------------------------
* fixed tests and WARNINGs on CRAN
Expand Down
59 changes: 41 additions & 18 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,6 @@ explain.default <- function(model, data = NULL, y = NULL, predict_function = NUL
}
}

if (is.null(model_info)) {
# extract defaults
model_info <- model_info(model)
verbose_cat(" -> model_info : package", model_info$package[1], ", ver.", model_info$ver[1], ", task", model_info$type, "(", color_codes$yellow_start,"default",color_codes$yellow_end, ")", "\n", verbose = verbose)
} else {
verbose_cat(" -> model_info : package", model_info$package[1], ", ver.", model_info$ver[1], ", task", model_info$type, "\n", verbose = verbose)
}
# if type specified then it overwrite the type in model_info
if (!is.null(type)) {
model_info$type <- type
verbose_cat(" -> model_info : type set to ", type, "\n", verbose = verbose)
}

# REPORT: checks for predict_function
if (is.null(predict_function)) {
# predict_function not specified
Expand Down Expand Up @@ -251,15 +238,30 @@ explain.default <- function(model, data = NULL, y = NULL, predict_function = NUL
}
}

if (is.null(model_info)) {
# extract defaults
task_subtype <- check_if_multilabel(model, predict_function, data[1:2,])
model_info <- model_info(model, is_multiclass = task_subtype)
verbose_cat(" -> model_info : package", model_info$package[1], ", ver.", model_info$ver[1], ", task", model_info$type, "(", color_codes$yellow_start,"default",color_codes$yellow_end, ")", "\n", verbose = verbose)
} else {
verbose_cat(" -> model_info : package", model_info$package[1], ", ver.", model_info$ver[1], ", task", model_info$type, "\n", verbose = verbose)
}
# if type specified then it overwrite the type in model_info
if (!is.null(type)) {
model_info$type <- type
verbose_cat(" -> model_info : type set to ", type, "\n", verbose = verbose)
}

# REPORT: checks for residual_function
if (is.null(residual_function)) {
# residual_function not specified
# try the default
if (!is.null(predict_function)) {
residual_function <- function(model, data, y) {
y - predict_function(model, data)
}
if (!is.null(predict_function) & model_info$type != "multiclass") {
residual_function <- default_residual_function
verbose_cat(" -> residual function : difference between y and yhat (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose)
} else if (!is.null(predict_function) & model_info$type == "multiclass") {
residual_function <- multiclass_residual_function
verbose_cat(" -> residual function : difference between 1 and probability of true class (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose)
}
} else {
if (!"function" %in% class(residual_function)) {
Expand All @@ -271,7 +273,7 @@ explain.default <- function(model, data = NULL, y = NULL, predict_function = NUL
# if data is specified then we may test residual_function
residuals <- NULL
if (!is.null(data) && !is.null(residual_function) && !is.null(y) && (verbose | precalculate)) {
residuals <- try(residual_function(model, data, y), silent = TRUE)
residuals <- try(residual_function(model, data, y, predict_function), silent = TRUE)
if (class(residuals)[1] == "try-error") {
residuals <- NULL
verbose_cat(" -> residuals : the residual_function returns an error when executed (",color_codes$red_start,"WARNING",color_codes$red_end,") \n", verbose = verbose)
Expand Down Expand Up @@ -317,6 +319,27 @@ is_y_in_data <- function(data, y) {
}))
}

# check if model whether model is multilabel classification task
check_if_multilabel <- function(model, predict_function, sample_data) {
response_sample <- try(predict_function(model, sample_data), silent = TRUE)
!is.null(dim(response_sample))
}

default_residual_function <- function(model, data, y, predict_function) {
y - predict_function(model, data)
}


multiclass_residual_function <- function(model, data, y, predict_function) {
y_char <- as.character(y)
pred <- predict_function(model, data)
res <- numeric(length(y))
for (i in 1:nrow(pred)) {
res[i] <- 1-pred[i, y_char[i]]
}
res
}

#' @rdname explain
#' @export
explain <- explain.default
Expand Down
14 changes: 14 additions & 0 deletions R/misc_yhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ yhat.svm <- function(X.model, newdata, ...) {
yhat.gbm <- function(X.model, newdata, ...) {
n.trees <- X.model$n.trees
response <- predict(X.model, newdata = newdata, n.trees = n.trees, type = "response")
#gbm returns and 3D array for multilabel classif
if(length(dim(response)) > 2){
response <- response[,,1]
}
response
}


Expand All @@ -90,6 +95,10 @@ yhat.cv.glmnet <- function(X.model, newdata, ...) {
}
if (!is.null(X.model$glmnet.fit$classnames)) {
pred <- predict(X.model, newdata, type = "response", s = X.model$lambda[length(X.model$lambda)])
#glmnet returns and 3D array for multilabel classif
if(length(dim(pred)) > 2){
return(pred[,,1])
}
if (ncol(pred) == 1) {
return(as.numeric(pred))
}
Expand All @@ -110,6 +119,10 @@ yhat.glmnet <- function(X.model, newdata, ...) {
}
if (!is.null(X.model$classnames)) {
pred <- predict(X.model, newdata, type = "response", s = X.model$lambda[length(X.model$lambda)])
#glmnet returns and 3D array for multilabel classif
if(length(dim(pred)) > 2){
return(pred[,,1])
}
# For binary classifiaction matrix with one column is returned
if (ncol(pred) == 1) {
return(as.numeric(pred))
Expand Down Expand Up @@ -143,6 +156,7 @@ yhat.ranger <- function(X.model, newdata, ...) {
yhat.model_fit <- function(X.model, newdata, ...) {
if (X.model$spec$mode == "classification") {
response <- as.matrix(predict(X.model, newdata, type = "prob"))
colnames(response) <- X.model$lvl
if (ncol(response) == 2) {
response <- response[,2]
}
Expand Down
Loading

0 comments on commit 05aa82a

Please sign in to comment.