@@ -3,6 +3,7 @@ use crate::ty::print::{FmtPrinter, Printer};
3
3
use crate :: ty:: { self , Ty , TyCtxt , TypeFoldable , TypeSuperFoldable } ;
4
4
use crate :: ty:: { EarlyBinder , GenericArgs , GenericArgsRef , TypeVisitableExt } ;
5
5
use rustc_errors:: ErrorGuaranteed ;
6
+ use rustc_hir as hir;
6
7
use rustc_hir:: def:: Namespace ;
7
8
use rustc_hir:: def_id:: { CrateNum , DefId } ;
8
9
use rustc_hir:: lang_items:: LangItem ;
@@ -11,6 +12,7 @@ use rustc_macros::HashStable;
11
12
use rustc_middle:: ty:: normalize_erasing_regions:: NormalizationError ;
12
13
use rustc_span:: Symbol ;
13
14
15
+ use std:: assert_matches:: assert_matches;
14
16
use std:: fmt;
15
17
16
18
/// A monomorphized `InstanceDef`.
@@ -572,6 +574,54 @@ impl<'tcx> Instance<'tcx> {
572
574
Some ( Instance { def, args } )
573
575
}
574
576
577
+ pub fn try_resolve_item_for_coroutine (
578
+ tcx : TyCtxt < ' tcx > ,
579
+ trait_item_id : DefId ,
580
+ trait_id : DefId ,
581
+ rcvr_args : ty:: GenericArgsRef < ' tcx > ,
582
+ ) -> Option < Instance < ' tcx > > {
583
+ let ty:: Coroutine ( coroutine_def_id, args) = * rcvr_args. type_at ( 0 ) . kind ( ) else {
584
+ return None ;
585
+ } ;
586
+ let coroutine_kind = tcx. coroutine_kind ( coroutine_def_id) . unwrap ( ) ;
587
+
588
+ let lang_items = tcx. lang_items ( ) ;
589
+ let coroutine_callable_item = if Some ( trait_id) == lang_items. future_trait ( ) {
590
+ assert_matches ! (
591
+ coroutine_kind,
592
+ hir:: CoroutineKind :: Desugared ( hir:: CoroutineDesugaring :: Async , _)
593
+ ) ;
594
+ hir:: LangItem :: FuturePoll
595
+ } else if Some ( trait_id) == lang_items. iterator_trait ( ) {
596
+ assert_matches ! (
597
+ coroutine_kind,
598
+ hir:: CoroutineKind :: Desugared ( hir:: CoroutineDesugaring :: Gen , _)
599
+ ) ;
600
+ hir:: LangItem :: IteratorNext
601
+ } else if Some ( trait_id) == lang_items. async_iterator_trait ( ) {
602
+ assert_matches ! (
603
+ coroutine_kind,
604
+ hir:: CoroutineKind :: Desugared ( hir:: CoroutineDesugaring :: AsyncGen , _)
605
+ ) ;
606
+ hir:: LangItem :: AsyncIteratorPollNext
607
+ } else if Some ( trait_id) == lang_items. coroutine_trait ( ) {
608
+ assert_matches ! ( coroutine_kind, hir:: CoroutineKind :: Coroutine ( _) ) ;
609
+ hir:: LangItem :: CoroutineResume
610
+ } else {
611
+ return None ;
612
+ } ;
613
+
614
+ if tcx. lang_items ( ) . get ( coroutine_callable_item) == Some ( trait_item_id) {
615
+ Some ( Instance { def : ty:: InstanceDef :: Item ( coroutine_def_id) , args : args } )
616
+ } else {
617
+ // All other methods should be defaulted methods of the built-in trait.
618
+ // This is important for `Iterator`'s combinators, but also useful for
619
+ // adding future default methods to `Future`, for instance.
620
+ debug_assert ! ( tcx. defaultness( trait_item_id) . has_value( ) ) ;
621
+ Some ( Instance :: new ( trait_item_id, rcvr_args) )
622
+ }
623
+ }
624
+
575
625
/// Depending on the kind of `InstanceDef`, the MIR body associated with an
576
626
/// instance is expressed in terms of the generic parameters of `self.def_id()`, and in other
577
627
/// cases the MIR body is expressed in terms of the types found in the substitution array.
0 commit comments