Skip to content

Commit

Permalink
indepth vignette small fixes (#412)
Browse files Browse the repository at this point in the history
* indepth vignette small fixes

* Update NEWS.md

* ps_replicate prefixes -> affixes

* typo fix

* ps_union postfix_names

* postfix in PSC

* tests

* document

* NEWS entry

* test fix
  • Loading branch information
mb706 authored Aug 26, 2024
1 parent 5a84ccb commit 87cb68d
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 54 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# paradox 1.0.1-9000

* `ParamSetCollection$flatten()` now detaches `$extra_trafo` completely from original ParamSetCollection.
* Option to postfix, instead of prefix, in `ParamSetCollection`, `c()`/`ps_union()`, and `ps_replicate()`.

# paradox 1.0.1

Expand Down
2 changes: 1 addition & 1 deletion R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ ParamSet = R6Class("ParamSet",
assert_list(x, names = "unique")
trafos = private$.trafos[names(x), .(id, trafo), nomatch = 0]
value = NULL # static checks
trafos[, value := x[id]]
if (nrow(trafos)) {
trafos[, value := x[id]]
transformed = pmap(trafos, function(id, trafo, value) trafo(value))
x = insert_named(x, set_names(transformed, trafos$id))
}
Expand Down
63 changes: 43 additions & 20 deletions R/ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
#' Whether to add tags of the form `"set_<set_id>"` to each parameter originating from a given `ParamSet` given with name `<set_id>`.
#' @param tag_params (`logical(1)`)\cr
#' Whether to add tags of the form `"param_<param_id>"` to each parameter with original ID `<param_id>`.
initialize = function(sets, tag_sets = FALSE, tag_params = FALSE) {
#' @param postfix_names (`logical(1)`)\cr
#' Whether to use the names inside `sets` as postfixes, rather than prefixes.
initialize = function(sets, tag_sets = FALSE, tag_params = FALSE, postfix_names = FALSE) {
assert_list(sets, types = "ParamSet")
assert_flag(tag_sets)
assert_flag(tag_params)
private$.postfix = assert_flag(postfix_names)

if (is.null(names(sets))) names(sets) = rep("", length(sets))

Expand All @@ -52,7 +55,11 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
set(params_child, , "original_id", params_child$id)
set(params_child, , "owner_ps_index", i)
set(params_child, , "owner_name", n)
if (n != "") set(params_child, , "id", sprintf("%s.%s", n, params_child$id))
if (n != "") {
set(params_child, , "id",
private$.add_name_prefix(n, params_child$id)
)
}
params_child
}
}), prototype = {
Expand All @@ -74,13 +81,13 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
if (tag_sets || tag_params) {
ids = s$.__enclos_env__$private$.params$id
newids = ids
if (n != "") newids = sprintf("%s.%s", n, ids)
if (n != "") newids = private$.add_name_prefix(n, ids)
}
tags_child = s$.__enclos_env__$private$.tags
list(
if (nrow(tags_child)) {
tags_child = copy(tags_child)
if (n != "") set(tags_child, , "id", sprintf("%s.%s", n, tags_child$id))
if (n != "") set(tags_child, , "id", private$.add_name_prefix(n, tags_child$id))
tags_child
},
if (tag_sets && n != "" && length(newids)) {
Expand Down Expand Up @@ -109,7 +116,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
trafos_child = s$.__enclos_env__$private$.trafos
if (nrow(trafos_child)) {
trafos_child = copy(trafos_child)
if (n != "" && nrow(trafos_child)) set(trafos_child, , "id", sprintf("%s.%s", n, trafos_child$id))
if (n != "" && nrow(trafos_child)) set(trafos_child, , "id", private$.add_name_prefix(n, trafos_child$id))
trafos_child
}
}), prototype = structure(list(
Expand Down Expand Up @@ -149,7 +156,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
}
pnames = p$ids()
nameclashes = intersect(
ifelse(n != "", sprintf("%s.%s", n, pnames), pnames),
ifelse(n != "", private$.add_name_prefix(n, pnames), pnames),
self$ids()
)
if (length(nameclashes)) {
Expand All @@ -158,7 +165,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,

new_index = length(private$.sets) + 1
paramtbl = p$params[, `:=`(original_id = id, owner_ps_index = new_index, owner_name = n)]
if (n != "") set(paramtbl, , "id", sprintf("%s.%s", n, paramtbl$id))
if (n != "") set(paramtbl, , "id", private$.add_name_prefix(n, paramtbl$id))

if (!nrow(paramtbl)) {
# when paramtbl is empty, use special setup to make sure information about the `.tags` column is present.
Expand Down Expand Up @@ -225,7 +232,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
} else if (prefix == "") {
info$owner_name
} else {
paste0(prefix, ".", info$owner_name)
private$.add_name_prefix(prefix, info$owner_name)
}

if (!test_class(subset, "ParamSetCollection")) return(prefix)
Expand All @@ -237,7 +244,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
xs = private$.params[list(id_), "cargo", on = "id"][[1]][[1]]$disable_in_tune
prefix = full_prefix(self, id_)
if (prefix == "") return(xs)
set_names(xs, paste0(full_prefix(self, id_), ".", names(xs)))
set_names(xs, private$.add_name_prefix(full_prefix(self, id_), names(xs)))
})) %??% named_list()
self$set_values(.values = pvs)
},
Expand Down Expand Up @@ -274,7 +281,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
} else if (prefix == "") {
info$owner_name
} else {
paste0(prefix, ".", info$owner_name)
private$.add_name_prefix(prefix, info$owner_name)
}
subset = get_private(param_set)$.sets[[info$owner_ps_index]]
if (!test_class(subset, "ParamSetCollection")) {
Expand All @@ -296,17 +303,17 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,

in_tune_fn = cargo$in_tune_fn

set_ids = info$ids
prefixed_set_ids = private$.add_name_prefix(prefix, info$ids)
cargo$in_tune_fn = crate(function(domain, param_vals) {
param_vals = param_vals[names(param_vals) %in% paste0(prefix, ".", set_ids)]
param_vals = param_vals[names(param_vals) %in% prefixed_set_ids]
names(param_vals) = gsub(sprintf("^\\Q%s.\\E", prefix), "", names(param_vals))
in_tune_fn(domain, param_vals)
}, in_tune_fn, prefix, set_ids)
}, in_tune_fn, prefix, prefixed_set_ids)

if (length(cargo$disable_in_tune)) {
cargo$disable_in_tune = set_names(
cargo$disable_in_tune,
paste0(prefix, ".", names(cargo$disable_in_tune))
private$.add_name_prefix(prefix, names(cargo$disable_in_tune))
)
}
cargo
Expand All @@ -328,7 +335,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
dd = s$deps
if (id != "" && nrow(dd)) {
ids_old = s$ids()
ids_new = sprintf("%s.%s", id, ids_old)
ids_new = private$.add_name_prefix(id, ids_old)
dd$id = map_values(dd$id, ids_old, ids_new)
dd$on = map_values(dd$on, ids_old, ids_new)
}
Expand Down Expand Up @@ -363,8 +370,23 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
),

private = list(
.postfix = FALSE,
.add_name_prefix = function(owner, id) {
if (private$.postfix) sprintf("%s.%s", id, owner) else sprintf("%s.%s", owner, id)
},
.get_values = function() {
vals = unlist(map(private$.sets, "values"), recursive = FALSE)
if (private$.postfix && !is.null(names(private$.sets))) {
vals = imap(private$.sets, function(x, n) {
vals_subset = x$values
if (nchar(n)) {
names(vals_subset) = sprintf("%s.%s", names(vals_subset), n)
}
vals_subset
})
vals = unlist(unname(vals), recursive = FALSE)
} else {
vals = unlist(map(private$.sets, "values"), recursive = FALSE)
}
if (length(vals)) vals else named_list()
},
.store_values = function(xs) {
Expand Down Expand Up @@ -393,7 +415,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
children_with_trafos = private$.children_with_trafos()
sets_with_trafos = private$.sets[children_with_trafos]
translation = private$.translation
psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation)
psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation, private$.postfix)
},
# get an extra_trafo function that does not have any references to the PSC object or any of its contained sets.
# This is used for flattening.
Expand All @@ -406,7 +428,8 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
}
if (!length(children_with_trafos)) return(NULL)
sets_with_trafos = lapply(private$.sets[children_with_trafos], function(x) x$clone(deep = TRUE)) # get new objects that are detached from PSC
crate(function(x) psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation), children_with_trafos, sets_with_trafos, translation, psc_extra_trafo)
postfix = private$.postfix
crate(function(x) psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation, postfix), children_with_trafos, sets_with_trafos, translation, psc_extra_trafo, postfix)
},
.constraint_explicit = function(x) {
children_with_constraints = private$.children_with_constraints()
Expand Down Expand Up @@ -447,7 +470,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
# We have this functoin outside of the ParamSetCollection class, because we anticipate that PSC can be "flattened", i.e. turned into
# a normal ParamSet. In that case, the resulting ParamSet's extra_trafo should be a function that can stand on its own, without
# referring to private$<anything>.
psc_extra_trafo = function(x, children_with_trafos, sets_with_trafos, translation) {
psc_extra_trafo = function(x, children_with_trafos, sets_with_trafos, translation, postfix) {
changed = unlist(lapply(seq_along(children_with_trafos), function(i) {
set_index = children_with_trafos[[i]]
changing_ids = translation[J(set_index), id, on = "owner_ps_index"]
Expand All @@ -464,7 +487,7 @@ psc_extra_trafo = function(x, children_with_trafos, sets_with_trafos, translatio
changing_values = trafo(changing_values_in)
prefix = names(sets_with_trafos)[[i]]
if (prefix != "") {
names(changing_values) = sprintf("%s.%s", prefix, names(changing_values))
names(changing_values) = if (postfix) sprintf("%s.%s", names(changing_values), prefix) else sprintf("%s.%s", prefix, names(changing_values))
}
changing_values
}), recursive = FALSE)
Expand Down
25 changes: 16 additions & 9 deletions R/ps_replicate.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
#' [`ParamSet`] to use as template.
#' @param times (`integer(1)`)\cr
#' Number of times to repeat `set`.
#' Should not be given if `prefixes` is provided.
#' @param prefixes (`character`)\cr
#' A `character` vector indicating the prefixes to use for each repetition of `set`.
#' If this is given, `times` is inferred from `length(prefixes)` and should not be given separately.
#' Should not be given if `affixes` is provided.
#' @param affixes (`character`)\cr
#' A `character` vector indicating the prefixes / postfixes to use for each repetition of `set`.
#' Per default, these are prefixes; if `postfix` is `TRUE`, these values are postfixed instead.
#' If this is given, `times` is inferred from `length(affixes)` and should not be given separately.
#' If `times` is given, this defaults to `"repX"`, with `X` counting up from 1.
#' @param postfix (`logical(1)`)\cr
#' Whether to use `affixes` as a postfix instead of a prefix.
#' Default `FALSE` (use prefixes).
#' @param tag_sets (`logical(1)`)\cr
#' Whether to add a tag of the form `"set_<prefixes[[i]]>"` to each parameter in the result, indicating the repetition each parameter belongs to.
#' Whether to add a tag of the form `"set_<affixes[[i]]>"` to each parameter in the result, indicating the repetition each parameter belongs to.
#' @param tag_params (`logical(1)`)\cr
#' Whether to add a tag of the form `"param_<id>"` to each parameter in the result, indicating the original parameter ID inside `set`.
#' @examples
Expand All @@ -26,7 +30,9 @@
#'
#' ps_replicate(pset, 3)
#'
#' ps_replicate(pset, prefixes = c("first", "last"))
#' ps_replicate(pset, affixes = c("first", "last"))
#'
#' ps_replicate(pset, affixes = c("first", "last"), postfix = TRUE)
#'
#' pset$values = list(i = 1, z = FALSE)
#'
Expand All @@ -52,9 +58,10 @@
#' # get all values associated with the first repetition "rep1"
#' psr$get_values(any_tags = "set_rep1")
#' @export
ps_replicate = function(set, times = length(prefixes), prefixes = sprintf("rep%s", seq_len(times)), tag_sets = FALSE, tag_params = FALSE) {
ps_replicate = function(set, times = length(affixes), affixes = sprintf("rep%s", seq_len(times)), postfix = FALSE, tag_sets = FALSE, tag_params = FALSE) {
assert_count(times)
assert_character(prefixes, any.missing = FALSE, unique = TRUE, len = times)
assert_character(affixes, any.missing = FALSE, unique = TRUE, len = times)
assert_flag(postfix)

ps_union(named_list(prefixes, set), tag_sets = tag_sets, tag_params = tag_params)
ps_union(named_list(affixes, set), postfix_names = postfix, tag_sets = tag_sets, tag_params = tag_params)
}
17 changes: 13 additions & 4 deletions R/ps_union.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
#' have their `<id>` changed to `<name in "sets">.<id>`. This is also reflected in deps.
#'
#' The `c()` operator, applied to [`ParamSet`]s, is a synony for `ps_union()`.
#' The named arguments `tag_sets`, `tag_params`, and `postfix_names` are also available in the `c()` operator, but need to be
#' used with a preceding dot instead: `.tag_sets`, `.tag_params`, and `.postfix_names`.
#' @param sets (`list` of [`ParamSet`])\cr
#' This may be a named list, in which case non-empty names are prefixed to parameters in the corresponding [`ParamSet`].
#' @param tag_sets (`logical(1)`)\cr
#' Whether to add tags of the form `"set_<set_id>"` to each parameter originating from a given `ParamSet` given with name `<name in "sets">`.
#' @param tag_params (`logical(1)`)\cr
#' Whether to add tags of the form `"param_<param_id>"` to each parameter with original ID `<param_id>`.
#' @param postfix_names (`logical(1)`)\cr
#' Whether to use names in `sets` as postfixes, instead of prefixes.
#' Default `FALSE`.
#' @examples
#' ps1 = ps(x = p_dbl())
#' ps1$values = list(x = 1)
Expand Down Expand Up @@ -43,14 +48,18 @@
#'
#' pu2$values
#'
#' pu3 = c(one = ps1, two = ps1, ps2, .postfix_names = TRUE)
#' pu3
#'
#'
#' @export
ps_union = function(sets, tag_sets = FALSE, tag_params = FALSE) {
ps_union = function(sets, tag_sets = FALSE, tag_params = FALSE, postfix_names = FALSE) {
assert_list(sets, types = "ParamSet")
if (!length(sets)) return(ParamSet$new())
ParamSetCollection$new(sets, tag_sets = tag_sets, tag_params = tag_params)$flatten()
ParamSetCollection$new(sets, tag_sets = tag_sets, tag_params = tag_params, postfix_names = postfix_names)$flatten()
}

#' @export
c.ParamSet = function(..., .tag_sets = FALSE, .tag_params = FALSE) {
ps_union(list(...), tag_sets = .tag_sets, tag_params = .tag_params)
c.ParamSet = function(..., .tag_sets = FALSE, .tag_params = FALSE, .postfix_names = FALSE) {
ps_union(list(...), tag_sets = .tag_sets, tag_params = .tag_params, postfix_names = .postfix_names)
}
10 changes: 9 additions & 1 deletion man/ParamSetCollection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 16 additions & 8 deletions man/ps_replicate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 87cb68d

Please sign in to comment.