Skip to content

Commit

Permalink
fix: Support generics in extract_function assist
Browse files Browse the repository at this point in the history
This change attempts to resolve issue rust-lang#7637: Extract into Function does not
create a generic function with constraints when extracting generic code.

In `FunctionBody::analyze_container`, when the ancestor matches `ast::Fn`, we
can perserve both the `generic_param_list` and the `where_clause`. These can
then be included in the newly extracted function output via `format_function`.

From what I can tell, the only other ancestor type that could potentially have
a generic param list would be `ast::ClosureExpr`. In this case, we perserve the
`generic_param_list`, but no where clause is ever present.

In this inital implementation, all the generic params and where clauses from
the parent function will be copied to the newly extracted function. An obvious
improvement would be to filter this output in some way to only include generic
parameters that are actually used in the function body. I'm not experienced
enough with this codebase to know how challenging doing this kind of filtration
would be.

I don't believe this implementation will work in contexts where the generic
parameters and where clauses are defined multiple layers above the function
being extracted, such as with nested function declarations. Resolving this
seems like another obvious improvement, but one that will potentially require
more significant changes to the structure of `analyze_container` that I wasn't
comfortable trying to make as a first change.
  • Loading branch information
DorianListens committed Jun 30, 2022
1 parent ce36446 commit e2fea73
Showing 1 changed file with 86 additions and 15 deletions.
101 changes: 86 additions & 15 deletions crates/ide-assists/src/handlers/extract_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use syntax::{
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
AstNode,
AstNode, HasGenericParams,
},
match_ast, ted, SyntaxElement,
SyntaxKind::{self, COMMENT},
Expand Down Expand Up @@ -266,6 +266,8 @@ struct ContainerInfo {
parent_loop: Option<SyntaxNode>,
/// The function's return type, const's type etc.
ret_type: Option<hir::Type>,
generic_param_list: Option<ast::GenericParamList>,
where_clause: Option<ast::WhereClause>,
}

/// Control flow that is exported from extracted function
Expand Down Expand Up @@ -676,11 +678,11 @@ impl FunctionBody {
parent_loop.get_or_insert(loop_.syntax().clone());
}
};
let (is_const, expr, ty) = loop {
let (is_const, expr, ty, generic_param_list, where_clause) = loop {
let anc = ancestors.next()?;
break match_ast! {
match anc {
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body())),
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body()), closure.generic_param_list(), None),
ast::BlockExpr(block_expr) => {
let (constness, block) = match block_expr.modifier() {
Some(ast::BlockModifier::Const(_)) => (true, block_expr),
Expand All @@ -689,7 +691,7 @@ impl FunctionBody {
_ => continue,
};
let expr = Some(ast::Expr::BlockExpr(block));
(constness, expr.clone(), infer_expr_opt(expr))
(constness, expr.clone(), infer_expr_opt(expr), None, None)
},
ast::Fn(fn_) => {
let func = sema.to_def(&fn_)?;
Expand All @@ -699,23 +701,23 @@ impl FunctionBody {
ret_ty = async_ret;
}
}
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty))
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty), fn_.generic_param_list(), fn_.where_clause())
},
ast::Static(statik) => {
(true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db)))
(true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db)), None, None)
},
ast::ConstArg(ca) => {
(true, ca.expr(), infer_expr_opt(ca.expr()))
(true, ca.expr(), infer_expr_opt(ca.expr()), None, None)
},
ast::Const(konst) => {
(true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db)))
(true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db)), None, None)
},
ast::ConstParam(cp) => {
(true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db)))
(true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db)), None, None)
},
ast::ConstBlockPat(cbp) => {
let expr = cbp.block_expr().map(ast::Expr::BlockExpr);
(true, expr.clone(), infer_expr_opt(expr))
(true, expr.clone(), infer_expr_opt(expr), None, None)
},
ast::Variant(__) => return None,
ast::Meta(__) => return None,
Expand Down Expand Up @@ -743,7 +745,14 @@ impl FunctionBody {
container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
});
Some(ContainerInfo { is_in_tail, is_const, parent_loop, ret_type: ty })
Some(ContainerInfo {
is_in_tail,
is_const,
parent_loop,
ret_type: ty,
generic_param_list,
where_clause,
})
}

fn return_ty(&self, ctx: &AssistContext) -> Option<RetType> {
Expand Down Expand Up @@ -1311,26 +1320,32 @@ fn format_function(
let const_kw = if fun.mods.is_const { "const " } else { "" };
let async_kw = if fun.control_flow.is_async { "async " } else { "" };
let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
let generic_params = format_generic_param_list(fun);
let where_clause = format_where_clause(fun);
match ctx.config.snippet_cap {
Some(_) => format_to!(
fn_def,
"\n\n{}{}{}{}fn $0{}{}",
"\n\n{}{}{}{}fn $0{}{}{}{}",
new_indent,
const_kw,
async_kw,
unsafe_kw,
fun.name,
params
generic_params,
params,
where_clause
),
None => format_to!(
fn_def,
"\n\n{}{}{}{}fn {}{}",
"\n\n{}{}{}{}fn {}{}{}{}",
new_indent,
const_kw,
async_kw,
unsafe_kw,
fun.name,
params
generic_params,
params,
where_clause,
),
}
if let Some(ret_ty) = ret_ty {
Expand All @@ -1341,6 +1356,20 @@ fn format_function(
fn_def
}

fn format_generic_param_list(fun: &Function) -> String {
match &fun.mods.generic_param_list {
Some(it) => format!("{}", it),
None => "".to_string(),
}
}

fn format_where_clause(fun: &Function) -> String {
match &fun.mods.where_clause {
Some(it) => format!(" {}", it),
None => "".to_string(),
}
}

impl Function {
fn make_param_list(&self, ctx: &AssistContext, module: hir::Module) -> ast::ParamList {
let self_param = self.self_param.clone();
Expand Down Expand Up @@ -4709,6 +4738,48 @@ fn $0fun_name() {
/* a comment */
let x = 0;
}
"#,
);
}

#[test]
fn preserve_generics() {
check_assist(
extract_function,
r#"
fn func<T: Debug>(i: T) {
$0foo(i);$0
}
"#,
r#"
fn func<T: Debug>(i: T) {
fun_name(i);
}
fn $0fun_name<T: Debug>(i: T) {
foo(i);
}
"#,
);
}

#[test]
fn preserve_where_clause() {
check_assist(
extract_function,
r#"
fn func<T>(i: T) where T: Debug {
$0foo(i);$0
}
"#,
r#"
fn func<T>(i: T) where T: Debug {
fun_name(i);
}
fn $0fun_name<T>(i: T) where T: Debug {
foo(i);
}
"#,
);
}
Expand Down

0 comments on commit e2fea73

Please sign in to comment.