Skip to content

Commit

Permalink
Use a helper to zip together parent and child captures for coroutine-…
Browse files Browse the repository at this point in the history
…closures
  • Loading branch information
compiler-errors committed Apr 8, 2024
1 parent 54a93ab commit 8af94ce
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 73 deletions.
47 changes: 47 additions & 0 deletions compiler/rustc_middle/src/ty/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{mir, ty};
use std::fmt::Write;

use crate::query::Providers;
use rustc_data_structures::captures::Captures;
use rustc_data_structures::fx::FxIndexMap;
use rustc_hir as hir;
use rustc_hir::def_id::LocalDefId;
Expand Down Expand Up @@ -415,6 +416,52 @@ impl BorrowKind {
}
}

pub fn analyze_coroutine_closure_captures<'a, 'tcx: 'a, T>(
parent_captures: impl IntoIterator<Item = &'a CapturedPlace<'tcx>>,
child_captures: impl IntoIterator<Item = &'a CapturedPlace<'tcx>>,
mut for_each: impl FnMut((usize, &'a CapturedPlace<'tcx>), (usize, &'a CapturedPlace<'tcx>)) -> T,
) -> impl Iterator<Item = T> + Captures<'a> + Captures<'tcx> {
let mut parent_captures = parent_captures.into_iter().enumerate().peekable();
// Make sure we use every field at least once, b/c why are we capturing something
// if it's not used in the inner coroutine.
let mut field_used_at_least_once = false;
child_captures.into_iter().enumerate().map(move |(child_field_idx, child_capture)| {
loop {
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
bug!("we ran out of parent captures!")
};

let HirPlaceBase::Upvar(parent_base) = parent_capture.place.base else {
bug!("expected capture to be an upvar");
};
let HirPlaceBase::Upvar(child_base) = child_capture.place.base else {
bug!("expected capture to be an upvar");
};

if parent_base.var_path.hir_id != child_base.var_path.hir_id
|| !std::iter::zip(
&child_capture.place.projections,
&parent_capture.place.projections,
)
.all(|(child, parent)| child.kind == parent.kind)
{
// Make sure the field was used at least once.
assert!(
field_used_at_least_once,
"we captured {parent_capture:#?} but it was not used in the child coroutine?"
);
field_used_at_least_once = false;
// Skip this field.
let _ = parent_captures.next().unwrap();
continue;
}

field_used_at_least_once = true;
break for_each((parent_field_idx, parent_capture), (child_field_idx, child_capture));
}
})
}

pub fn provide(providers: &mut Providers) {
*providers = Providers { closure_typeinfo, ..*providers }
}
7 changes: 4 additions & 3 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ pub use rustc_type_ir::ConstKind::{
pub use rustc_type_ir::*;

pub use self::closure::{
is_ancestor_or_same_capture, place_to_string_for_capture, BorrowKind, CaptureInfo,
CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap, MinCaptureList,
RootVariableMinCaptureList, UpvarCapture, UpvarId, UpvarPath, CAPTURE_STRUCT_LOCAL,
analyze_coroutine_closure_captures, is_ancestor_or_same_capture, place_to_string_for_capture,
BorrowKind, CaptureInfo, CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap,
MinCaptureList, RootVariableMinCaptureList, UpvarCapture, UpvarId, UpvarPath,
CAPTURE_STRUCT_LOCAL,
};
pub use self::consts::{
Const, ConstData, ConstInt, ConstKind, Expr, ScalarInt, UnevaluatedConst, ValTree,
Expand Down
80 changes: 10 additions & 70 deletions compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
use rustc_data_structures::unord::UnordMap;
use rustc_hir as hir;
use rustc_middle::hir::place::{PlaceBase, Projection, ProjectionKind};
use rustc_middle::hir::place::{Projection, ProjectionKind};
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::{self, dump_mir, MirPass};
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
Expand Down Expand Up @@ -124,62 +124,10 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
.tuple_fields()
.len();

let mut field_remapping = UnordMap::default();

// One parent capture may correspond to several child captures if we end up
// refining the set of captures via edition-2021 precise captures. We want to
// match up any number of child captures with one parent capture, so we keep
// peeking off this `Peekable` until the child doesn't match anymore.
let mut parent_captures =
tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
// Make sure we use every field at least once, b/c why are we capturing something
// if it's not used in the inner coroutine.
let mut field_used_at_least_once = false;

for (child_field_idx, child_capture) in tcx
.closure_captures(coroutine_def_id)
.iter()
.copied()
// By construction we capture all the args first.
.skip(num_args)
.enumerate()
{
loop {
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
bug!("we ran out of parent captures!")
};

let PlaceBase::Upvar(parent_base) = parent_capture.place.base else {
bug!("expected capture to be an upvar");
};
let PlaceBase::Upvar(child_base) = child_capture.place.base else {
bug!("expected capture to be an upvar");
};

assert!(
child_capture.place.projections.len() >= parent_capture.place.projections.len()
);
// A parent matches a child they share the same prefix of projections.
// The child may have more, if it is capturing sub-fields out of
// something that is captured by-move in the parent closure.
if parent_base.var_path.hir_id != child_base.var_path.hir_id
|| !std::iter::zip(
&child_capture.place.projections,
&parent_capture.place.projections,
)
.all(|(child, parent)| child.kind == parent.kind)
{
// Make sure the field was used at least once.
assert!(
field_used_at_least_once,
"we captured {parent_capture:#?} but it was not used in the child coroutine?"
);
field_used_at_least_once = false;
// Skip this field.
let _ = parent_captures.next().unwrap();
continue;
}

let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures(
tcx.closure_captures(parent_def_id).iter().copied(),
tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(),
|(parent_field_idx, parent_capture), (child_field_idx, child_capture)| {
// Store this set of additional projections (fields and derefs).
// We need to re-apply them later.
let child_precise_captures =
Expand Down Expand Up @@ -210,26 +158,18 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
),
};

field_remapping.insert(
(
FieldIdx::from_usize(child_field_idx + num_args),
(
FieldIdx::from_usize(parent_field_idx + num_args),
parent_capture_ty,
needs_deref,
child_precise_captures,
),
);

field_used_at_least_once = true;
break;
}
}

// Pop the last parent capture
if field_used_at_least_once {
let _ = parent_captures.next().unwrap();
}
assert_eq!(parent_captures.next(), None, "leftover parent captures?");
)
},
)
.collect();

if coroutine_kind == ty::ClosureKind::FnOnce {
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
Expand Down

0 comments on commit 8af94ce

Please sign in to comment.