@@ -290,7 +290,7 @@ def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None
290
290
s = marker_size ,
291
291
)
292
292
293
- yticklabels = list (range (n_splits )) + ["group" ]
293
+ yticklabels = list (range (1 , n_splits + 1 )) + ["group" ]
294
294
ax .set (
295
295
yticks = np .arange (n_splits + 1 ) + 0.5 ,
296
296
yticklabels = yticklabels ,
@@ -299,15 +299,34 @@ def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None
299
299
xlim = [- 0.5 , len (indices ) - 0.5 ],
300
300
)
301
301
302
- ax .legend (
302
+ legend_splits = ax .legend (
303
303
[Patch (color = cmap_cv (0.2 )), Patch (color = cmap_cv (0.8 ))],
304
304
["Training set" , "Testing set" ],
305
- loc = (1.02 , 0.8 ),
305
+ title = "Data Splits" ,
306
+ loc = "upper right" ,
307
+ fontsize = 13 ,
308
+ )
309
+
310
+ ax .add_artist (legend_splits )
311
+
312
+ group_labels = [f"{ group } " for group in np .unique (groups )]
313
+ cmap = plt .cm .get_cmap ("tab20" , len (group_labels ))
314
+
315
+ unique_patches = {}
316
+ for i , group in enumerate (np .unique (groups )):
317
+ unique_patches [group ] = Patch (color = cmap (i ), label = f"{ group } " )
318
+
319
+ ax .legend (
320
+ handles = list (unique_patches .values ()),
321
+ title = "Groups" ,
322
+ loc = "center left" ,
323
+ bbox_to_anchor = (1.02 , 0.5 ),
306
324
fontsize = 13 ,
307
325
)
308
326
309
327
ax .set_title ("{}\n {}" .format (type (cv ).__name__ , cv_args ), fontsize = 15 )
310
- ax .xaxis .set_major_locator (MaxNLocator (min_n_ticks = len (X ), integer = True ))
328
+ ax .set_xlim (0 , len (X ))
329
+ ax .xaxis .set_major_locator (MaxNLocator (integer = True ))
311
330
ax .set_xlabel (xlabel = "Sample index" , fontsize = 13 )
312
331
ax .set_ylabel (ylabel = "CV iteration" , fontsize = 13 )
313
332
ax .tick_params (axis = "both" , which = "major" , labelsize = 13 )
0 commit comments