@@ -2,7 +2,7 @@ use std::iter;
2
2
3
3
use ast:: make;
4
4
use either:: Either ;
5
- use hir:: { HirDisplay , InFile , Local , ModuleDef , Semantics , TypeInfo } ;
5
+ use hir:: { HasSource , HirDisplay , InFile , Local , ModuleDef , Semantics , TypeInfo } ;
6
6
use ide_db:: {
7
7
defs:: { Definition , NameRefClass } ,
8
8
famous_defs:: FamousDefs ,
@@ -27,6 +27,7 @@ use syntax::{
27
27
28
28
use crate :: {
29
29
assist_context:: { AssistContext , Assists , TreeMutator } ,
30
+ utils:: generate_impl_text,
30
31
AssistId ,
31
32
} ;
32
33
@@ -106,6 +107,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
106
107
let params =
107
108
body. extracted_function_params ( ctx, & container_info, locals_used. iter ( ) . copied ( ) ) ;
108
109
110
+ let extracted_from_trait_impl = body. extracted_from_trait_impl ( ) ;
111
+
109
112
let name = make_function_name ( & semantics_scope) ;
110
113
111
114
let fun = Function {
@@ -124,8 +127,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
124
127
125
128
builder. replace ( target_range, make_call ( ctx, & fun, old_indent) ) ;
126
129
127
- let fn_def = format_function ( ctx, module, & fun, old_indent, new_indent) ;
128
- let insert_offset = insert_after. text_range ( ) . end ( ) ;
130
+ let fn_def = match fun. self_param_adt ( ctx) {
131
+ Some ( adt) if extracted_from_trait_impl => {
132
+ let fn_def = format_function ( ctx, module, & fun, old_indent, new_indent + 1 ) ;
133
+ generate_impl_text ( & adt, & fn_def) . replace ( "{\n \n " , "{" )
134
+ }
135
+ _ => format_function ( ctx, module, & fun, old_indent, new_indent) ,
136
+ } ;
129
137
130
138
if fn_def. contains ( "ControlFlow" ) {
131
139
let scope = match scope {
@@ -150,6 +158,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
150
158
}
151
159
}
152
160
161
+ let insert_offset = insert_after. text_range ( ) . end ( ) ;
162
+
153
163
match ctx. config . snippet_cap {
154
164
Some ( cap) => builder. insert_snippet ( cap, insert_offset, fn_def) ,
155
165
None => builder. insert ( insert_offset, fn_def) ,
@@ -381,6 +391,14 @@ impl Function {
381
391
} ,
382
392
}
383
393
}
394
+
395
+ fn self_param_adt ( & self , ctx : & AssistContext ) -> Option < ast:: Adt > {
396
+ let self_param = self . self_param . as_ref ( ) ?;
397
+ let def = ctx. sema . to_def ( self_param) ?;
398
+ let adt = def. ty ( ctx. db ( ) ) . strip_references ( ) . as_adt ( ) ?;
399
+ let InFile { file_id : _, value } = adt. source ( ctx. db ( ) ) ?;
400
+ Some ( value)
401
+ }
384
402
}
385
403
386
404
impl ParamKind {
@@ -485,6 +503,20 @@ impl FunctionBody {
485
503
}
486
504
}
487
505
506
+ fn node ( & self ) -> & SyntaxNode {
507
+ match self {
508
+ FunctionBody :: Expr ( e) => e. syntax ( ) ,
509
+ FunctionBody :: Span { parent, .. } => parent. syntax ( ) ,
510
+ }
511
+ }
512
+
513
+ fn extracted_from_trait_impl ( & self ) -> bool {
514
+ match self . node ( ) . ancestors ( ) . find_map ( ast:: Impl :: cast) {
515
+ Some ( c) => return c. trait_ ( ) . is_some ( ) ,
516
+ None => false ,
517
+ }
518
+ }
519
+
488
520
fn from_expr ( expr : ast:: Expr ) -> Option < Self > {
489
521
match expr {
490
522
ast:: Expr :: BreakExpr ( it) => it. expr ( ) . map ( Self :: Expr ) ,
@@ -1111,10 +1143,7 @@ fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
1111
1143
///
1112
1144
/// Function should be put right after returned node
1113
1145
fn node_to_insert_after ( body : & FunctionBody , anchor : Anchor ) -> Option < SyntaxNode > {
1114
- let node = match body {
1115
- FunctionBody :: Expr ( e) => e. syntax ( ) ,
1116
- FunctionBody :: Span { parent, .. } => parent. syntax ( ) ,
1117
- } ;
1146
+ let node = body. node ( ) ;
1118
1147
let mut ancestors = node. ancestors ( ) . peekable ( ) ;
1119
1148
let mut last_ancestor = None ;
1120
1149
while let Some ( next_ancestor) = ancestors. next ( ) {
@@ -1126,9 +1155,8 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
1126
1155
break ;
1127
1156
}
1128
1157
}
1129
- SyntaxKind :: ASSOC_ITEM_LIST if !matches ! ( anchor, Anchor :: Method ) => {
1130
- continue ;
1131
- }
1158
+ SyntaxKind :: ASSOC_ITEM_LIST if !matches ! ( anchor, Anchor :: Method ) => continue ,
1159
+ SyntaxKind :: ASSOC_ITEM_LIST if body. extracted_from_trait_impl ( ) => continue ,
1132
1160
SyntaxKind :: ASSOC_ITEM_LIST => {
1133
1161
if ancestors. peek ( ) . map ( SyntaxNode :: kind) == Some ( SyntaxKind :: IMPL ) {
1134
1162
break ;
@@ -4777,6 +4805,43 @@ fn fun_name() {
4777
4805
fn $0fun_name2() {
4778
4806
let x = 0;
4779
4807
}
4808
+ "# ,
4809
+ ) ;
4810
+ }
4811
+
4812
+ #[ test]
4813
+ fn extract_method_from_trait_impl ( ) {
4814
+ check_assist (
4815
+ extract_function,
4816
+ r#"
4817
+ struct Struct(i32);
4818
+ trait Trait {
4819
+ fn bar(&self) -> i32;
4820
+ }
4821
+
4822
+ impl Trait for Struct {
4823
+ fn bar(&self) -> i32 {
4824
+ $0self.0 + 2$0
4825
+ }
4826
+ }
4827
+ "# ,
4828
+ r#"
4829
+ struct Struct(i32);
4830
+ trait Trait {
4831
+ fn bar(&self) -> i32;
4832
+ }
4833
+
4834
+ impl Trait for Struct {
4835
+ fn bar(&self) -> i32 {
4836
+ self.fun_name()
4837
+ }
4838
+ }
4839
+
4840
+ impl Struct {
4841
+ fn $0fun_name(&self) -> i32 {
4842
+ self.0 + 2
4843
+ }
4844
+ }
4780
4845
"# ,
4781
4846
) ;
4782
4847
}
0 commit comments