From 13dc2b22c1efbfcfdc3e299b657a8e28c9b46a99 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 23 Jan 2025 16:10:54 -0600 Subject: [PATCH 1/2] Enable ruff SIM rules --- doc/conf.py | 4 ++- pyproject.toml | 1 + pytools/__init__.py | 48 ++++++++++------------------ pytools/batchjob.py | 23 ++++--------- pytools/convergence.py | 18 +++++------ pytools/debug.py | 17 +++------- pytools/graph.py | 6 ++-- pytools/obj_array.py | 5 +-- pytools/spatial_btree.py | 10 ++---- pytools/test/test_persistent_dict.py | 34 ++++++++------------ 10 files changed, 60 insertions(+), 106 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 6b2b9ca5..8a9f8500 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -17,8 +17,10 @@ # # The short X.Y version. ver_dic = {} -exec(compile(open("../pytools/version.py").read(), "../pytools/version.py", "exec"), +with open("../pytools/version.py") as vfile: + exec(compile(vfile.read(), "../pytools/version.py", "exec"), ver_dic) + version = ".".join(str(x) for x in ver_dic["VERSION"]) release = ver_dic["VERSION_TEXT"] diff --git a/pyproject.toml b/pyproject.toml index 9c142776..80c7b8e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ extend-select = [ "RUF", # ruff "W", # pycodestyle "TC", + "SIM", ] extend-ignore = [ "C90", # McCabe complexity diff --git a/pytools/__init__.py b/pytools/__init__.py index 372e8dad..3be14b1b 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -28,6 +28,7 @@ """ import builtins +import contextlib import logging import operator import re @@ -421,10 +422,8 @@ def __init__(self, def get_copy_kwargs(self, **kwargs): for f in self.__class__.fields: if f not in kwargs: - try: + with contextlib.suppress(AttributeError): kwargs[f] = getattr(self, f) - except AttributeError: - pass return kwargs def copy(self, **kwargs): @@ -615,10 +614,7 @@ def is_single_valued( except StopIteration: raise ValueError("empty iterable passed to 'single_valued()'") from None - for other_item in it: - if not equality_pred(other_item, first_item): - return False - return True + return all(equality_pred(other_item, first_item) for other_item in it) all_equal = is_single_valued @@ -642,12 +638,7 @@ def single_valued( except StopIteration: raise ValueError("empty iterable passed to 'single_valued()'") from None - def others_same(): - for other_item in it: - if not equality_pred(other_item, first_item): - return False - return True - assert others_same() + assert all(equality_pred(other_item, first_item) for other_item in it) return first_item @@ -754,10 +745,7 @@ def memoize_on_first_arg( ) def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: - if kwargs: - key = (_HasKwargs, frozenset(kwargs.items()), *args) - else: - key = args + key = (_HasKwargs, frozenset(kwargs.items()), *args) if kwargs else args assert cache_dict_name is not None try: @@ -2093,9 +2081,8 @@ def invoke_editor(s, filename="edit.txt", descr="the file"): from os.path import join full_name = join(tempdir, filename) - outf = open(full_name, "w") - outf.write(str(s)) - outf.close() + with open(full_name, "w") as outf: + outf.write(str(s)) import os if "EDITOR" in os.environ: @@ -2107,9 +2094,8 @@ def invoke_editor(s, filename="edit.txt", descr="the file"): "dropped directly into an editor next time.)") input(f"Edit {descr} at {full_name} now, then hit [Enter]:") - inf = open(full_name) - result = inf.read() - inf.close() + with open(full_name) as inf: + result = inf.read() return result @@ -2634,16 +2620,14 @@ def __init__( use_late_start_logging = False if use_late_start_logging: - try: + # https://github.com/firedrakeproject/firedrake/issues/1422 + # + # Starting a thread may fail in various environments, e.g. MPI. + # Since the late-start logging is an optional 'quality-of-life' + # feature for interactive use, tolerate failures of it without + # warning. + with contextlib.suppress(RuntimeError): self.late_start_log_thread.start() - except RuntimeError: - # https://github.com/firedrakeproject/firedrake/issues/1422 - # - # Starting a thread may fail in various environments, e.g. MPI. - # Since the late-start logging is an optional 'quality-of-life' - # feature for interactive use, tolerate failures of it without - # warning. - pass self.timer = ProcessTimer() diff --git a/pytools/batchjob.py b/pytools/batchjob.py index d4f2881f..84235a6c 100644 --- a/pytools/batchjob.py +++ b/pytools/batchjob.py @@ -5,15 +5,8 @@ def _cp(src, dest): from pytools import assert_not_a_file assert_not_a_file(dest) - inf = open(src, "rb") - try: - outf = open(dest, "wb") - try: - outf.write(inf.read()) - finally: - outf.close() - finally: - inf.close() + with open(src, "rb") as inf, open(dest, "wb") as outf: + outf.write(inf.read()) def get_timestamp(): @@ -43,10 +36,9 @@ def __init__(self, moniker, main_file, aux_files=(), timestamp=None): os.makedirs(self.path) - runscript = open(f"{self.path}/run.sh", "w") - import sys - runscript.write(f"{sys.executable} {main_file} setup.cpy") - runscript.close() + with open(f"{self.path}/run.sh", "w") as runscript: + import sys + runscript.write(f"{sys.executable} {main_file} setup.cpy") from os.path import basename @@ -58,9 +50,8 @@ def __init__(self, moniker, main_file, aux_files=(), timestamp=None): def write_setup(self, lines): import os.path - setup = open(os.path.join(self.path, "setup.cpy"), "w") - setup.write("\n".join(lines)) - setup.close() + with open(os.path.join(self.path, "setup.cpy"), "w") as setup: + setup.write("\n".join(lines)) class INHERIT: diff --git a/pytools/convergence.py b/pytools/convergence.py index cd16e4f7..0b6ae0e7 100644 --- a/pytools/convergence.py +++ b/pytools/convergence.py @@ -145,15 +145,15 @@ def __str__(self): return self.pretty_print() def write_gnuplot_file(self, filename: str) -> None: - outfile = open(filename, "w") - for absc, err in self.history: - outfile.write(f"{absc:f} {err:f}\n") - result = self.estimate_order_of_convergence() - const = result[0, 0] - order = result[0, 1] - outfile.write("\n") - for absc, _err in self.history: - outfile.write(f"{absc:f} {const * absc**(-order):f}\n") + with open(filename, "w") as outfile: + for absc, err in self.history: + outfile.write(f"{absc:f} {err:f}\n") + result = self.estimate_order_of_convergence() + const = result[0, 0] + order = result[0, 1] + outfile.write("\n") + for absc, _err in self.history: + outfile.write(f"{absc:f} {const * absc**(-order):f}\n") def stringify_eocs(*eocs: EOCRecorder, diff --git a/pytools/debug.py b/pytools/debug.py index 6a0382ac..0fc125c8 100644 --- a/pytools/debug.py +++ b/pytools/debug.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import sys from pytools import memoize @@ -66,17 +67,12 @@ def is_excluded(o): return True from sys import _getframe - if isinstance(o, FrameType) and \ - o.f_code.co_filename == _getframe().f_code.co_filename: - return True - - return False + return bool(isinstance(o, FrameType) + and o.f_code.co_filename == _getframe().f_code.co_filename) if top_level: - try: + with contextlib.suppress(RefDebugQuit): refdebug(obj, top_level=False, exclude=exclude) - except RefDebugQuit: - pass return import gc @@ -94,10 +90,7 @@ def is_excluded(o): print_head = False r = reflist[idx] - if isinstance(r, FrameType): - s = str(r.f_code) - else: - s = str(r) + s = str(r.f_code) if isinstance(r, FrameType) else str(r) print(f"{idx}/{len(reflist)}: ", id(r), type(r), s) diff --git a/pytools/graph.py b/pytools/graph.py index f8f1a1a0..f4760f43 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -368,9 +368,9 @@ def compute_transitive_closure( closure = deepcopy(graph) # (assumes all graph nodes are included in keys) - for k in graph.keys(): - for n1 in graph.keys(): - for n2 in graph.keys(): + for k in graph: + for n1 in graph: + for n2 in graph: if k in closure[n1] and n2 in closure[k]: closure[n1].add(n2) diff --git a/pytools/obj_array.py b/pytools/obj_array.py index a4ee98b0..ccfa8a16 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -369,10 +369,7 @@ def with_object_array_or_scalar(f, field, obj_array_only=False): "use obj_array_vectorize", DeprecationWarning, stacklevel=2) if obj_array_only: - if is_obj_array(field): - ls = field.shape - else: - ls = () + ls = field.shape if is_obj_array(field) else () else: ls = log_shape(field) if ls != (): diff --git a/pytools/spatial_btree.py b/pytools/spatial_btree.py index 652705b8..0b267415 100644 --- a/pytools/spatial_btree.py +++ b/pytools/spatial_btree.py @@ -7,10 +7,7 @@ def do_boxes_intersect(bl, tr): (bl1, tr1) = bl (bl2, tr2) = tr (dimension,) = bl1.shape - for i in range(dimension): - if max(bl1[i], bl2[i]) > min(tr1[i], tr2[i]): - return False - return True + return all(max(bl1[i], bl2[i]) <= min(tr1[i], tr2[i]) for i in range(dimension)) def make_buckets(bottom_left, top_right, allbuckets, max_elements_per_box): @@ -131,10 +128,7 @@ def generate_matches(self, point): (dimensions,) = point.shape bucket = self.buckets for dim in range(dimensions): - if point[dim] < self.center[dim]: - bucket = bucket[0] - else: - bucket = bucket[1] + bucket = bucket[0] if point[dim] < self.center[dim] else bucket[1] yield from bucket.generate_matches(point) diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index 76135e05..cd805062 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import shutil import sys import tempfile @@ -232,14 +233,12 @@ def test_persistent_dict_cache_collisions() -> None: pdict[key1] = 1 # check lookup - with pytest.warns(CollisionWarning): - with pytest.raises(NoSuchEntryCollisionError): - pdict.fetch(key2) + with pytest.warns(CollisionWarning), pytest.raises(NoSuchEntryCollisionError): + pdict.fetch(key2) # check deletion - with pytest.warns(CollisionWarning): - with pytest.raises(NoSuchEntryCollisionError): - del pdict[key2] + with pytest.warns(CollisionWarning), pytest.raises(NoSuchEntryCollisionError): + del pdict[key2] # check presence after deletion assert pdict[key1] == 1 @@ -377,9 +376,8 @@ def test_write_once_persistent_dict_cache_collisions() -> None: pdict[key1] = 1 # check lookup - with pytest.warns(CollisionWarning): - with pytest.raises(NoSuchEntryCollisionError): - pdict.fetch(key2) + with pytest.warns(CollisionWarning), pytest.raises(NoSuchEntryCollisionError): + pdict.fetch(key2) # check update with pytest.raises(ReadOnlyEntryError): @@ -854,7 +852,7 @@ def test_keys_values_items(): assert list(pdict.values()) == list(range(10000)) assert list(pdict.items()) == list(zip(pdict, range(10000), strict=True)) - assert ([k for k in pdict.keys()] # noqa: C416 + assert ([k for k in pdict] # noqa: C416 == list(pdict.keys()) == list(pdict) == [k for k in pdict]) # noqa: C416 @@ -959,25 +957,19 @@ def _conc_fn(tmpdir: str | None = None, print(f"i={i}") if isinstance(pdict, WriteOncePersistentDict): - try: + with contextlib.suppress(ReadOnlyEntryError): pdict[i] = i - except ReadOnlyEntryError: - pass else: pdict[i] = i - try: + # Someone else already deleted the entry + with contextlib.suppress(NoSuchEntryError): s += pdict[i] - except NoSuchEntryError: - # Someone else already deleted the entry - pass if not isinstance(pdict, WriteOncePersistentDict): - try: + # Suppressed in case someone else already deleted the entry + with contextlib.suppress(NoSuchEntryError): del pdict[i] - except NoSuchEntryError: - # Someone else already deleted the entry - pass end = time.time() From c2859599623924beece88f23caab0a5f2fd93a40 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 23 Jan 2025 16:11:09 -0600 Subject: [PATCH 2/2] Delete CPyUserInterface --- pytools/__init__.py | 83 --------------------------------------------- 1 file changed, 83 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 3be14b1b..08307a82 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -1913,89 +1913,6 @@ def word_wrap(text, width, wrap_using="\n"): # }}} -# {{{ command line interfaces - -def _exec_arg(arg, execenv): - import os - if os.access(arg, os.F_OK): - exec(compile(open(arg), arg, "exec"), execenv) - else: - exec(compile(arg, "", "exec"), execenv) - - -class CPyUserInterface: - class Parameters(Record): - pass - - def __init__(self, variables, constants=None, doc=None): - if constants is None: - constants = {} - if doc is None: - doc = {} - self.variables = variables - self.constants = constants - self.doc = doc - - def show_usage(self, progname): - print(f"usage: {progname} ") - print() - print("FILE-OR-STATEMENTS may either be Python statements of the form") - print("'variable1 = value1; variable2 = value2' or the name of a file") - print("containing such statements. Any valid Python code may be used") - print("on the command line or in a command file. If new variables are") - print("used, they must start with 'user_' or just '_'.") - print() - print("The following variables are recognized:") - for v in sorted(self.variables): - print(f" {v} = {self.variables[v]}") - if v in self.doc: - print(f" {self.doc[v]}") - - print() - print("The following constants are supplied:") - for c in sorted(self.constants): - print(f" {c} = {self.constants[c]}") - if c in self.doc: - print(f" {self.doc[c]}") - - def gather(self, argv=None): - if argv is None: - argv = sys.argv - - if len(argv) == 1 or ( - ("-h" in argv) - or ("help" in argv) - or ("-help" in argv) - or ("--help" in argv)): - self.show_usage(argv[0]) - sys.exit(2) - - execenv = self.variables.copy() - execenv.update(self.constants) - - for arg in argv[1:]: - _exec_arg(arg, execenv) - - # check if the user set invalid keys - for added_key in ( - set(execenv.keys()) - - set(self.variables.keys()) - - set(self.constants.keys())): - if not (added_key.startswith("user_") or added_key.startswith("_")): - raise ValueError( - f"invalid setup key: '{added_key}' " - "(user variables must start with 'user_' or '_')") - - result = self.Parameters({key: execenv[key] for key in self.variables}) - self.validate(result) - return result - - def validate(self, setup): - pass - -# }}} - - # {{{ debugging class StderrToStdout: