From 7b21ade0760b951587c20b46b85108de9366e195 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Mon, 2 Sep 2024 12:51:15 +0100 Subject: [PATCH] Simplifications --- .../src/checkers/ast/analyze/statement.rs | 6 +- .../src/rules/ruff/rules/post_init_default.rs | 227 +++++++----------- ..._rules__ruff__tests__RUF033_RUF033.py.snap | 16 +- 3 files changed, 100 insertions(+), 149 deletions(-) diff --git a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs index f852e2b3c7688..4f73f7fe3764c 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs @@ -380,6 +380,9 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { if checker.enabled(Rule::WhitespaceAfterDecorator) { pycodestyle::rules::whitespace_after_decorator(checker, decorator_list); } + if checker.enabled(Rule::PostInitDefault) { + ruff::rules::post_init_default(checker, function_def); + } } Stmt::Return(_) => { if checker.enabled(Rule::ReturnOutsideFunction) { @@ -546,9 +549,6 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { if checker.enabled(Rule::WhitespaceAfterDecorator) { pycodestyle::rules::whitespace_after_decorator(checker, decorator_list); } - if checker.enabled(Rule::PostInitDefault) { - ruff::rules::post_init_default(checker, class_def); - } } Stmt::Import(ast::StmtImport { names, range: _ }) => { if checker.enabled(Rule::MultipleImportsOnOneLine) { diff --git a/crates/ruff_linter/src/rules/ruff/rules/post_init_default.rs b/crates/ruff_linter/src/rules/ruff/rules/post_init_default.rs index c5c094180e16e..a13949f167a70 100644 --- a/crates/ruff_linter/src/rules/ruff/rules/post_init_default.rs +++ b/crates/ruff_linter/src/rules/ruff/rules/post_init_default.rs @@ -1,30 +1,32 @@ +use anyhow::Context; + use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::{ - self as ast, helpers::is_docstring_stmt, Expr, Parameter, ParameterWithDefault, Stmt, -}; +use ruff_python_ast as ast; +use ruff_python_semantic::{Scope, ScopeKind}; use ruff_python_trivia::{indentation_at_offset, textwrap}; -use ruff_text_size::{Ranged, TextRange}; +use ruff_text_size::Ranged; use crate::{checkers::ast::Checker, importer::ImportRequest}; use super::helpers::is_dataclass; /// ## What it does -/// Checks for `__post_init__` dataclass methods with argument defaults. +/// Checks for `__post_init__` dataclass methods with parameter defaults. /// /// ## Why is this bad? -/// Variables that are only used during initialization should be instantiated -/// as an init-only pseudo-field using `dataclasses.InitVar`. According to the -/// [documentation]: +/// Adding a default value to a parameter in a `__post_init__` method has no +/// impact on whether the parameter will have a default value in the dataclass's +/// generated `__init__` method. To create an init-only dataclass parameter with +/// a default value, you should use an `InitVar` field in the dataclass's class +/// body and give that `InitVar` field a default value. +/// +/// As the [documentation] states: /// /// > Init-only fields are added as parameters to the generated `__init__()` /// > method, and are passed to the optional `__post_init__()` method. They are /// > not otherwise used by dataclasses. /// -/// Default values for `__post_init__` arguments that exist as init-only fields -/// as well will be overridden by the `dataclasses.InitVar` value. -/// /// ## Example /// ```python /// from dataclasses import InitVar, dataclass @@ -80,153 +82,106 @@ impl Violation for PostInitDefault { } /// RUF033 -pub(crate) fn post_init_default(checker: &mut Checker, class_def: &ast::StmtClassDef) { - if !is_dataclass(class_def, checker.semantic()) { +pub(crate) fn post_init_default(checker: &mut Checker, function_def: &ast::StmtFunctionDef) { + if &function_def.name != "__post_init__" { return; } - for statement in &class_def.body { - if let Stmt::FunctionDef(function_def) = statement { - if &function_def.name != "__post_init__" { - continue; + let current_scope = checker.semantic().current_scope(); + match current_scope.kind { + ScopeKind::Class(class_def) => { + if !is_dataclass(class_def, checker.semantic()) { + return; } + } + _ => return, + } - let mut stopped_fixes = false; - - for ParameterWithDefault { - parameter, - default, - range, - } in function_def.parameters.iter_non_variadic_params() - { - let Some(default) = default else { - continue; - }; - let mut diagnostic = Diagnostic::new(PostInitDefault, default.range()); - - // If the name is already bound on the class, no fix is available. - if !already_bound(parameter, class_def) && !stopped_fixes { - if let Some(fix) = use_initvar(class_def, parameter, default, *range, checker) { - diagnostic.set_fix(fix); - } - } else { - // Need to stop fixes as soon as there is a parameter we - // cannot fix. Otherwise, we risk a syntax error (a - // parameter without a default following parameter with a - // default). - stopped_fixes = true; - } - - checker.diagnostics.push(diagnostic); - } + let mut stopped_fixes = false; + let mut diagnostics = vec![]; + + for ast::ParameterWithDefault { + parameter, + default, + range: _, + } in function_def.parameters.iter_non_variadic_params() + { + let Some(default) = default else { + continue; + }; + let mut diagnostic = Diagnostic::new(PostInitDefault, default.range()); + + if !stopped_fixes { + diagnostic.try_set_fix(|| { + use_initvar(current_scope, function_def, parameter, default, checker) + }); + // Need to stop fixes as soon as there is a parameter we cannot fix. + // Otherwise, we risk a syntax error (a parameter without a default + // following parameter with a default). + stopped_fixes |= diagnostic.fix.is_none(); } + + diagnostics.push(diagnostic); } + + checker.diagnostics.extend(diagnostics); } /// Generate a [`Fix`] to transform a `__post_init__` default argument into a /// `dataclasses.InitVar` pseudo-field. fn use_initvar( - class_def: &ast::StmtClassDef, - parameter: &Parameter, - default: &Expr, - range: TextRange, + current_scope: &Scope, + post_init_def: &ast::StmtFunctionDef, + parameter: &ast::Parameter, + default: &ast::Expr, checker: &Checker, -) -> Option { +) -> anyhow::Result { + if current_scope.has(¶meter.name) { + return Err(anyhow::anyhow!( + "Cannot add a `{}: InitVar` field to the class body, as a field by that name already exists", + parameter.name + )); + } + // Ensure that `dataclasses.InitVar` is accessible. For example, // + `from dataclasses import InitVar` - let (import_edit, binding) = checker - .importer() - .get_or_import_symbol( - &ImportRequest::import("dataclasses", "InitVar"), - default.start(), - checker.semantic(), - ) - .ok()?; + let (import_edit, initvar_binding) = checker.importer().get_or_import_symbol( + &ImportRequest::import("dataclasses", "InitVar"), + default.start(), + checker.semantic(), + )?; // Delete the default value. For example, // - def __post_init__(self, foo: int = 0) -> None: ... // + def __post_init__(self, foo: int) -> None: ... - let (default_start, default_end) = match ¶meter.annotation { - Some(annotation) => (annotation.range().end(), range.end()), - None => (parameter.range().end(), range.end()), - }; - let default_edit = Edit::deletion(default_start, default_end); + let default_edit = Edit::deletion(parameter.end(), default.end()); // Add `dataclasses.InitVar` field to class body. - let mut body = class_def.body.iter().peekable(); - let statement = body.peek()?; - let mut content = String::new(); - match ¶meter.annotation { - Some(annotation) => { - content.push_str(&format!( - "{}: {}[{}] = {}", - checker.locator().slice(¶meter.name), - binding, - checker.locator().slice(annotation.range()), - checker.locator().slice(default) - )); - } - None => { - content.push_str(&format!( - "{}: {} = {}", - checker.locator().slice(¶meter.name), - binding, - checker.locator().slice(default) - )); - } - } - content.push_str(checker.stylist().line_ending().as_str()); - let indentation = indentation_at_offset(statement.start(), checker.locator())?; - let content = textwrap::indent(&content, indentation).to_string(); - - // Find the position after any docstring to insert the field. - let mut pos = checker.locator().line_start(statement.start()); - while let Some(statement) = body.next() { - if is_docstring_stmt(statement) { - if let Some(next) = body.peek() { - if checker - .indexer() - .in_multi_statement_line(statement, checker.locator()) - { - continue; - } - pos = checker.locator().line_start(next.start()); - } else { - pos = checker.locator().full_line_end(statement.end()); - break; - } + let locator = checker.locator(); + + let content = { + let parameter_name = locator.slice(¶meter.name); + let default = locator.slice(default); + let line_ending = checker.stylist().line_ending().as_str(); + + if let Some(annotation) = ¶meter + .annotation + .as_deref() + .map(|annotation| locator.slice(annotation)) + { + format!("{parameter_name}: {initvar_binding}[{annotation}] = {default}{line_ending}") } else { - break; - }; - } - let initvar_edit = Edit::insertion(content, pos); + format!("{parameter_name}: {initvar_binding} = {default}{line_ending}") + } + }; - Some(Fix::unsafe_edits(import_edit, [default_edit, initvar_edit])) -} + let indentation = indentation_at_offset(post_init_def.start(), checker.locator()) + .context("Failed to calculate leading indentation of `__post_init__` method")?; + let content = textwrap::indent(&content, indentation); -/// Check if a name is already bound as a class variable. -fn already_bound(parameter: &Parameter, class_def: &ast::StmtClassDef) -> bool { - for statement in &class_def.body { - let target = match statement { - Stmt::Assign(ast::StmtAssign { targets, .. }) => { - if let [Expr::Name(id)] = targets.as_slice() { - id - } else { - continue; - } - } - Stmt::AnnAssign(ast::StmtAnnAssign { target, .. }) => { - if let Expr::Name(id) = target.as_ref() { - id - } else { - continue; - } - } - _ => continue, - }; - if *target.id == *parameter.name { - return true; - } - } - false + let initvar_edit = Edit::insertion( + content.into_owned(), + locator.line_start(post_init_def.start()), + ); + Ok(Fix::unsafe_edits(import_edit, [default_edit, initvar_edit])) } diff --git a/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF033_RUF033.py.snap b/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF033_RUF033.py.snap index 8a145e329f9da..a5b42e3f4a71e 100644 --- a/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF033_RUF033.py.snap +++ b/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF033_RUF033.py.snap @@ -109,13 +109,11 @@ RUF033.py:59:40: RUF033 [*] `__post_init__` method with argument defaults = help: Use `dataclasses.InitVar` instead ℹ Unsafe fix -54 54 | Docstrings are very important and totally not a waste of time. -55 55 | """ 56 56 | - 57 |+ bar: InitVar[int] = 11 -57 58 | ping = "pong" -58 59 | +57 57 | ping = "pong" +58 58 | 59 |- def __post_init__(self, bar: int = 11, baz: int = 12) -> None: ... + 59 |+ bar: InitVar[int] = 11 60 |+ def __post_init__(self, bar: int, baz: int = 12) -> None: ... 60 61 | 61 62 | @@ -131,13 +129,11 @@ RUF033.py:59:55: RUF033 [*] `__post_init__` method with argument defaults = help: Use `dataclasses.InitVar` instead ℹ Unsafe fix -54 54 | Docstrings are very important and totally not a waste of time. -55 55 | """ 56 56 | - 57 |+ baz: InitVar[int] = 12 -57 58 | ping = "pong" -58 59 | +57 57 | ping = "pong" +58 58 | 59 |- def __post_init__(self, bar: int = 11, baz: int = 12) -> None: ... + 59 |+ baz: InitVar[int] = 12 60 |+ def __post_init__(self, bar: int = 11, baz: int) -> None: ... 60 61 | 61 62 |