Skip to content

Commit

Permalink
Merge pull request #35 from BiAPoL/jo-mueller/add-alpha
Browse files Browse the repository at this point in the history
removed alpha from base class
  • Loading branch information
jo-mueller authored Jan 30, 2025
2 parents 896112b + 06cbe5c commit 58aadba
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 55 deletions.
174 changes: 144 additions & 30 deletions docs/examples/scatter_artist_example.ipynb

Large diffs are not rendered by default.

16 changes: 12 additions & 4 deletions src/biaplotter/_tests/test_artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,26 @@ def on_color_indices_changed(color_indices):
assert np.all(sizes == np.linspace(1, 10, size))

# Test size reset when new data is set
new_data = np.random.rand(size, 2)
scatter.data = new_data
assert scatter.size == 50.0 # that's the default
scatter.data = np.random.rand(size//2, 2)
assert np.all(scatter.size == 50.0) # that's the default
sizes = scatter._scatter.get_sizes()
assert np.all(sizes == 50.0)

# test alpha
scatter.alpha = 0.5
assert np.all(scatter._scatter.get_alpha() == 0.5)

# test alpha reset when new data is set
scatter.data = np.random.rand(size, 2)
assert np.all(scatter._scatter.get_alpha() == 1.0)

# test handling NaNs
data_with_nans = np.copy(new_data)
data_with_nans = np.copy(scatter.data)
data_with_nans[0, 0] = np.nan
scatter.data = data_with_nans



def test_histogram2d():
# Inputs
threshold = 20
Expand Down
70 changes: 49 additions & 21 deletions src/biaplotter/artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_col
#: Stores the scatter plot matplotlib object
self._scatter = None
self.data = data
self._alpha = 1 # Default alpha
self._size = 50 # Default size
self.draw() # Initial draw of the scatter plot

Expand All @@ -152,38 +153,41 @@ def data(self) -> np.ndarray:

@data.setter
def data(self, value: np.ndarray):
"""Sets the data for the scatter plot, updating the display as needed."""
if value is None:
return
if len(value) == 0:
"""Sets the data for the scatter plot, resetting other properties to defaults."""
if value is None or len(value) == 0:
return

if self._data is not None:
data_length_changed = len(value) != len(self._data)
else:
data_length_changed = True
self._data = value
# emit signal

# Emit the data changed signal
self.data_changed_signal.emit(self._data)
if self._scatter is None:
self._scatter = self.ax.scatter(value[:, 0], value[:, 1], s=self._size)
self.color_indices = 0 # Set default color index
else:
# If the scatter plot already exists, just update its data
self._scatter.set_offsets(value)

if self._color_indices is None:
self.color_indices = 0 # Set default color index
if self._scatter is None or data_length_changed:
# Create the scatter plot if it doesn't exist yet
if self._scatter is not None:
self._scatter.remove()

# Create a new scatter plot with the updated data
self._scatter = self.ax.scatter(self._data[:, 0], self._data[:, 1])
self.size = 50 # Default size
self.alpha = 1 # Default alpha
self.color_indices = np.zeros(len(value), dtype=int) # Default color indices
else:
# Update colors if color indices are set, resize if data shape has changed
color_indices_size = len(self._color_indices)
color_indices = np.resize(self._color_indices, self._data.shape[0])
if len(color_indices) > color_indices_size:
# fill with zeros where new data is larger
color_indices[color_indices_size:] = 0
self.color_indices = color_indices
self.size = 50
self._scatter.set_offsets(value) # somehow resets the size and alpha
self.color_indices = self._color_indices
self.size = self._size
self.alpha = self._alpha

x_margin = 0.05 * (np.nanmax(value[:, 0]) - np.nanmin(value[:, 0]))
y_margin = 0.05 * (np.nanmax(value[:, 1]) - np.nanmin(value[:, 1]))
self.ax.set_xlim(np.nanmin(value[:, 0]) - x_margin, np.nanmax(value[:, 0]) + x_margin)
self.ax.set_ylim(np.nanmin(value[:, 1]) - y_margin, np.nanmax(value[:, 1]) + y_margin)

# Redraw the plot
self.draw()

@property
Expand Down Expand Up @@ -244,6 +248,30 @@ def color_indices(self, indices: np.ndarray):
self.color_indices_changed_signal.emit(self._color_indices)
self.draw()

@property
def alpha(self) -> Union[float, np.ndarray]:
"""Gets or sets the alpha value of the scatter plot.
Triggers a draw idle command.
Returns
-------
alpha : float
alpha value of the scatter plot.
"""
return self._scatter.get_alpha()

@alpha.setter
def alpha(self, value: Union[float, np.ndarray]):
"""Sets the alpha value of the scatter plot."""
self._alpha = value

if np.isscalar(value):
value = np.ones(len(self._data)) * value
if self._scatter is not None:
self._scatter.set_alpha(value)
self.draw()

@property
def size(self) -> Union[float, np.ndarray]:
"""Gets or sets the size of the points in the scatter plot.
Expand Down

0 comments on commit 58aadba

Please sign in to comment.