From a5d4aea033b2a026e8b2ebb1f567acda6b2bd197 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Fri, 7 Jun 2024 21:56:48 +0200 Subject: [PATCH 1/5] get latest partition.py from lagrangebench. From now on this will be the only partition.py copy --- jax_sph/partition.py | 621 +++++++++++++--------------------------- tests/test_neighbors.py | 137 +++++++++ 2 files changed, 338 insertions(+), 420 deletions(-) create mode 100644 tests/test_neighbors.py diff --git a/jax_sph/partition.py b/jax_sph/partition.py index 15869b3..d0256db 100644 --- a/jax_sph/partition.py +++ b/jax_sph/partition.py @@ -1,8 +1,4 @@ -"""Neighbors search backends. - -Source: -https://github.com/tumaer/lagrangebench/blob/main/lagrangebench/case_setup/partition.py -""" +"""Neighbors search backends.""" from functools import partial from typing import Optional @@ -13,23 +9,26 @@ import numpy as np import numpy as onp from jax import jit +from jax_md import space from jax_md.partition import ( - CellList, MaskFn, NeighborFn, NeighborList, NeighborListFns, NeighborListFormat, + PartitionError, + PartitionErrorCode, _displacement_or_metric_to_metric_sq, _neighboring_cells, cell_list, is_format_valid, is_sparse, shift_array, - space, ) from jax_md.partition import neighbor_list as vmap_neighbor_list +PEC = PartitionErrorCode + def get_particle_cells(idx, cl_capacity, N): """ @@ -54,7 +53,7 @@ def get_particle_cells(idx, cl_capacity, N): def _scan_neighbor_list( displacement_or_metric: space.DisplacementOrMetricFn, - box_size: space.Box, + box: space.Box, r_cutoff: float, dr_threshold: float = 0.0, capacity_multiplier: float = 1.25, @@ -63,7 +62,7 @@ def _scan_neighbor_list( custom_mask_function: Optional[MaskFn] = None, fractional_coordinates: bool = False, format: NeighborListFormat = NeighborListFormat.Sparse, - num_partitions: int = 1, + num_partitions: int = 8, **static_kwargs, ) -> NeighborFn: """Modified JAX-MD neighbor list function that uses `lax.scan` to compute the @@ -116,7 +115,7 @@ def body_fn(i, state): Args: displacement: A function `d(R_a, R_b)` that computes the displacement between pairs of points. - box_size: Either a float specifying the size of the box or an array of + box: Either a float specifying the size of the box or an array of shape `[spatial_dim]` specifying the box size in each spatial dimension. r_cutoff: A scalar specifying the neighborhood radius. dr_threshold: A scalar specifying the maximum distance particles can move @@ -146,18 +145,17 @@ def body_fn(i, state): A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list. """ - assert disable_cell_list is False, "Works only with a cell list" + assert not fractional_coordinates, "Works only with real coordinates" assert format == NeighborListFormat.Sparse, "Works only with sparse neighbor list" assert custom_mask_function is None, "Custom masking not implemented" - # assert mask_self == False, "Self edges cannot be excluded for now" is_format_valid(format) - box_size = lax.stop_gradient(box_size) + box = lax.stop_gradient(box) r_cutoff = lax.stop_gradient(r_cutoff) dr_threshold = lax.stop_gradient(dr_threshold) - box_size = jnp.float32(box_size) + box = jnp.float32(box) cutoff = r_cutoff + dr_threshold cutoff_sq = cutoff**2 @@ -165,430 +163,167 @@ def body_fn(i, state): metric_sq = _displacement_or_metric_to_metric_sq(displacement_or_metric) cell_size = cutoff - if fractional_coordinates: - cell_size = cutoff / box_size - box_size = ( - jnp.float32(box_size) - if onp.isscalar(box_size) - else onp.ones_like(box_size, jnp.float32) - ) + assert jnp.all(cell_size < box / 3.0), "Don't use scan with very few cells" - assert jnp.all(cell_size < box_size / 3.0), "Don't use scan with very few cells" + def neighbor_list_fn( + position: jnp.ndarray, + neighbors: Optional[NeighborList] = None, + extra_capacity: int = 0, + **kwargs, + ) -> NeighborList: + def neighbor_fn(position_and_error, max_occupancy=None): + position, err = position_and_error + N, dim = position.shape + cl_fn = None + cl = None + cell_size = None - cl_fn = cell_list(box_size, cell_size, capacity_multiplier) + if neighbors is None: # cl.shape = (nx, ny, nz, cell_capacity, dim) + cell_size = cutoff + cl_fn = cell_list(box, cell_size, capacity_multiplier) + cl = cl_fn.allocate(position, extra_capacity=extra_capacity) + else: + cell_size = neighbors.cell_size + cl_fn = neighbors.cell_list_fn + if cl_fn is not None: + cl = cl_fn.update(position, neighbors.cell_list_capacity) - @jit - def cell_list_candidate_fn(cl: CellList, position: jnp.ndarray) -> jnp.ndarray: - N, dim = position.shape - - idx = cl.id_buffer - - cell_idx = [idx] - - for dindex in _neighboring_cells( - dim - ): # here the expansion happens over all adjacent cells happens - if onp.all(dindex == 0): - continue - cell_idx += [shift_array(idx, dindex)] # 27* (nx,ny,nz,cell_capacity, 1) - - cell_idx = jnp.concatenate(cell_idx, axis=-2) - cell_idx = cell_idx[..., jnp.newaxis, :, :] # (nx,ny,nz,1,27*cell_capacity, 1) - cell_idx = jnp.broadcast_to( - cell_idx, idx.shape[:-1] + cell_idx.shape[-2:] - ) # (nx,ny,nz,cell_capacity,27*cell_capacity) TODO: memory blows up here - - def copy_values_from_cell(value, cell_value, cell_id): - scatter_indices = jnp.reshape(cell_id, (-1,)) # (nx*ny*nz*cell_capacity) - cell_value = jnp.reshape( - cell_value, (-1,) + cell_value.shape[-2:] - ) # (nx*ny*nz*cell_capacity, 27*cell_capacity, 1) - return value.at[scatter_indices].set(cell_value) - - neighbor_idx = jnp.zeros( - (N + 1,) + cell_idx.shape[-2:], jnp.int32 - ) # (N, 27*cell_capacity, 1) TODO: too much memory - neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) - return neighbor_idx[:-1, :, 0] # shape (N, 27*cell_capacity) + err = err.update(PEC.CELL_LIST_OVERFLOW, cl.did_buffer_overflow) + cl_capacity = cl.cell_capacity - @jit - def prune_neighbor_list_sparse( - position: jnp.ndarray, idx: jnp.ndarray, **kwargs - ) -> jnp.ndarray: - d = partial(metric_sq, **kwargs) - d = space.map_bond(d) + idx = cl.id_buffer - N = position.shape[0] - sender_idx = jnp.broadcast_to(jnp.arange(N)[:, None], idx.shape) + cell_idx = [idx] # shape: (nx, ny, nz, cell_capacity, 1) - sender_idx = jnp.reshape( - sender_idx, (-1,) - ) # (N, 27*cell_capacity) -> (N*27*cell_capacity) - receiver_idx = jnp.reshape(idx, (-1,)) - dR = d( - position[sender_idx], position[receiver_idx] - ) # (N*27*cell_capacity) eventually 3x during computation + for dindex in _neighboring_cells(dim): + if onp.all(dindex == 0): + continue + cell_idx += [shift_array(idx, dindex)] - mask = (dR < cutoff_sq) & (receiver_idx < N) - if format is NeighborListFormat.OrderedSparse: - mask = mask & (receiver_idx < sender_idx) + cell_idx = jnp.concatenate(cell_idx, axis=-2) + cell_idx = jnp.reshape(cell_idx, (-1, cell_idx.shape[-2])) + num_cells, considered_neighbors = cell_idx.shape - out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) + particle_cells = get_particle_cells(idx, cl_capacity, N) - cumsum = jnp.cumsum(mask) - index = jnp.where( - mask, cumsum - 1, len(receiver_idx) - 1 - ) # 7th object of shape (N*27*cell_capacity) - receiver_idx = out_idx.at[index].set(receiver_idx) - sender_idx = out_idx.at[index].set(sender_idx) - max_occupancy = cumsum[-1] + d = partial(metric_sq, **kwargs) + d = space.map_bond(d) - return jnp.stack((receiver_idx, sender_idx)), max_occupancy + # number of particles per partition N_sub + # np.ceil used to pad last partition with < num_partitions entries + N_sub = int(np.ceil(N / num_partitions)) + num_pad = N_sub * num_partitions - N + particle_cells = jnp.pad( + particle_cells, + ( + 0, + num_pad, + ), + constant_values=-1, + ) - def neighbor_list_fn( - position: jnp.ndarray, - neighbors: Optional[NeighborList] = None, - extra_capacity: int = 0, - **kwargs, - ) -> NeighborList: - nbrs = neighbors + if dim == 2: + # the area of a circle with r=1/3 is 0.34907 + volumetric_factor = 0.34907 + elif dim == 3: + # the volume of a sphere with r=1/3 is 0.15514 + volumetric_factor = 0.15514 - def neighbor_fn(position_and_overflow, max_occupancy=None): - position, overflow = position_and_overflow - N = position.shape[0] + num_edges_sub = int( + N_sub * considered_neighbors * volumetric_factor * capacity_multiplier + ) - if neighbors is None: # cl.shape = (nx, ny, nz, cell_capacity, dim) - cl = cl_fn.allocate(position, extra_capacity=extra_capacity) - else: - cl = cl_fn.update(position, neighbors.cell_list_capacity) - overflow = overflow | cl.did_buffer_overflow - cl_capacity = cl.cell_capacity + def scan_body(carry, input): + """Compute neighbors over a subset of particles - if num_partitions == 1: - implementation = "original" - elif num_partitions > 1: - implementation = ( - "numcells" # "numcells", "twentyseven", "vanilla", "original" - ) + The largest object here is of size (N_sub*considered_neighbors), where + considered_neighbors in 3D is 27 * cell_capacity. + """ - if implementation == "numcells": - # idx = cell_list_candidate_fn(cl, position) - # # idx.shape = (N, 27*cell_capacity) - # print("82 ", get_gpu_stats()) - # idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs) - ################################################################ - - N, dim = position.shape - - idx = cl.id_buffer - - cell_idx = [idx] - - for dindex in _neighboring_cells( - dim - ): # here the expansion happens over all adjacent cells happens - if onp.all(dindex == 0): - continue - cell_idx += [ - shift_array(idx, dindex) - ] # 27* (nx,ny,nz,cell_capacity, 1) - - cell_idx = jnp.concatenate(cell_idx, axis=-2) - cell_idx = jnp.reshape( - cell_idx, (-1, cell_idx.shape[-2]) - ) # (num_cells, num_potential_connections) - num_cells, considered_neighbors = cell_idx.shape - - # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - # Given is a cell list `cell_idx` of shape (nx, ny, nz, cell_capacity). - # Find which cell indices correspond to particle 0, 1, 2, ..., N-1 - # and write the results into a new array of shape (N, nx, ny, nz) - - def scan_body(carry, input): - occupancy = carry - slice_from = input - - _entries = lax.dynamic_slice( - particle_cells, (slice_from,), (N_sub,) - ) - _idx = cell_idx[_entries] + occupancy = carry + slice_from = input - if mask_self: - particle_idx = slice_from + jnp.arange(N_sub) - _idx = jnp.where(_idx == particle_idx[:, None], N, _idx) + _entries = lax.dynamic_slice(particle_cells, (slice_from,), (N_sub,)) + _idx = cell_idx[_entries] - if num_pad > 0: - _idx = jnp.where(_entries[:, None] != -1, _idx, N) + if mask_self: + particle_idx = slice_from + jnp.arange(N_sub) + _idx = jnp.where(_idx == particle_idx[:, None], N, _idx) - sender_idx = ( - jnp.broadcast_to( - jnp.arange(N_sub, dtype="int32")[:, None], _idx.shape - ) - + slice_from - ) - if num_pad > 0: - sender_idx = jnp.clip(sender_idx, a_max=N) - - sender_idx = jnp.reshape( - sender_idx, (-1,) - ) # (N, 27*cell_capacity) -> (N*27*cell_capacity) - receiver_idx = jnp.reshape(_idx, (-1,)) - dR = d( - position[sender_idx], position[receiver_idx] - ) # (N*27*cell_capacity) eventually 3x during computation - - mask = (dR < cutoff_sq) & (receiver_idx < N) - out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) - - cumsum = jnp.cumsum(mask) # + occupancy - index = jnp.where( - mask, cumsum - 1, considered_neighbors * N - 1 - ) # (N*27*cell_capacity) - receiver_idx = out_idx.at[index].set(receiver_idx) - sender_idx = out_idx.at[index].set(sender_idx) - occupancy += cumsum[-1] - - carry = occupancy - y = jnp.stack( - (receiver_idx[:num_edges_sub], sender_idx[:num_edges_sub]) + if num_pad > 0: + _idx = jnp.where(_entries[:, None] != -1, _idx, N) + + sender_idx = ( + jnp.broadcast_to( + jnp.arange(N_sub, dtype="int32")[:, None], _idx.shape ) - overflow = cumsum[-1] > num_edges_sub - return carry, (y, overflow) - - particle_cells = get_particle_cells(idx, cl_capacity, N) - - d = partial(metric_sq, **kwargs) - d = space.map_bond(d) - - N_sub = int( - np.ceil(N / num_partitions) - ) # to pad the last chunk with < num_partitions entries - num_pad = N_sub * num_partitions - N - particle_cells = jnp.pad( - particle_cells, - ( - 0, - num_pad, - ), - constant_values=-1, + + slice_from ) + if num_pad > 0: + sender_idx = jnp.clip(sender_idx, a_max=N) - if dim == 2: - # area of a circle with r=1/3 is 0.15514 of a unit cube volume - volumetric_factor = 0.34907 - elif dim == 3: - # volume of sphere with r=1/3 is 0.15514 of a unit cube volume - volumetric_factor = 0.15514 - - num_edges_sub = int( - N_sub - * considered_neighbors - * volumetric_factor - * capacity_multiplier - ) + sender_idx = jnp.reshape(sender_idx, (-1,)) + receiver_idx = jnp.reshape(_idx, (-1,)) + dR = d(position[sender_idx], position[receiver_idx]) - carry = jnp.array(0) - xs = jnp.array([i * N_sub for i in range(num_partitions)]) - # print("82 (numcells)", get_gpu_stats()) - occupancy, (idx, overflows) = lax.scan( - scan_body, carry, xs, length=num_partitions - ) - # print("83 ", get_gpu_stats()) - overflow = overflow | overflows.sum() + mask = (dR < cutoff_sq) & (receiver_idx < N) + out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) - # print(f"idx memory: {idx.nbytes / 1e6:.0f}MB, idx.shape={idx.shape}, - # cl.id_buffer.shape={cl.id_buffer.shape}" ) - idx = idx.transpose(1, 2, 0).reshape(2, -1) + cumsum = jnp.cumsum(mask) + index = jnp.where(mask, cumsum - 1, considered_neighbors * N - 1) + receiver_idx = out_idx.at[index].set(receiver_idx) + sender_idx = out_idx.at[index].set(sender_idx) + occupancy += cumsum[-1] - # sort to enable pruning later - ordering = jnp.argsort(idx[1]) - idx = idx[:, ordering] - - if max_occupancy is None: - _extra_capacity = N * extra_capacity - max_occupancy = int( - occupancy * capacity_multiplier + _extra_capacity - ) - if max_occupancy > idx.shape[-1]: - max_occupancy = idx.shape[-1] - if not is_sparse(format): - capacity_limit = N - 1 if mask_self else N - elif format is NeighborListFormat.Sparse: - capacity_limit = N * (N - 1) if mask_self else N**2 - else: - capacity_limit = N * (N - 1) // 2 - if max_occupancy > capacity_limit: - max_occupancy = capacity_limit - - # prune neighbors list to max_occupancy by removing paddings - idx = idx[:, :max_occupancy] - elif implementation == "original_expanded": - # TODO: here we expand on the 27 adjacent cells - #################################################################### - ### - # idx = cell_list_candidate_fn(cl, position) - # # shape (N, 27*cell_capacity) -> 19M too much! - N, dim = position.shape - - idx = cl.id_buffer # (5, 5, 5, 88, 1) - - cell_idx = [idx] - - for dindex in _neighboring_cells( - dim - ): # here the expansion happens over all adjacent cells happens - if onp.all(dindex == 0): - continue - cell_idx += [ - shift_array(idx, dindex) - ] # 27* (nx,ny,nz,cell_capacity) - - cell_idx = jnp.concatenate(cell_idx, axis=-2) # (5, 5, 5, 2376, 1) - cell_idx = cell_idx[..., jnp.newaxis, :, :] # (5, 5, 5, 1, 2376, 1) - # TODO: memory blows up here by factor "cell_capacity" - cell_idx = jnp.broadcast_to( - cell_idx, idx.shape[:-1] + cell_idx.shape[-2:] - ) # 1.2*X (nx,ny,nz,cell_capacity,27*cell_capacity) - - # def copy_values_from_cell(value, cell_value, cell_id): - # scatter_indices = jnp.reshape(cell_id, (-1,)) - # cell_value = jnp.reshape(cell_value, (-1,) +cell_value.shape[-2:]) - # return value.at[scatter_indices].set(cell_value) - # TODO: further memory increase in the next two lines - neighbor_idx = jnp.zeros( - (N + 1,) + cell_idx.shape[-2:], jnp.int32 - ) # X (N, 27*cell_capacity, 1) - # neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) # X - scatter_indices = jnp.reshape( - idx, (-1,) - ) # (11000,) each cell allocation over all cells expanded - cell_value = jnp.reshape( - cell_idx, (-1,) + cell_idx.shape[-2:] - ) # (nx*ny*nz*cell_capacity, 27*cell_capacity, 1) - neighbor_idx = neighbor_idx.at[scatter_indices].set( - cell_value - ) # X (N, 27*cell_capacity, 1) - - idx = neighbor_idx[ - :-1, :, 0 - ] # X shape (N, 27*cell_capacity) this only removes the 8001th element - # this is just expanded over all cells indices. Should work with - # arbitrary pices over the last dimension - - #################################################################### - # idx.shape = (nx*ny*nz*cell_capacity**2*27) - # -> 26M (or actually just 19M) too much! - # idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs) - d = partial(metric_sq, **kwargs) - d = space.map_bond(d) - - N = position.shape[0] - sender_idx = jnp.broadcast_to( - jnp.arange(N)[:, None], idx.shape - ) # 2X (N, 27*cell_capacity) - - sender_idx = jnp.reshape( - sender_idx, (-1,) - ) # (N, 27*cell_capacity) -> (N*27*cell_capacity) - # [0,0,0,0...0, 1,1,1,1...1, ....] - receiver_idx = jnp.reshape( - idx, (-1,) - ) # flatten the stuff with all possible neighbors (27*cell_size) of - # particle 0, of 1, .... - dR = d( - position[sender_idx], position[receiver_idx] - ) # (N*27*cell_capacity) eventually 3x during computation - - mask = (dR < cutoff_sq) & (receiver_idx < N) # negligible - # if format is NeighborListFormat.OrderedSparse: - # mask = mask & (receiver_idx < sender_idx) - - out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) # X - cumsum = jnp.cumsum(mask) # 2X - index = jnp.where( - mask, cumsum - 1, len(receiver_idx) - 1 - ) # 2X 7th object of shape (N*27*cell_capacity) - receiver_idx = out_idx.at[index].set( - receiver_idx - ) # X # this operation sorts the entries - sender_idx = out_idx.at[index].set( - sender_idx - ) # 2X -> X # this operation also sorts the entries - max_occupancy_ = cumsum[-1] - - idx, occupancy = jnp.stack((receiver_idx, sender_idx)), max_occupancy_ - # Memory: idx 2X, neihbor_idx X, cell_idx 0.8X, sender_idx X, - # receiver_idx X, dR 2X, out_idx X, cumsum 2X, index 2X - # idx_final = jnp.zeros((N, max_occupancy), jnp.int32) # X -> 2X - # print("max occupancy2 ", occupancy) - - if max_occupancy is None: - _extra_capacity = N * extra_capacity - max_occupancy = int( - occupancy * capacity_multiplier + _extra_capacity - ) - if max_occupancy > idx.shape[-1]: - max_occupancy = idx.shape[-1] - if not is_sparse(format): - capacity_limit = N - 1 if mask_self else N - elif format is NeighborListFormat.Sparse: - capacity_limit = N * (N - 1) if mask_self else N**2 - else: - capacity_limit = N * (N - 1) // 2 - if max_occupancy > capacity_limit: - max_occupancy = capacity_limit - idx = idx[ - :, :max_occupancy - ] # shape (N, max_occupancy) -> 2M much smaller - # TODO: from here on the size is ~10x smaller after - # idx=idx[:, :max_occupancy] - # how can we run the previous part sequentially? - ### - #################################################################### - elif implementation == "original": - # print("82 (original)", get_gpu_stats()) - idx = cell_list_candidate_fn(cl, position) - - idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs) - - if max_occupancy is None: - _extra_capacity = ( - extra_capacity if not is_sparse(format) else N * extra_capacity - ) - max_occupancy = int( - occupancy * capacity_multiplier + _extra_capacity - ) - if max_occupancy > idx.shape[-1]: - max_occupancy = idx.shape[-1] - if not is_sparse(format): - capacity_limit = N - 1 if mask_self else N - elif format is NeighborListFormat.Sparse: - capacity_limit = N * (N - 1) if mask_self else N**2 - else: - capacity_limit = N * (N - 1) // 2 - if max_occupancy > capacity_limit: - max_occupancy = capacity_limit - idx = idx[:, :max_occupancy] - - # print("83 ", get_gpu_stats()) - # print(f"idx memory: {idx.nbytes / 1e6:.0f}MB, idx.shape={idx.shape}, - # cl.id_buffer.shape={cl.id_buffer.shape}" ) - - # print("##### max occupancy", max_occupancy, "occupancy", occupancy) + carry = occupancy + y = jnp.stack( + (receiver_idx[:num_edges_sub], sender_idx[:num_edges_sub]) + ) + overflow = cumsum[-1] > num_edges_sub + return carry, (y, overflow) + carry = jnp.array(0) + xs = jnp.array([i * N_sub for i in range(num_partitions)]) + occupancy, (idx, overflows) = lax.scan( + scan_body, carry, xs, length=num_partitions + ) + err = err.update(PEC.CELL_LIST_OVERFLOW, overflows.sum()) + idx = idx.transpose(1, 2, 0).reshape(2, -1) + + # sort to enable pruning later + ordering = jnp.argsort(idx[1]) + idx = idx[:, ordering] + + if max_occupancy is None: + _extra_capacity = N * extra_capacity + max_occupancy = int(occupancy * capacity_multiplier + _extra_capacity) + if max_occupancy > idx.shape[-1]: + max_occupancy = idx.shape[-1] + if not is_sparse(format): + capacity_limit = N - 1 if mask_self else N + elif format is NeighborListFormat.Sparse: + capacity_limit = N * (N - 1) if mask_self else N**2 + else: + capacity_limit = N * (N - 1) // 2 + if max_occupancy > capacity_limit: + max_occupancy = capacity_limit + idx = idx[:, :max_occupancy] update_fn = neighbor_list_fn if neighbors is None else neighbors.update_fn return NeighborList( idx, position, - overflow | (occupancy > max_occupancy), + err.update(PEC.NEIGHBOR_LIST_OVERFLOW, occupancy > max_occupancy), cl_capacity, max_occupancy, format, + cell_size, + cl_fn, update_fn, ) # pytype: disable=wrong-arg-count + nbrs = neighbors if nbrs is None: - return neighbor_fn((position, False)) + return neighbor_fn((position, PartitionError(jnp.zeros((), jnp.uint8)))) neighbor_fn = partial(neighbor_fn, max_occupancy=nbrs.max_occupancy) @@ -597,7 +332,7 @@ def scan_body(carry, input): return lax.cond( jnp.any(d(position, nbrs.reference_position) > threshold_sq), - (position, nbrs.did_buffer_overflow), + (position, nbrs.error), neighbor_fn, nbrs, lambda x: x, @@ -646,19 +381,23 @@ def _matscipy_neighbor_list( else: pbc = np.asarray(pbc, dtype=bool) + dtype_idx = jnp.arange(0).dtype # just to get the correct dtype + def matscipy_wrapper(position, idx_shape, num_particles): position = position[:num_particles] + if position.shape[1] == 2: position = np.pad( position, ((0, 0), (0, 1)), mode="constant", constant_values=0.5 ) + edge_list = matscipy_nl( "ij", cutoff=r_cutoff, positions=position, cell=box_size, pbc=pbc ) - edge_list = np.asarray(edge_list, dtype=np.int32) + edge_list = np.asarray(edge_list, dtype=dtype_idx) if not mask_self: # add self connection, which matscipy does not do - self_connect = np.arange(num_particles, dtype=np.int32) + self_connect = np.arange(num_particles, dtype=dtype_idx) self_connect = np.array([self_connect, self_connect]) edge_list = np.concatenate((self_connect, edge_list), axis=-1) @@ -666,7 +405,7 @@ def matscipy_wrapper(position, idx_shape, num_particles): idx_new = np.asarray(edge_list[:, : idx_shape[1]]) buffer_overflow = np.array(True) else: - idx_new = np.ones(idx_shape, dtype=np.int32) * num_particles_max + idx_new = np.ones(idx_shape, dtype=dtype_idx) * num_particles_max idx_new[:, : edge_list.shape[1]] = edge_list buffer_overflow = np.array(False) @@ -686,10 +425,13 @@ def update_fn( idx, buffer_overflow = jax.pure_callback( matscipy_wrapper, shape_out, position, neighbors.idx.shape, num_particles ) + return NeighborList( idx, position, - buffer_overflow, + neighbors.error.update(PEC.NEIGHBOR_LIST_OVERFLOW, buffer_overflow), + None, + None, None, None, None, @@ -700,7 +442,6 @@ def allocate_fn( position: jnp.ndarray, extra_capacity: int = 0, **kwargs ) -> NeighborList: num_particles = kwargs["num_particles"] - position = position[:num_particles] if position.shape[1] == 2: @@ -711,24 +452,26 @@ def allocate_fn( edge_list = matscipy_nl( "ij", cutoff=r_cutoff, positions=position, cell=box_size, pbc=pbc ) - edge_list = np.asarray(edge_list, dtype=np.int32) + edge_list = jnp.asarray(edge_list, dtype=dtype_idx) if not mask_self: # add self connection, which matscipy does not do - self_connect = np.arange(num_particles, dtype=np.int32) - self_connect = np.array([self_connect, self_connect]) - edge_list = np.concatenate((self_connect, edge_list), axis=-1) + self_connect = jnp.arange(num_particles, dtype=dtype_idx) + self_connect = jnp.array([self_connect, self_connect]) + edge_list = jnp.concatenate((self_connect, edge_list), axis=-1) # in case this is a (2,M) pair list, we pad with N and capacity_multiplier factor = capacity_multiplier * num_particles_max / num_particles res = num_particles * jnp.ones( (2, round(edge_list.shape[1] * factor + extra_capacity)), - np.int32, + dtype_idx, ) res = res.at[:, : edge_list.shape[1]].set(edge_list) return NeighborList( res, position, - jnp.array(False), + PartitionError(jnp.zeros((), jnp.uint8)), + None, + None, None, None, None, @@ -761,14 +504,52 @@ def neighbor_list( num_partitions: int = 1, pbc: jnp.ndarray = None, ) -> NeighborFn: - """Neighbor lists wrapper. + """Neighbor lists wrapper. Its arguments are mainly based on the jax-md ones. Args: - backend: The backend to use. One of "jaxmd_vmap", "jaxmd_scan", "matscipy". - - - "jaxmd_vmap": Default jax-md neighbor list. Uses vmap. Fast. - - "jaxmd_scan": Modified jax-md neighbor list. Uses scan. Memory efficient. - - "matscipy": Matscipy neighbor list. Runs on cpu, allows dynamic shapes. + displacement: A function `d(R_a, R_b)` that computes the displacement + between pairs of points. + box_size: Either a float specifying the size of the box or an array of + shape `[spatial_dim]` specifying the box size in each spatial dimension. + r_cutoff: A scalar specifying the neighborhood radius. + dr_threshold: A scalar specifying the maximum distance particles can move + before rebuilding the neighbor list. + backend: The backend to use. Can be one of: 1) ``jaxmd_vmap`` - the default + jax-md neighbor list which vectorizes the computations. 2) ``jaxmd_scan`` - + a modified jax-md neighbor list which serializes the search into + ``num_partitions`` chunks to improve the memory efficiency. 3) ``matscipy`` + - a jit-able implementation with the matscipy neighbor list backend, which + runs on CPU and takes variable number of particles smaller or equal to + ``num_particles``. + capacity_multiplier: A floating point scalar specifying the fractional + increase in maximum neighborhood occupancy we allocate compared with the + maximum in the example positions. + disable_cell_list: An optional boolean. If set to `True` then the neighbor + list is constructed using only distances. This can be useful for + debugging but should generally be left as `False`. + mask_self: An optional boolean. Determines whether points can consider + themselves to be their own neighbors. + custom_mask_function: An optional function. Takes the neighbor array + and masks selected elements. Note: The input array to the function is + `(n_particles, m)` where the index of particle 1 is in index in the first + dimension of the array, the index of particle 2 is given by the value in + the array + fractional_coordinates: An optional boolean. Specifies whether positions will + be supplied in fractional coordinates in the unit cube, :math:`[0, 1]^d`. + If this is set to True then the `box_size` will be set to `1.0` and the + cell size used in the cell list will be set to `cutoff / box_size`. + format: The format of the neighbor list; see the :meth:`NeighborListFormat` enum + for details about the different choices for formats. Defaults to `Dense`. + num_particles_max: only used with the ``matscipy`` backend. Based + on the largest particles system in a dataset. + num_partitions: only used with the ``jaxmd_scan`` backend + pbc: only used with the ``matscipy`` backend. Defines the boundary conditions + for each dimension individually. Can have shape (2,) or (3,). + **static_kwargs: kwargs that get threaded through the calculation of + example positions. + Returns: + A NeighborListFns object that contains a method to allocate a new neighbor + list and a method to update an existing neighbor list. """ assert backend in BACKENDS, f"Unknown backend {backend}" diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py new file mode 100644 index 0000000..9e2ca71 --- /dev/null +++ b/tests/test_neighbors.py @@ -0,0 +1,137 @@ +import unittest + +import numpy as np +from jax import config + +config.update("jax_enable_x64", True) +import jax.numpy as jnp +from jax import jit +from jax_md import space + +from jax_sph import partition + + +@jit +def updater(nbrs_old, r_new, **kwargs): + nbrs_new = nbrs_old.update(r_new, **kwargs) + return nbrs_new + + +class BaseTest(unittest.TestCase): + def body(self, args, backend, num_partitions, verbose=False): + r = args["r"] + box_size = args["box_size"] + cutoff = args["cutoff"] + mask_self = args["mask_self"] + target = args["target"] + + if verbose: + print(f"Start with {backend} backend and {num_partitions} partition(s)") + + N, dim = r.shape + num_particles_max = r.shape[0] + displacement_fn, _ = space.periodic(side=box_size) + neighbor_fn = partition.neighbor_list( + displacement_fn, + box_size, + r_cutoff=cutoff, + backend=backend, + dr_threshold=0.0, + capacity_multiplier=1.25, + mask_self=mask_self, + format=partition.NeighborListFormat.Sparse, + num_particles_max=num_particles_max, + num_partitions=num_partitions, + pbc=np.array([True] * dim), + ) + + nbrs = neighbor_fn.allocate(r, num_particles=N) + + if backend == "matscipy": + nbrs2 = updater(nbrs_old=nbrs, r_new=r, num_particles=N) + else: + nbrs2 = updater(nbrs, r) + mask_real = nbrs.idx[0] < N + idx_real = nbrs.idx[:, mask_real] + + if verbose: + print("Idx: \n", nbrs.idx) + print("Idx_real: \n", idx_real) + + self.assertFalse(nbrs.did_buffer_overflow, "Buffer overflow (allocate)") + self.assertFalse(nbrs2.did_buffer_overflow, "Buffer overflow (update)") + + self.assertTrue((nbrs.idx == nbrs2.idx).all(), "allocate differes from update") + + self.assertTrue( + ((nbrs.idx[0] == N) == (nbrs.idx[1] == N)).all(), "One sided edges" + ) + + self_edges_mask = idx_real[0] == idx_real[1] + if mask_self: + self.assertEqual(sum(self_edges_mask), 0.0, "Self edges b/n real particles") + else: + self_edges = idx_real[:, self_edges_mask] + self.assertEqual(len(np.unique(self_edges[0])), N, "Self edges are broken") + + # sorted edge list based on second edge row (first sort by first row) + sort_idx = np.argsort(idx_real[0]) + idx_real_sorted = idx_real[:, sort_idx] + sort_idx = np.argsort(idx_real_sorted[1]) + idx_real_sorted = idx_real_sorted[:, sort_idx] + self.assertTrue((idx_real_sorted == target).all(), "Wrong edge list") + + if verbose: + print(f"Finish with {backend} backend and {num_partitions} partition(s)") + + def cases(self, backend, num_partitions=1, verbose=False): + # Simple test with pbc and with/without self-masking + args = { + "mask_self": False, + "cutoff": 0.33, + "box_size": np.array([1.0, 1.0]), + "r": jnp.array([[0.1, 0.1], [0.1, 0.3], [0.1, 0.9], [0.6, 0.5]]), + "target": jnp.array([[0, 1, 2, 0, 1, 0, 2, 3], [0, 0, 0, 1, 1, 2, 2, 3]]), + } + self.body(args, backend, num_partitions, verbose) + + args["mask_self"] = True + args["target"] = jnp.array([[1, 2, 0, 0], [0, 0, 1, 2]]) + self.body(args, backend, num_partitions, verbose) + + # Edge case at which the scan implementation almost breaks + args = { + "mask_self": False, + "cutoff": 0.33, + "box_size": np.array([1.0, 1.0]), + "r": jnp.array( + [[0.5, 0.2], [0.2, 0.5], [0.5, 0.5], [0.8, 0.5], [0.5, 0.8]] + ), + "target": jnp.array( + [ + [0, 2, 1, 2, 0, 1, 2, 3, 4, 2, 3, 2, 4], + [0, 0, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4], + ] + ), + } + self.body(args, backend, num_partitions, verbose) + + args["mask_self"] = True + args["target"] = jnp.array([[2, 2, 0, 1, 3, 4, 2, 2], [0, 1, 2, 2, 2, 2, 3, 4]]) + self.body(args, backend, num_partitions, verbose) + + def test_vmap(self): + self.cases("jaxmd_vmap") + + def test_scan1(self): + self.cases("jaxmd_scan") + + def test_scan2(self): + self.cases("jaxmd_scan", 2) + + def test_matscipy(self): + self.cases("matscipy") + + +if __name__ == "__main__": + unittest.main() From cbed8acbb535c6ceb08eafbe9e5f2f235ab8aba8 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Fri, 7 Jun 2024 22:00:29 +0200 Subject: [PATCH 2/5] update codecov params --- .codecov.yml | 8 ++++---- .github/workflows/tests.yml | 2 +- jax_sph/simulate.py | 2 +- notebooks/iclr24_inverse.ipynb | 2 +- pyproject.toml | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index a6500d9..3da95fc 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,12 +1,12 @@ -coverage: # TODO: increase required values once proper tests are implemented - range: 2..3 # red color under 50%, yellow at 50%..70%, green over 70% +coverage: + range: 50..60 # red color under 50%, yellow at 50%..60%, green over 60% precision: 1 status: project: default: - target: 2% # coverage success only above X%. Later to be changed to e.g. 60% + target: 60% # coverage success only above X% threshold: 5% # allow the coverage to drop by X% and being a success patch: default: - target: 5% # later to be changed to e.g. 50% + target: 50% threshold: 5% \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index db1eb85..4b4b0ec 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,7 +35,7 @@ jobs: run: | .venv/bin/pytest --cov-report=xml - name: Upload coverage report to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} file: ./coverage.xml diff --git a/jax_sph/simulate.py b/jax_sph/simulate.py index 1ef7b6b..01e3263 100644 --- a/jax_sph/simulate.py +++ b/jax_sph/simulate.py @@ -86,7 +86,7 @@ def simulate(cfg: DictConfig): # Instantiate advance function for our use case advance = si_euler(cfg.solver.tvf, forward, shift_fn, bc_fn) - advance = advance if cfg.no_jit else jit(advance) # TODO: is this even needed? + advance = advance if cfg.no_jit else jit(advance) print("#" * 79, "\nStarting a JAX-SPH run with the following configs:") print(OmegaConf.to_yaml(cfg)) diff --git a/notebooks/iclr24_inverse.ipynb b/notebooks/iclr24_inverse.ipynb index dd75129..a80ab31 100644 --- a/notebooks/iclr24_inverse.ipynb +++ b/notebooks/iclr24_inverse.ipynb @@ -193,7 +193,7 @@ "\n", " # Instantiate advance function for our use case\n", " advance = si_euler(cfg.solver.tvf, forward, shift_fn, bc_fn)\n", - " advance = advance if cfg.no_jit else jit(advance) # TODO: is this even needed?\n", + " advance = advance if cfg.no_jit else jit(advance)\n", "\n", " # compile kernel and initialize accelerations\n", " _state, _neighbors = advance(0.0, optim_init_state, neighbors)\n", diff --git a/pyproject.toml b/pyproject.toml index bf99508..7484463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ select = [ [tool.pytest.ini_options] testpaths = "tests/" -addopts = "--cov=jax_sph --cov-fail-under=1" # TODO: increase later e.g. to 50 +addopts = "--cov=jax_sph --cov-fail-under=50" filterwarnings = [ # ignore all deprecation warnings except from jax-sph "ignore::DeprecationWarning:^(?!.*jax_sph).*" From 1fae780cc37dd1e7b165401727e0c64681612005 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Fri, 7 Jun 2024 22:46:57 +0200 Subject: [PATCH 3/5] add badges --- README.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6f928fb..ebdac7c 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,26 @@ # JAX-SPH: A Differentiable Smoothed Particle Hydrodynamics Framework -![HT_T.gif](https://s9.gifyu.com/images/SUwUD.gif) +
+ +[![Paper](http://img.shields.io/badge/paper-arxiv.2403.04750-B31B1B.svg)](https://arxiv.org/abs/2403.04750) +[![Docs](https://img.shields.io/readthedocs/jax-sph/latest)](https://jax-sph.readthedocs.io/en/latest/index.html) +[![PyPI - Version](https://img.shields.io/pypi/v/jax-sph)](https://pypi.org/project/jax-sph/) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/tutorial.ipynb) +[![Discord](https://img.shields.io/badge/Discord-%235865F2?logo=discord&logoColor=white)](https://discord.gg/Ds8jRZ78hU) + +[![Tests](https://github.com/tumaer/jax-sph/actions/workflows/tests.yml/badge.svg)](https://github.com/tumaer/jax-sph/actions/workflows/tests.yml) +[![CodeCov](https://codecov.io/gh/tumaer/jax-sph/graph/badge.svg?token=ULMGSY71R1)](https://codecov.io/gh/tumaer/jax-sph) +[![License](https://img.shields.io/pypi/l/jax-sph)](https://github.com/tumaer/jax-sph/blob/main/LICENSE) + +
JAX-SPH [(Toshev et al., 2024)](https://arxiv.org/abs/2403.04750) is a modular JAX-based weakly compressible SPH framework, which implements the following SPH routines: - Standard SPH [(Adami et al., 2012)](https://www.sciencedirect.com/science/article/pii/S002199911200229X) - Transport velocity SPH [(Adami et al., 2013)](https://www.sciencedirect.com/science/article/pii/S002199911300096X) - Riemann SPH [(Zhang et al., 2017)](https://www.sciencedirect.com/science/article/abs/pii/S0021999117300438) +![HT_T.gif](https://s9.gifyu.com/images/SUwUD.gif) + ## Installation ### Standalone library From 65efb45953225f8767749aec168ab32d91222d12 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Fri, 7 Jun 2024 22:54:13 +0200 Subject: [PATCH 4/5] poetry add matscipy --- poetry.lock | 182 ++++++++++++++++++++++++++++++++++--------------- pyproject.toml | 1 + 2 files changed, 128 insertions(+), 55 deletions(-) diff --git a/poetry.lock b/poetry.lock index bb2d1ac..a8bf95d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -43,6 +43,27 @@ files = [ {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, ] +[[package]] +name = "ase" +version = "3.23.0" +description = "Atomic Simulation Environment" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ase-3.23.0-py3-none-any.whl", hash = "sha256:52060410e720b6c701ea1ebecfdeb5ec6f9c1c63edc7cee68c15bd66d226dd43"}, + {file = "ase-3.23.0.tar.gz", hash = "sha256:91a2aa31d89bd90b0efdfe4a7e84264f32828b2abfc9f38e65e041ad76fec8ae"}, +] + +[package.dependencies] +matplotlib = ">=3.3.4" +numpy = ">=1.18.5" +scipy = ">=1.6.0" + +[package.extras] +docs = ["pillow", "sphinx", "sphinx-rtd-theme"] +spglib = ["spglib (>=1.9)"] +test = ["pytest (>=7.0.0)", "pytest-xdist (>=2.1.0)"] + [[package]] name = "asttokens" version = "2.4.1" @@ -1404,6 +1425,17 @@ jax = ">=0.4.26" jaxtyping = ">=0.2.20" typing-extensions = ">=4.5.0" +[[package]] +name = "looseversion" +version = "1.3.0" +description = "Version numbering for anarchists and software realists" +optional = false +python-versions = "*" +files = [ + {file = "looseversion-1.3.0-py2.py3-none-any.whl", hash = "sha256:781ef477b45946fc03dd4c84ea87734b21137ecda0e1e122bcb3c8d16d2a56e0"}, + {file = "looseversion-1.3.0.tar.gz", hash = "sha256:ebde65f3f6bb9531a81016c6fef3eb95a61181adc47b7f949e9c0ea47911669e"}, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -1564,6 +1596,46 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "matscipy" +version = "1.0.0" +description = "Generic Python Materials Science tools" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "matscipy-1.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4038a0c917644a9144b4962c54fc77f3be8cccba2172dc79432a5c6c4a27b872"}, + {file = "matscipy-1.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:8b9014ad64f238af720631c955ecb1e8a6ec80cbc892f8f2b5c4a6ce1b3ced99"}, + {file = "matscipy-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62b0f0699b3c7a59ad45b318cb5789595b24532e93d5ee49fce2c29529e4a010"}, + {file = "matscipy-1.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:cf714d19a470bd4c29a23703ccac51bfca81ba76db5464d24e2bb96fddaf91d9"}, + {file = "matscipy-1.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:56df29e8f55bbe7f6332dac1468456158ac954c01adc6b51a58dd622f9a99968"}, + {file = "matscipy-1.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:e0aaefefd0685326300bcb0e35a47b2548502126ee521a3d5fbd07485d5c08b8"}, + {file = "matscipy-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7ef29f5153f6fe9ebd147abcdabd4cd4050c1da94ca2d57ed2681c61b5ac808"}, + {file = "matscipy-1.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:0adedd5a5f95cfff9f43d1a6384aa3c066ca566b5068864e5155f6bd7227b3d9"}, + {file = "matscipy-1.0.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d6b86bf919e406163939e68c3f31bdd16b0b7954924f659da246f5b97180fcfb"}, + {file = "matscipy-1.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45e7994ec4f91d6e038507fc07aec491f75a4ffa8e808aa8807b80be61eacee6"}, + {file = "matscipy-1.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:04ca7ba6ee820feb2ae3f9cf4c588de6e32a012270e510fd4685245982152d13"}, + {file = "matscipy-1.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e50ad1df9862234780951c2e5034738279883b6bc7f94045310e5ec76580a793"}, + {file = "matscipy-1.0.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d9ea621f0ceff38e189da17d51eeb9a06395fbeb6b8cd660f6d22e9fd46e6d28"}, + {file = "matscipy-1.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43977b190958bf69652eae419da0c931512584bc78df801c76777a55783c0e4c"}, + {file = "matscipy-1.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:db63cb5601466b85ea0017dab53b717fdf4910ddf21bc2d4a772901e2916351b"}, + {file = "matscipy-1.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f4a06d01f9c0a5c782bc16edcfb3334ef813c925c1561bb1d808b4165d79b69f"}, + {file = "matscipy-1.0.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:99ea22dca7a2798c3334fa28ee5acaf4c7dadd3df16c37d4a7fe1ecc60d1cdab"}, + {file = "matscipy-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c9032db508f37b052cb6d4b687f59b3736f4a8a3fee9a82d1fd6fca947d6f48"}, + {file = "matscipy-1.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:20bc5962c15addcd9624499805444d6fe5dcfac33d8d65aa5009f8a671c00015"}, + {file = "matscipy-1.0.0.tar.gz", hash = "sha256:50d896bf7527dc0acbba8ed548dbd247a77a129faf592a3e47a9fa128ca19483"}, +] + +[package.dependencies] +ase = ">=3.16.0" +looseversion = "*" +numpy = ">=1.16.0" +scipy = ">=1.2.3" + +[package.extras] +cli = ["argcomplete"] +docs = ["atomman", "jupytext", "myst_nb", "nglview", "nglview (==3.0.8)", "numpydoc", "ovito", "pydata-sphinx-theme", "sphinx", "sphinx_copybutton", "sphinx_rtd_theme", "sphinxcontrib-spelling"] +test = ["atomman", "ovito", "pytest", "pytest-subtests", "sympy"] + [[package]] name = "mdurl" version = "0.1.2" @@ -2115,13 +2187,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pooch" -version = "1.8.1" -description = "\"Pooch manages your Python library's sample data files: it automatically downloads and stores them in a local directory, with support for versioning and corruption checks.\"" +version = "1.8.2" +description = "A friend to fetch your data files" optional = false python-versions = ">=3.7" files = [ - {file = "pooch-1.8.1-py3-none-any.whl", hash = "sha256:6b56611ac320c239faece1ac51a60b25796792599ce5c0b1bb87bf01df55e0a9"}, - {file = "pooch-1.8.1.tar.gz", hash = "sha256:27ef63097dd9a6e4f9d2694f5cfbf2f0a5defa44fccafec08d601e731d746270"}, + {file = "pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47"}, + {file = "pooch-1.8.2.tar.gz", hash = "sha256:76561f0de68a01da4df6af38e9955c4c9d1a5c90da73f7e40276a5728ec83d10"}, ] [package.dependencies] @@ -2168,22 +2240,22 @@ wcwidth = "*" [[package]] name = "protobuf" -version = "5.27.0" +version = "5.27.1" description = "" optional = false python-versions = ">=3.8" files = [ - {file = "protobuf-5.27.0-cp310-abi3-win32.whl", hash = "sha256:2f83bf341d925650d550b8932b71763321d782529ac0eaf278f5242f513cc04e"}, - {file = "protobuf-5.27.0-cp310-abi3-win_amd64.whl", hash = "sha256:b276e3f477ea1eebff3c2e1515136cfcff5ac14519c45f9b4aa2f6a87ea627c4"}, - {file = "protobuf-5.27.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:744489f77c29174328d32f8921566fb0f7080a2f064c5137b9d6f4b790f9e0c1"}, - {file = "protobuf-5.27.0-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:f51f33d305e18646f03acfdb343aac15b8115235af98bc9f844bf9446573827b"}, - {file = "protobuf-5.27.0-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:56937f97ae0dcf4e220ff2abb1456c51a334144c9960b23597f044ce99c29c89"}, - {file = "protobuf-5.27.0-cp38-cp38-win32.whl", hash = "sha256:a17f4d664ea868102feaa30a674542255f9f4bf835d943d588440d1f49a3ed15"}, - {file = "protobuf-5.27.0-cp38-cp38-win_amd64.whl", hash = "sha256:aabbbcf794fbb4c692ff14ce06780a66d04758435717107c387f12fb477bf0d8"}, - {file = "protobuf-5.27.0-cp39-cp39-win32.whl", hash = "sha256:587be23f1212da7a14a6c65fd61995f8ef35779d4aea9e36aad81f5f3b80aec5"}, - {file = "protobuf-5.27.0-cp39-cp39-win_amd64.whl", hash = "sha256:7cb65fc8fba680b27cf7a07678084c6e68ee13cab7cace734954c25a43da6d0f"}, - {file = "protobuf-5.27.0-py3-none-any.whl", hash = "sha256:673ad60f1536b394b4fa0bcd3146a4130fcad85bfe3b60eaa86d6a0ace0fa374"}, - {file = "protobuf-5.27.0.tar.gz", hash = "sha256:07f2b9a15255e3cf3f137d884af7972407b556a7a220912b252f26dc3121e6bf"}, + {file = "protobuf-5.27.1-cp310-abi3-win32.whl", hash = "sha256:3adc15ec0ff35c5b2d0992f9345b04a540c1e73bfee3ff1643db43cc1d734333"}, + {file = "protobuf-5.27.1-cp310-abi3-win_amd64.whl", hash = "sha256:25236b69ab4ce1bec413fd4b68a15ef8141794427e0b4dc173e9d5d9dffc3bcd"}, + {file = "protobuf-5.27.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4e38fc29d7df32e01a41cf118b5a968b1efd46b9c41ff515234e794011c78b17"}, + {file = "protobuf-5.27.1-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:917ed03c3eb8a2d51c3496359f5b53b4e4b7e40edfbdd3d3f34336e0eef6825a"}, + {file = "protobuf-5.27.1-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:ee52874a9e69a30271649be88ecbe69d374232e8fd0b4e4b0aaaa87f429f1631"}, + {file = "protobuf-5.27.1-cp38-cp38-win32.whl", hash = "sha256:7a97b9c5aed86b9ca289eb5148df6c208ab5bb6906930590961e08f097258107"}, + {file = "protobuf-5.27.1-cp38-cp38-win_amd64.whl", hash = "sha256:f6abd0f69968792da7460d3c2cfa7d94fd74e1c21df321eb6345b963f9ec3d8d"}, + {file = "protobuf-5.27.1-cp39-cp39-win32.whl", hash = "sha256:dfddb7537f789002cc4eb00752c92e67885badcc7005566f2c5de9d969d3282d"}, + {file = "protobuf-5.27.1-cp39-cp39-win_amd64.whl", hash = "sha256:39309898b912ca6febb0084ea912e976482834f401be35840a008da12d189340"}, + {file = "protobuf-5.27.1-py3-none-any.whl", hash = "sha256:4ac7249a1530a2ed50e24201d6630125ced04b30619262f06224616e0030b6cf"}, + {file = "protobuf-5.27.1.tar.gz", hash = "sha256:df5e5b8e39b7d1c25b186ffdf9f44f40f810bbcc9d2b71d9d3156fee5a9adf15"}, ] [[package]] @@ -2345,18 +2417,18 @@ files = [ [[package]] name = "pyvista" -version = "0.43.8" +version = "0.43.9" description = "Easier Pythonic interface to VTK" optional = false python-versions = ">=3.8" files = [ - {file = "pyvista-0.43.8-py3-none-any.whl", hash = "sha256:8b0769f6ac7a8dc93137ae659556e8e89de54b9a928eb4bd448c4c7c4d484cf7"}, - {file = "pyvista-0.43.8.tar.gz", hash = "sha256:b9220753ae94fb8ca3047d291a706a4046b06659016c0000c184b5f24504f8d0"}, + {file = "pyvista-0.43.9-py3-none-any.whl", hash = "sha256:f9f23baa74a8e2a4181c260e4c742ede00c73a7cc46e5275152f82a736f12c95"}, + {file = "pyvista-0.43.9.tar.gz", hash = "sha256:87d55ffe0efa6a8b15ca55f9f07f49a81980522c4a3ada29ca5caa2ab31179b7"}, ] [package.dependencies] matplotlib = ">=3.0.1" -numpy = ">=1.21.0" +numpy = ">=1.21.0,<2.0.0" pillow = "*" pooch = "*" scooby = ">=0.5.1" @@ -2591,28 +2663,28 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "ruff" -version = "0.4.7" +version = "0.4.8" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.4.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e089371c67892a73b6bb1525608e89a2aca1b77b5440acf7a71dda5dac958f9e"}, - {file = "ruff-0.4.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:10f973d521d910e5f9c72ab27e409e839089f955be8a4c8826601a6323a89753"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59c3d110970001dfa494bcd95478e62286c751126dfb15c3c46e7915fc49694f"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa9773c6c00f4958f73b317bc0fd125295110c3776089f6ef318f4b775f0abe4"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07fc80bbb61e42b3b23b10fda6a2a0f5a067f810180a3760c5ef1b456c21b9db"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:fa4dafe3fe66d90e2e2b63fa1591dd6e3f090ca2128daa0be33db894e6c18648"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7c0083febdec17571455903b184a10026603a1de078428ba155e7ce9358c5f6"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad1b20e66a44057c326168437d680a2166c177c939346b19c0d6b08a62a37589"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbf5d818553add7511c38b05532d94a407f499d1a76ebb0cad0374e32bc67202"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:50e9651578b629baec3d1513b2534de0ac7ed7753e1382272b8d609997e27e83"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8874a9df7766cb956b218a0a239e0a5d23d9e843e4da1e113ae1d27ee420877a"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b9de9a6e49f7d529decd09381c0860c3f82fa0b0ea00ea78409b785d2308a567"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:13a1768b0691619822ae6d446132dbdfd568b700ecd3652b20d4e8bc1e498f78"}, - {file = "ruff-0.4.7-py3-none-win32.whl", hash = "sha256:769e5a51df61e07e887b81e6f039e7ed3573316ab7dd9f635c5afaa310e4030e"}, - {file = "ruff-0.4.7-py3-none-win_amd64.whl", hash = "sha256:9e3ab684ad403a9ed1226894c32c3ab9c2e0718440f6f50c7c5829932bc9e054"}, - {file = "ruff-0.4.7-py3-none-win_arm64.whl", hash = "sha256:10f2204b9a613988e3484194c2c9e96a22079206b22b787605c255f130db5ed7"}, - {file = "ruff-0.4.7.tar.gz", hash = "sha256:2331d2b051dc77a289a653fcc6a42cce357087c5975738157cd966590b18b5e1"}, + {file = "ruff-0.4.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7663a6d78f6adb0eab270fa9cf1ff2d28618ca3a652b60f2a234d92b9ec89066"}, + {file = "ruff-0.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eeceb78da8afb6de0ddada93112869852d04f1cd0f6b80fe464fd4e35c330913"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aad360893e92486662ef3be0a339c5ca3c1b109e0134fcd37d534d4be9fb8de3"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:284c2e3f3396fb05f5f803c9fffb53ebbe09a3ebe7dda2929ed8d73ded736deb"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7354f921e3fbe04d2a62d46707e569f9315e1a613307f7311a935743c51a764"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:72584676164e15a68a15778fd1b17c28a519e7a0622161eb2debdcdabdc71883"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9678d5c9b43315f323af2233a04d747409d1e3aa6789620083a82d1066a35199"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704977a658131651a22b5ebeb28b717ef42ac6ee3b11e91dc87b633b5d83142b"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d05f8d6f0c3cce5026cecd83b7a143dcad503045857bc49662f736437380ad45"}, + {file = "ruff-0.4.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6ea874950daca5697309d976c9afba830d3bf0ed66887481d6bca1673fc5b66a"}, + {file = "ruff-0.4.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fc95aac2943ddf360376be9aa3107c8cf9640083940a8c5bd824be692d2216dc"}, + {file = "ruff-0.4.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:384154a1c3f4bf537bac69f33720957ee49ac8d484bfc91720cc94172026ceed"}, + {file = "ruff-0.4.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e9d5ce97cacc99878aa0d084c626a15cd21e6b3d53fd6f9112b7fc485918e1fa"}, + {file = "ruff-0.4.8-py3-none-win32.whl", hash = "sha256:6d795d7639212c2dfd01991259460101c22aabf420d9b943f153ab9d9706e6a9"}, + {file = "ruff-0.4.8-py3-none-win_amd64.whl", hash = "sha256:e14a3a095d07560a9d6769a72f781d73259655919d9b396c650fc98a8157555d"}, + {file = "ruff-0.4.8-py3-none-win_arm64.whl", hash = "sha256:14019a06dbe29b608f6b7cbcec300e3170a8d86efaddb7b23405cb7f7dcaf780"}, + {file = "ruff-0.4.8.tar.gz", hash = "sha256:16d717b1d57b2e2fd68bd0bf80fb43931b79d05a7131aa477d66fc40fbd86268"}, ] [[package]] @@ -2978,22 +3050,22 @@ files = [ [[package]] name = "tornado" -version = "6.4" +version = "6.4.1" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ - {file = "tornado-6.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:02ccefc7d8211e5a7f9e8bc3f9e5b0ad6262ba2fbb683a6443ecc804e5224ce0"}, - {file = "tornado-6.4-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:27787de946a9cffd63ce5814c33f734c627a87072ec7eed71f7fc4417bb16263"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7894c581ecdcf91666a0912f18ce5e757213999e183ebfc2c3fdbf4d5bd764e"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e43bc2e5370a6a8e413e1e1cd0c91bedc5bd62a74a532371042a18ef19e10579"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fd03192e287fbd0899dd8f81c6fb9cbbc69194d2074b38f384cb6fa72b80e9c2"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:88b84956273fbd73420e6d4b8d5ccbe913c65d31351b4c004ae362eba06e1f78"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:71ddfc23a0e03ef2df1c1397d859868d158c8276a0603b96cf86892bff58149f"}, - {file = "tornado-6.4-cp38-abi3-win32.whl", hash = "sha256:6f8a6c77900f5ae93d8b4ae1196472d0ccc2775cc1dfdc9e7727889145c45052"}, - {file = "tornado-6.4-cp38-abi3-win_amd64.whl", hash = "sha256:10aeaa8006333433da48dec9fe417877f8bcc21f48dda8d661ae79da357b2a63"}, - {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"}, + {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8"}, + {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6d5ce3437e18a2b66fbadb183c1d3364fb03f2be71299e7d10dbeeb69f4b2a14"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e20b9113cd7293f164dc46fffb13535266e713cdb87bd2d15ddb336e96cfc4"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae50a504a740365267b2a8d1a90c9fbc86b780a39170feca9bcc1787ff80842"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:613bf4ddf5c7a95509218b149b555621497a6cc0d46ac341b30bd9ec19eac7f3"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:454db8a7ecfcf2ff6042dde58404164d969b6f5d58b926da15e6b23817950fc4"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a02a08cc7a9314b006f653ce40483b9b3c12cda222d6a46d4ac63bb6c9057698"}, + {file = "tornado-6.4.1-cp38-abi3-win32.whl", hash = "sha256:d9a566c40b89757c9aa8e6f032bcdb8ca8795d7c1a9762910c722b1635c9de4d"}, + {file = "tornado-6.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:b24b8982ed444378d7f21d563f4180a2de31ced9d8d84443907a0a64da2072e7"}, + {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"}, ] [[package]] @@ -3028,13 +3100,13 @@ test = ["mypy", "pytest", "typing-extensions"] [[package]] name = "typing-extensions" -version = "4.12.1" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.1-py3-none-any.whl", hash = "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a"}, - {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -3155,4 +3227,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.11" -content-hash = "4b2c24b5b351b69f69d5c1305178c8711b0ca1188f8f31a63ed6d1ecd02885de" +content-hash = "c5b1bbcfbb18730f6e573f9bbd35ee80e2be5e905618a17c3a465d58b0aa04ac" diff --git a/pyproject.toml b/pyproject.toml index 7484463..87d4525 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ ruff = ">=0.1.8" [tool.poetry.group.temp.dependencies] ott-jax = ">=0.4.2" ipykernel = ">=6.25.1" +matscipy = ">=0.8.0" [tool.poetry.group.docs.dependencies] sphinx = "7.2.6" From bb313a8f0fc039798656645461116d607b953bfa Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Sat, 8 Jun 2024 00:28:49 +0200 Subject: [PATCH 5/5] bump version to 0.0.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 87d4525..051ecc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "jax-sph" -version = "0.0.0a1" +version = "0.0.1" description = "JAX-SPH: Smoothed Particle Hydrodynamics in JAX" authors = ["Artur Toshev ",] maintainers = ["Artur Toshev ",]