Skip to content

Commit

Permalink
fix: Extract Function assist with generics
Browse files Browse the repository at this point in the history
This change attempts to resolve issue rust-lang#7637: Extract into Function does not
create a generic function with constraints when extracting generic code.

In `FunctionBody::analyze_container`, when the ancestor matches `ast::Fn`, we
can perserve both the `generic_param_list` and the `where_clause`. These can
then be included in the newly extracted function output via `format_function`.

From what I can tell, the only other ancestor type that could potentially have
a generic param list would be `ast::ClosureExpr`. In this case, we perserve the
`generic_param_list`, but no where clause is ever present.

In this inital implementation, all the generic params and where clauses from
the parent function will be copied to the newly extracted function. An obvious
improvement would be to filter this output in some way to only include generic
parameters that are actually used in the function body.

fixes rust-lang#7636
  • Loading branch information
DorianListens committed Jun 16, 2022
1 parent 65874df commit ecc3a48
Showing 1 changed file with 88 additions and 15 deletions.
103 changes: 88 additions & 15 deletions crates/ide-assists/src/handlers/extract_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use syntax::{
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
AstNode,
AstNode, HasGenericParams,
},
match_ast, ted, SyntaxElement,
SyntaxKind::{self, COMMENT},
Expand Down Expand Up @@ -266,6 +266,8 @@ struct ContainerInfo {
parent_loop: Option<SyntaxNode>,
/// The function's return type, const's type etc.
ret_type: Option<hir::Type>,
generic_param_list: Option<ast::GenericParamList>,
where_clause: Option<ast::WhereClause>,
}

/// Control flow that is exported from extracted function
Expand Down Expand Up @@ -676,11 +678,11 @@ impl FunctionBody {
parent_loop.get_or_insert(loop_.syntax().clone());
}
};
let (is_const, expr, ty) = loop {
let (is_const, expr, ty, generic_param_list, where_clause) = loop {
let anc = ancestors.next()?;
break match_ast! {
match anc {
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body())),
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body()), closure.generic_param_list(), None),
ast::BlockExpr(block_expr) => {
let (constness, block) = match block_expr.modifier() {
Some(ast::BlockModifier::Const(_)) => (true, block_expr),
Expand All @@ -689,7 +691,7 @@ impl FunctionBody {
_ => continue,
};
let expr = Some(ast::Expr::BlockExpr(block));
(constness, expr.clone(), infer_expr_opt(expr))
(constness, expr.clone(), infer_expr_opt(expr), None, None)
},
ast::Fn(fn_) => {
let func = sema.to_def(&fn_)?;
Expand All @@ -699,23 +701,23 @@ impl FunctionBody {
ret_ty = async_ret;
}
}
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty))
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty), fn_.generic_param_list(), fn_.where_clause())
},
ast::Static(statik) => {
(true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db)))
(true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db)), None, None)
},
ast::ConstArg(ca) => {
(true, ca.expr(), infer_expr_opt(ca.expr()))
(true, ca.expr(), infer_expr_opt(ca.expr()), None, None)
},
ast::Const(konst) => {
(true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db)))
(true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db)), None, None)
},
ast::ConstParam(cp) => {
(true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db)))
(true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db)), None, None)
},
ast::ConstBlockPat(cbp) => {
let expr = cbp.block_expr().map(ast::Expr::BlockExpr);
(true, expr.clone(), infer_expr_opt(expr))
(true, expr.clone(), infer_expr_opt(expr), None, None)
},
ast::Variant(__) => return None,
ast::Meta(__) => return None,
Expand Down Expand Up @@ -743,7 +745,14 @@ impl FunctionBody {
container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
});
Some(ContainerInfo { is_in_tail, is_const, parent_loop, ret_type: ty })
Some(ContainerInfo {
is_in_tail,
is_const,
parent_loop,
ret_type: ty,
generic_param_list,
where_clause,
})
}

fn return_ty(&self, ctx: &AssistContext) -> Option<RetType> {
Expand Down Expand Up @@ -890,6 +899,7 @@ impl FunctionBody {
// if the var is not used but defined outside a loop we are extracting from we can't move it either
// as the function will reuse it in the next iteration.
let move_local = (!has_usages && defined_outside_parent_loop) || ty.is_reference();

Param { var, ty, move_local, requires_mut, is_copy }
})
.collect()
Expand Down Expand Up @@ -1311,26 +1321,32 @@ fn format_function(
let const_kw = if fun.mods.is_const { "const " } else { "" };
let async_kw = if fun.control_flow.is_async { "async " } else { "" };
let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
let generic_params = format_generic_param_list(fun);
let where_clause = format_where_clause(fun);
match ctx.config.snippet_cap {
Some(_) => format_to!(
fn_def,
"\n\n{}{}{}{}fn $0{}{}",
"\n\n{}{}{}{}fn $0{}{}{}{}",
new_indent,
const_kw,
async_kw,
unsafe_kw,
fun.name,
params
generic_params,
params,
where_clause
),
None => format_to!(
fn_def,
"\n\n{}{}{}{}fn {}{}",
"\n\n{}{}{}{}fn {}{}{}{}",
new_indent,
const_kw,
async_kw,
unsafe_kw,
fun.name,
params
generic_params,
params,
where_clause,
),
}
if let Some(ret_ty) = ret_ty {
Expand All @@ -1341,6 +1357,21 @@ fn format_function(
fn_def
}

fn format_where_clause(fun: &Function) -> String {
format_or_empty(&fun.mods.where_clause, |x| Some(format!(" {}", x)))
}

fn format_generic_param_list(fun: &Function) -> String {
format_or_empty(&fun.mods.generic_param_list, |x| Some(format!("{}", x)))
}

fn format_or_empty<'a, T: std::fmt::Display>(
x: &'a Option<T>,
formatter: impl FnOnce(&'a T) -> Option<String>,
) -> String {
x.as_ref().and_then(formatter).unwrap_or_default()
}

impl Function {
fn make_param_list(&self, ctx: &AssistContext, module: hir::Module) -> ast::ParamList {
let self_param = self.self_param.clone();
Expand Down Expand Up @@ -4709,6 +4740,48 @@ fn $0fun_name() {
/* a comment */
let x = 0;
}
"#,
);
}

#[test]
fn preserve_generics() {
check_assist(
extract_function,
r#"
fn func<T: Debug>(i: T) {
$0foo(i);$0
}
"#,
r#"
fn func<T: Debug>(i: T) {
fun_name(i);
}
fn $0fun_name<T: Debug>(i: T) {
foo(i);
}
"#,
);
}

#[test]
fn preserve_where_clause() {
check_assist(
extract_function,
r#"
fn func<T>(i: T) where T: Debug {
$0foo(i);$0
}
"#,
r#"
fn func<T>(i: T) where T: Debug {
fun_name(i);
}
fn $0fun_name<T>(i: T) where T: Debug {
foo(i);
}
"#,
);
}
Expand Down

0 comments on commit ecc3a48

Please sign in to comment.