Skip to content

Commit a62401b

Browse files
committed
Fix #1199, fix #1200
1 parent 597f032 commit a62401b

File tree

4 files changed

+148
-4
lines changed

4 files changed

+148
-4
lines changed

NEWS.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Version 7.11.0.9000
22

3+
## Bug fixes
4+
5+
* Restrict static transforms so they only use the upstream part of the plan (#1199, #1200, @bart1).
36

47
# Version 7.11.0
58

R/igraph.R

+9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ downstream_nodes <- function(graph, from) {
1818
)
1919
}
2020

21+
upstream_nodes <- function(graph, from) {
22+
nbhd_vertices(
23+
graph = graph,
24+
vertices = from,
25+
mode = "in",
26+
order = igraph::gorder(graph)
27+
)
28+
}
29+
2130
nbhd_graph <- function(graph, vertices, mode, order) {
2231
vertices <- nbhd_vertices(
2332
graph = graph,

R/transform_plan.R

+12-4
Original file line numberDiff line numberDiff line change
@@ -277,19 +277,27 @@ transform_plan_ <- function(
277277
plan$transform <- lapply(plan$transform, parse_transform)
278278
graph <- dsl_graph(plan)
279279
order <- igraph::topo_sort(graph)$name
280+
subplans <- split(plan, f = plan$target)
280281
for (target in order) {
281-
index <- which(target == plan$target)
282-
rows <- transform_row(index, plan, graph, max_expand)
283-
plan <- sub_in_plan(plan, rows, index)
284-
old_cols(plan) <- old_cols
282+
upstream_plan <- dsl_upstream_plan(target, graph, subplans)
283+
index <- which(target == upstream_plan$target)
284+
old_cols(upstream_plan) <- old_cols
285+
subplans[[target]] <- transform_row(index, upstream_plan, graph, max_expand)
285286
}
287+
plan <- drake_bind_rows(subplans)
288+
old_cols(plan) <- old_cols
286289
plan <- dsl_trace(plan = plan, trace = trace)
287290
old_cols(plan) <- plan$transform <- NULL
288291
plan <- dsl_tidy_eval(plan = plan, tidy_eval = tidy_eval, envir = envir)
289292
plan <- dsl_sanitize(plan = plan, sanitize = sanitize, envir = envir)
290293
plan
291294
}
292295

296+
dsl_upstream_plan <- function(target, graph, subplans) {
297+
upstream_targets <- upstream_nodes(graph, target)
298+
drake_bind_rows(subplans[upstream_targets])
299+
}
300+
293301
dsl_trace <- function(plan, trace) {
294302
if (!trace) {
295303
keep <- as.character(intersect(colnames(plan), old_cols(plan)))

tests/testthat/test-7-dsl.R

+124
Original file line numberDiff line numberDiff line change
@@ -3078,3 +3078,127 @@ test_with_dir("NAs removed from old grouping vars grid (#1010)", {
30783078
)
30793079
equivalent_plans(out, exp)
30803080
})
3081+
3082+
test_with_dir("static transforms use only upstream part of plan (#1199)", {
3083+
skip_on_cran()
3084+
radars <- c("radar1", "radar2")
3085+
seasons <- c("season1", "season2")
3086+
months <- c(1, 2)
3087+
radar_seasons <- expand.grid(
3088+
radar = radars,
3089+
season = seasons,
3090+
stringsAsFactors = FALSE
3091+
)
3092+
out <- drake_plan(
3093+
data = target(
3094+
get_data(radar, month),
3095+
transform = cross(radar = !!radars, month = !!months)
3096+
),
3097+
to_cross = target(
3098+
list(data),
3099+
transform = combine(data, .by = radar)
3100+
),
3101+
problem = target(
3102+
list(to_cross, season),
3103+
transform = cross(to_cross, season = !!seasons)
3104+
),
3105+
separate = target(
3106+
list(radar, season),
3107+
transform = map(.data = !!radar_seasons)
3108+
),
3109+
trace = TRUE
3110+
)
3111+
exp <- drake_plan(
3112+
data_radar1_1 = target(
3113+
command = get_data("radar1", 1),
3114+
radar = "\"radar1\"",
3115+
month = "1",
3116+
data = "data_radar1_1"
3117+
),
3118+
data_radar2_1 = target(
3119+
command = get_data("radar2", 1),
3120+
radar = "\"radar2\"",
3121+
month = "1",
3122+
data = "data_radar2_1"
3123+
),
3124+
data_radar1_2 = target(
3125+
command = get_data("radar1", 2),
3126+
radar = "\"radar1\"",
3127+
month = "2",
3128+
data = "data_radar1_2"
3129+
),
3130+
data_radar2_2 = target(
3131+
command = get_data("radar2", 2),
3132+
radar = "\"radar2\"",
3133+
month = "2",
3134+
data = "data_radar2_2"
3135+
),
3136+
problem_season1_to_cross_radar1 = target(
3137+
command = list(to_cross_radar1, "season1"),
3138+
radar = "\"radar1\"",
3139+
season = "\"season1\"",
3140+
separate = "separate_radar1_season1",
3141+
to_cross = "to_cross_radar1",
3142+
problem = "problem_season1_to_cross_radar1"
3143+
),
3144+
problem_season2_to_cross_radar1 = target(
3145+
command = list(to_cross_radar1, "season2"),
3146+
radar = "\"radar1\"",
3147+
season = "\"season2\"",
3148+
separate = "separate_radar1_season2",
3149+
to_cross = "to_cross_radar1",
3150+
problem = "problem_season2_to_cross_radar1"
3151+
),
3152+
problem_season1_to_cross_radar2 = target(
3153+
command = list(to_cross_radar2, "season1"),
3154+
radar = "\"radar1\"",
3155+
season = "\"season1\"",
3156+
separate = "separate_radar1_season1",
3157+
to_cross = "to_cross_radar2",
3158+
problem = "problem_season1_to_cross_radar2"
3159+
),
3160+
problem_season2_to_cross_radar2 = target(
3161+
command = list(to_cross_radar2, "season2"),
3162+
radar = "\"radar1\"",
3163+
season = "\"season2\"",
3164+
separate = "separate_radar1_season2",
3165+
to_cross = "to_cross_radar2",
3166+
problem = "problem_season2_to_cross_radar2"
3167+
),
3168+
separate_radar1_season1 = target(
3169+
command = list("radar1", "season1"),
3170+
radar = "\"radar1\"",
3171+
season = "\"season1\"",
3172+
separate = "separate_radar1_season1"
3173+
),
3174+
separate_radar2_season1 = target(
3175+
command = list("radar2", "season1"),
3176+
radar = "\"radar2\"",
3177+
season = "\"season1\"",
3178+
separate = "separate_radar2_season1"
3179+
),
3180+
separate_radar1_season2 = target(
3181+
command = list("radar1", "season2"),
3182+
radar = "\"radar1\"",
3183+
season = "\"season2\"",
3184+
separate = "separate_radar1_season2"
3185+
),
3186+
separate_radar2_season2 = target(
3187+
command = list("radar2", "season2"),
3188+
radar = "\"radar2\"",
3189+
season = "\"season2\"",
3190+
separate = "separate_radar2_season2"
3191+
),
3192+
to_cross_radar1 = target(
3193+
command = list(data_radar1_1, data_radar1_2),
3194+
radar = "\"radar1\"",
3195+
to_cross = "to_cross_radar1"
3196+
),
3197+
to_cross_radar2 = target(
3198+
command = list(data_radar2_1, data_radar2_2),
3199+
radar = "\"radar2\"",
3200+
to_cross = "to_cross_radar2"
3201+
)
3202+
)
3203+
equivalent_plans(out, exp)
3204+
})

0 commit comments

Comments
 (0)