Skip to content

Commit ec40b75

Browse files
authored
Improved plot_splits for time series splits (#1113)
* fixed test_end_idx calculation * added legend title * fixed intersection of xticklabels for Sample index * added changes to changelog * Update CHANGELOG.md * removed redundancy * start x-axis at the origin of coordinates * fixed CV iteration labels * revert idx change * Update CHANGELOG.md * add groups legend * fixed cmap in groups legend
1 parent c229178 commit ec40b75

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

docs/sources/CHANGELOG.md

+13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ The CHANGELOG for the current development version is available at
77

88
---
99

10+
### Version 0.23.3 (tbd)
11+
12+
##### Downloads
13+
...
14+
15+
##### New Features and Enhancements
16+
17+
Files updated:
18+
- ['mlxtend.evaluate.time_series.plot_splits'](https://github.com/rasbt/mlxtend/blob/master/mlxtend/evaluate/time_series.py)
19+
- Improved `plot_splits` for better visualization of time series splits
20+
21+
##### Changes
22+
...
1023

1124
### Version 0.23.2 (5 Nov 2024)
1225

mlxtend/evaluate/time_series.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None
290290
s=marker_size,
291291
)
292292

293-
yticklabels = list(range(n_splits)) + ["group"]
293+
yticklabels = list(range(1, n_splits + 1)) + ["group"]
294294
ax.set(
295295
yticks=np.arange(n_splits + 1) + 0.5,
296296
yticklabels=yticklabels,
@@ -299,15 +299,34 @@ def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None
299299
xlim=[-0.5, len(indices) - 0.5],
300300
)
301301

302-
ax.legend(
302+
legend_splits = ax.legend(
303303
[Patch(color=cmap_cv(0.2)), Patch(color=cmap_cv(0.8))],
304304
["Training set", "Testing set"],
305-
loc=(1.02, 0.8),
305+
title="Data Splits",
306+
loc="upper right",
307+
fontsize=13,
308+
)
309+
310+
ax.add_artist(legend_splits)
311+
312+
group_labels = [f"{group}" for group in np.unique(groups)]
313+
cmap = plt.cm.get_cmap("tab20", len(group_labels))
314+
315+
unique_patches = {}
316+
for i, group in enumerate(np.unique(groups)):
317+
unique_patches[group] = Patch(color=cmap(i), label=f"{group}")
318+
319+
ax.legend(
320+
handles=list(unique_patches.values()),
321+
title="Groups",
322+
loc="center left",
323+
bbox_to_anchor=(1.02, 0.5),
306324
fontsize=13,
307325
)
308326

309327
ax.set_title("{}\n{}".format(type(cv).__name__, cv_args), fontsize=15)
310-
ax.xaxis.set_major_locator(MaxNLocator(min_n_ticks=len(X), integer=True))
328+
ax.set_xlim(0, len(X))
329+
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
311330
ax.set_xlabel(xlabel="Sample index", fontsize=13)
312331
ax.set_ylabel(ylabel="CV iteration", fontsize=13)
313332
ax.tick_params(axis="both", which="major", labelsize=13)

0 commit comments

Comments
 (0)