Skip to content

Commit aa029d4

Browse files
committed
Support conditional drops
This adds support for branching and merging control flow and uses this to correctly handle the case where a value is dropped in one branch of an if expression but not another. There are other cases we need to handle, which will come in follow up patches. Issue rust-lang#57478
1 parent f246c0b commit aa029d4

File tree

4 files changed

+220
-28
lines changed

4 files changed

+220
-28
lines changed

Cargo.lock

+1
Original file line numberDiff line numberDiff line change
@@ -4383,6 +4383,7 @@ dependencies = [
43834383
name = "rustc_typeck"
43844384
version = "0.0.0"
43854385
dependencies = [
4386+
"itertools 0.9.0",
43864387
"rustc_arena",
43874388
"rustc_ast",
43884389
"rustc_attr",

compiler/rustc_typeck/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ test = false
88
doctest = false
99

1010
[dependencies]
11+
itertools = "0.9"
1112
rustc_arena = { path = "../rustc_arena" }
1213
tracing = "0.1"
1314
rustc_macros = { path = "../rustc_macros" }

compiler/rustc_typeck/src/check/generator_interior.rs

+196-28
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
//! is calculated in `rustc_const_eval::transform::generator` and may be a subset of the
44
//! types computed here.
55
6+
use std::mem;
7+
68
use crate::expr_use_visitor::{self, ExprUseVisitor};
79

810
use super::FnCtxt;
911
use hir::{HirIdMap, Node};
12+
use itertools::Itertools;
1013
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
1114
use rustc_errors::pluralize;
1215
use rustc_hir as hir;
@@ -24,6 +27,9 @@ use rustc_span::Span;
2427
use smallvec::SmallVec;
2528
use tracing::debug;
2629

30+
#[cfg(test)]
31+
mod tests;
32+
2733
struct InteriorVisitor<'a, 'tcx> {
2834
fcx: &'a FnCtxt<'a, 'tcx>,
2935
types: FxIndexSet<ty::GeneratorInteriorTypeCause<'tcx>>,
@@ -80,7 +86,9 @@ impl<'a, 'tcx> InteriorVisitor<'a, 'tcx> {
8086
);
8187

8288
match self.drop_ranges.get(&hir_id) {
83-
Some(range) if range.contains(yield_data.expr_and_pat_count) => {
89+
Some(range)
90+
if range.is_dropped_at(yield_data.expr_and_pat_count) =>
91+
{
8492
debug!("value is dropped at yield point; not recording");
8593
return false;
8694
}
@@ -229,7 +237,7 @@ pub fn resolve_interior<'a, 'tcx>(
229237
hir: fcx.tcx.hir(),
230238
consumed_places: <_>::default(),
231239
borrowed_places: <_>::default(),
232-
drop_ranges: vec![<_>::default()],
240+
drop_ranges: <_>::default(),
233241
expr_count: 0,
234242
};
235243

@@ -254,7 +262,7 @@ pub fn resolve_interior<'a, 'tcx>(
254262
guard_bindings: <_>::default(),
255263
guard_bindings_set: <_>::default(),
256264
linted_values: <_>::default(),
257-
drop_ranges: drop_range_visitor.drop_ranges.pop().unwrap(),
265+
drop_ranges: drop_range_visitor.drop_ranges,
258266
}
259267
};
260268
intravisit::walk_body(&mut visitor, body);
@@ -671,7 +679,7 @@ struct DropRangeVisitor<'tcx> {
671679
/// Maps a HirId to a set of HirIds that are dropped by that node.
672680
consumed_places: HirIdMap<HirIdSet>,
673681
borrowed_places: HirIdSet,
674-
drop_ranges: Vec<HirIdMap<DropRange>>,
682+
drop_ranges: HirIdMap<DropRange>,
675683
expr_count: usize,
676684
}
677685

@@ -684,28 +692,42 @@ impl DropRangeVisitor<'tcx> {
684692
}
685693

686694
fn record_drop(&mut self, hir_id: HirId) {
687-
let drop_ranges = self.drop_ranges.last_mut().unwrap();
695+
let drop_ranges = &mut self.drop_ranges;
688696
if self.borrowed_places.contains(&hir_id) {
689697
debug!("not marking {:?} as dropped because it is borrowed at some point", hir_id);
690698
} else {
691699
debug!("marking {:?} as dropped at {}", hir_id, self.expr_count);
692-
drop_ranges.insert(hir_id, DropRange { dropped_at: self.expr_count });
700+
drop_ranges.insert(hir_id, DropRange::new(self.expr_count));
693701
}
694702
}
695703

696-
fn push_drop_scope(&mut self) {
697-
self.drop_ranges.push(<_>::default());
704+
fn swap_drop_ranges(&mut self, mut other: HirIdMap<DropRange>) -> HirIdMap<DropRange> {
705+
mem::swap(&mut self.drop_ranges, &mut other);
706+
other
698707
}
699708

700-
fn pop_and_merge_drop_scope(&mut self) {
701-
let mut old_last = self.drop_ranges.pop().unwrap();
702-
let drop_ranges = self.drop_ranges.last_mut().unwrap();
703-
for (k, v) in old_last.drain() {
704-
match drop_ranges.get(&k).cloned() {
705-
Some(v2) => drop_ranges.insert(k, v.intersect(&v2)),
706-
None => drop_ranges.insert(k, v),
707-
};
708-
}
709+
#[allow(dead_code)]
710+
fn fork_drop_ranges(&self) -> HirIdMap<DropRange> {
711+
self.drop_ranges.iter().map(|(k, v)| (*k, v.fork_at(self.expr_count))).collect()
712+
}
713+
714+
fn intersect_drop_ranges(&mut self, drops: HirIdMap<DropRange>) {
715+
drops.into_iter().for_each(|(k, v)| match self.drop_ranges.get_mut(&k) {
716+
Some(ranges) => *ranges = ranges.intersect(&v),
717+
None => {
718+
self.drop_ranges.insert(k, v);
719+
}
720+
})
721+
}
722+
723+
#[allow(dead_code)]
724+
fn merge_drop_ranges(&mut self, drops: HirIdMap<DropRange>) {
725+
drops.into_iter().for_each(|(k, v)| {
726+
if !self.drop_ranges.contains_key(&k) {
727+
self.drop_ranges.insert(k, DropRange { events: vec![] });
728+
}
729+
self.drop_ranges.get_mut(&k).unwrap().merge_with(&v, self.expr_count);
730+
});
709731
}
710732

711733
/// ExprUseVisitor's consume callback doesn't go deep enough for our purposes in all
@@ -751,7 +773,10 @@ impl<'tcx> expr_use_visitor::Delegate<'tcx> for DropRangeVisitor<'tcx> {
751773
Some(parent) => parent,
752774
None => place_with_id.hir_id,
753775
};
754-
debug!("consume {:?}; diag_expr_id={:?}, using parent {:?}", place_with_id, diag_expr_id, parent);
776+
debug!(
777+
"consume {:?}; diag_expr_id={:?}, using parent {:?}",
778+
place_with_id, diag_expr_id, parent
779+
);
755780
self.mark_consumed(parent, place_with_id.hir_id);
756781
place_hir_id(&place_with_id.place).map(|place| self.mark_consumed(parent, place));
757782
}
@@ -800,15 +825,47 @@ impl<'tcx> Visitor<'tcx> for DropRangeVisitor<'tcx> {
800825
self.visit_expr(lhs);
801826
self.visit_expr(rhs);
802827

803-
self.push_drop_scope();
828+
let old_drops = self.swap_drop_ranges(<_>::default());
804829
std::mem::swap(&mut old_count, &mut self.expr_count);
805830
self.visit_expr(rhs);
806831
self.visit_expr(lhs);
807832

808833
// We should have visited the same number of expressions in either order.
809834
assert_eq!(old_count, self.expr_count);
810835

811-
self.pop_and_merge_drop_scope();
836+
self.intersect_drop_ranges(old_drops);
837+
}
838+
ExprKind::If(test, if_true, if_false) => {
839+
self.visit_expr(test);
840+
841+
match if_false {
842+
Some(if_false) => {
843+
let mut true_ranges = self.fork_drop_ranges();
844+
let mut false_ranges = self.fork_drop_ranges();
845+
846+
true_ranges = self.swap_drop_ranges(true_ranges);
847+
self.visit_expr(if_true);
848+
true_ranges = self.swap_drop_ranges(true_ranges);
849+
850+
false_ranges = self.swap_drop_ranges(false_ranges);
851+
self.visit_expr(if_false);
852+
false_ranges = self.swap_drop_ranges(false_ranges);
853+
854+
self.merge_drop_ranges(true_ranges);
855+
self.merge_drop_ranges(false_ranges);
856+
}
857+
None => {
858+
let mut true_ranges = self.fork_drop_ranges();
859+
debug!("true branch drop range fork: {:?}", true_ranges);
860+
true_ranges = self.swap_drop_ranges(true_ranges);
861+
self.visit_expr(if_true);
862+
true_ranges = self.swap_drop_ranges(true_ranges);
863+
debug!("true branch computed drop_ranges: {:?}", true_ranges);
864+
debug!("drop ranges before merging: {:?}", self.drop_ranges);
865+
self.merge_drop_ranges(true_ranges);
866+
debug!("drop ranges after merging: {:?}", self.drop_ranges);
867+
}
868+
}
812869
}
813870
_ => intravisit::walk_expr(self, expr),
814871
}
@@ -825,20 +882,131 @@ impl<'tcx> Visitor<'tcx> for DropRangeVisitor<'tcx> {
825882
}
826883
}
827884

828-
#[derive(Clone)]
885+
#[derive(Clone, Debug, PartialEq, Eq)]
886+
enum Event {
887+
Drop(usize),
888+
Reinit(usize),
889+
}
890+
891+
impl Event {
892+
fn location(&self) -> usize {
893+
match *self {
894+
Event::Drop(i) | Event::Reinit(i) => i,
895+
}
896+
}
897+
}
898+
899+
impl PartialOrd for Event {
900+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
901+
self.location().partial_cmp(&other.location())
902+
}
903+
}
904+
905+
#[derive(Clone, Debug, PartialEq, Eq)]
829906
struct DropRange {
830-
/// The post-order id of the point where this expression is dropped.
831-
///
832-
/// We can consider the value dropped at any post-order id greater than dropped_at.
833-
dropped_at: usize,
907+
events: Vec<Event>,
834908
}
835909

836910
impl DropRange {
911+
fn new(begin: usize) -> Self {
912+
Self { events: vec![Event::Drop(begin)] }
913+
}
914+
837915
fn intersect(&self, other: &Self) -> Self {
838-
Self { dropped_at: self.dropped_at.max(other.dropped_at) }
916+
let mut events = vec![];
917+
self.events
918+
.iter()
919+
.merge_join_by(other.events.iter(), |a, b| a.partial_cmp(b).unwrap())
920+
.fold((false, false), |(left, right), event| match event {
921+
itertools::EitherOrBoth::Both(_, _) => todo!(),
922+
itertools::EitherOrBoth::Left(e) => match e {
923+
Event::Drop(i) => {
924+
if !left && right {
925+
events.push(Event::Drop(*i));
926+
}
927+
(true, right)
928+
}
929+
Event::Reinit(i) => {
930+
if left && !right {
931+
events.push(Event::Reinit(*i));
932+
}
933+
(false, right)
934+
}
935+
},
936+
itertools::EitherOrBoth::Right(e) => match e {
937+
Event::Drop(i) => {
938+
if left && !right {
939+
events.push(Event::Drop(*i));
940+
}
941+
(left, true)
942+
}
943+
Event::Reinit(i) => {
944+
if !left && right {
945+
events.push(Event::Reinit(*i));
946+
}
947+
(left, false)
948+
}
949+
},
950+
});
951+
Self { events }
952+
}
953+
954+
fn is_dropped_at(&self, id: usize) -> bool {
955+
match self.events.iter().try_fold(false, |is_dropped, event| {
956+
if event.location() < id {
957+
Ok(match event {
958+
Event::Drop(_) => true,
959+
Event::Reinit(_) => false,
960+
})
961+
} else {
962+
Err(is_dropped)
963+
}
964+
}) {
965+
Ok(is_dropped) | Err(is_dropped) => is_dropped,
966+
}
967+
}
968+
969+
#[allow(dead_code)]
970+
fn drop(&mut self, location: usize) {
971+
self.events.push(Event::Drop(location))
972+
}
973+
974+
#[allow(dead_code)]
975+
fn reinit(&mut self, location: usize) {
976+
self.events.push(Event::Reinit(location));
977+
}
978+
979+
/// Merges another range with this one. Meant to be used at control flow join points.
980+
///
981+
/// After merging, the value will be dead at the end of the range only if it was dead
982+
/// at the end of both self and other.
983+
///
984+
/// Assumes that all locations in each range are less than joinpoint
985+
#[allow(dead_code)]
986+
fn merge_with(&mut self, other: &DropRange, join_point: usize) {
987+
let mut events: Vec<_> =
988+
self.events.iter().merge(other.events.iter()).dedup().cloned().collect();
989+
990+
events.push(if self.is_dropped_at(join_point) && other.is_dropped_at(join_point) {
991+
Event::Drop(join_point)
992+
} else {
993+
Event::Reinit(join_point)
994+
});
995+
996+
self.events = events;
839997
}
840998

841-
fn contains(&self, id: usize) -> bool {
842-
id > self.dropped_at
999+
/// Creates a new DropRange from this one at the split point.
1000+
///
1001+
/// Used to model branching control flow.
1002+
#[allow(dead_code)]
1003+
fn fork_at(&self, split_point: usize) -> Self {
1004+
Self {
1005+
events: vec![if self.is_dropped_at(split_point) {
1006+
Event::Drop(split_point)
1007+
} else {
1008+
Event::Reinit(split_point)
1009+
}],
1010+
}
8431011
}
8441012
}

src/test/ui/generator/drop-if.rs

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// build-pass
2+
3+
// This test case is reduced from src/test/ui/drop/dynamic-drop-async.rs
4+
5+
#![feature(generators)]
6+
7+
struct Ptr;
8+
impl<'a> Drop for Ptr {
9+
fn drop(&mut self) {
10+
}
11+
}
12+
13+
fn main() {
14+
let arg = true;
15+
let _ = || {
16+
let arr = [Ptr];
17+
if arg {
18+
drop(arr);
19+
}
20+
yield
21+
};
22+
}

0 commit comments

Comments
 (0)