diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d446c5d..4d7ef92 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,6 +20,8 @@ jobs: matrix: version: - '1.8' + - '1' + os: [ubuntu-latest, windows-latest, macOS-latest] arch: - x64 diff --git a/.gitignore b/.gitignore index 11d80c3..59c0e0f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /Manifest.toml +.CondaPkg/* .CondaPkg diff --git a/Project.toml b/Project.toml index 529b9cb..1d24159 100644 --- a/Project.toml +++ b/Project.toml @@ -4,10 +4,13 @@ authors = ["Essam and contributors"] version = "1.0.0-DEV" [deps] +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] MLJBase = "0.21" @@ -17,9 +20,10 @@ julia = "1.6" [extras] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Imbalance = "c709b415-507b-45b7-9a3d-1767c89fde68" -MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Imbalance", "DataFrames", "MLJLinearModels", "MLJModels"] +test = ["Test", "Imbalance", "DataFrames", "MLJLinearModels", "MLJModels", "Tables"] diff --git a/examples/BalancedBagging.ipynb b/examples/BalancedBagging.ipynb new file mode 100644 index 0000000..0304c5b --- /dev/null +++ b/examples/BalancedBagging.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ENV[\"JULIA_PKG_SERVER\"] = \"\"\n", + "using Pkg\n", + "Pkg.activate(@__DIR__)\n", + "Pkg.instantiate()\n", + "\n", + "\n", + "using MLJBalancing\n", + "using Imbalance\n", + "using MLJ\n", + "using Random" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((Column1 = [0.9695150609084499, 0.012898301755861596, 0.7555027304121053, 0.3467415729179013, 0.35969402837473463, 0.2601876747805505, 0.9522580699968279, 0.06304475092339623, 0.18909001622655808, 0.19934942931986965 … 0.021532597906190776, 0.8482825697641306, 0.10773487816863903, 0.32189982199036116, 0.12662208474317038, 0.28529465447429614, 0.2907506630258835, 0.36872799387588473, 0.061489791166806085, 0.45645058368583713], Column2 = [0.06546916714160167, 0.7243956502957003, 0.5183099801474415, 0.7555562860508294, 0.11226218114407538, 0.9135150277876691, 0.8739421974558176, 0.2268482788660101, 0.580604436651146, 0.4142252330250549 … 0.6517425913240111, 0.01713263102740481, 0.7175499403837856, 0.7362894157420817, 0.24893665902538054, 0.41499951381631595, 0.2159527717429719, 0.8966879835264249, 0.87252430655793, 0.41461921031276117], Column3 = [0.5939320702328891, 0.19329886972497456, 0.04656947038518311, 0.22095698685781184, 0.678807659662497, 0.12720198818430306, 0.6795750371448686, 0.9314917999820301, 0.22920734893984274, 0.5148148980955375 … 0.55049773593343, 0.038576459283091946, 0.27765727942909757, 0.2753072414696357, 0.8823620780359746, 0.44831794170895023, 0.9073846432163745, 0.4648550947905655, 0.311984726769037, 0.25829997798611304], Column4 = [0.12253944650540982, 0.8259140842535423, 0.4034477332184384, 0.5279399406265695, 0.5579944087437719, 0.24650366028608328, 0.6874897000162434, 0.23391406844015605, 0.5641254897013973, 0.6250622796341656 … 0.21708181942178983, 0.35224683896541464, 0.8444113778983325, 0.4547214584884428, 0.13508852017592232, 0.9510137735662383, 0.5723463533029658, 0.626377972762265, 0.7854013810594317, 0.15394691114473347], Column5 = [0.47958743625921163, 0.45779753417165514, 0.6367059235247621, 0.8601116026079643, 0.3334020182022719, 0.41593698717526373, 0.13208968772625174, 0.16951044109747648, 0.8137887839507706, 0.4429229861115882 … 0.01308976221980429, 0.48597926808091163, 0.20768781798463476, 0.30045611276046247, 0.15759293576302558, 0.975806377881983, 0.19451065500145392, 0.9638103356367584, 0.3594043445295293, 0.7792867217495332], Column6 = [3.0, 3.0, 1.0, 3.0, 1.0, 2.0, 3.0, 2.0, 3.0, 3.0 … 3.0, 2.0, 1.0, 2.0, 1.0, 2.0, 2.0, 3.0, 3.0, 1.0], Column7 = [2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0 … 2.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0]), CategoricalArrays.CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 0, 0, 0, 1, 0 … 0, 0, 1, 0, 1, 0, 0, 0, 0, 0])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "X, y = generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], \n", + " probs = [0.9, 0.1], \n", + " type = \"ColTable\", \n", + " rng=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Construct `BalancedBaggingClassifier` Model" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BalancedBaggingClassifier(\n", + " model = LogisticClassifier(\n", + " lambda = 2.220446049250313e-16, \n", + " gamma = 0.0, \n", + " penalty = :l2, \n", + " fit_intercept = true, \n", + " penalize_intercept = false, \n", + " scale_penalty_with_samples = true, \n", + " solver = nothing), \n", + " T = 10, \n", + " rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n", + "logistic_model = LogisticClassifier()\n", + "model = BalancedBaggingClassifier(classifier=logistic_model, T=10, rng=Random.Xoshiro(42))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Train & Evaluate the Model" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(LogisticClassifier(lambda = 2.220446049250313e-16, …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n", + "┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, NamedTuple{(), Tuple{}}}\n", + "│ optim_options: Optim.Options{Float64, Nothing}\n", + "│ lbfgs_options: NamedTuple{(), Tuple{}} NamedTuple()\n", + "└ @ MLJLinearModels /Users/essam/.julia/packages/MLJLinearModels/zSQnL/src/mlj/interface.jl:72\n" + ] + }, + { + "data": { + "text/plain": [ + "100-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{2}, Int64, UInt32, Float64}:\n", + " UnivariateFinite{Multiclass{2}}(0=>0.928, 1=>0.0722)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.845, 1=>0.155)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.749, 1=>0.251)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.902, 1=>0.0977)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.804, 1=>0.196)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.864, 1=>0.136)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.851, 1=>0.149)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.954, 1=>0.0458)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.853, 1=>0.147)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.86, 1=>0.14)\n", + " ⋮\n", + " UnivariateFinite{Multiclass{2}}(0=>0.671, 1=>0.329)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.73, 1=>0.27)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.843, 1=>0.157)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.941, 1=>0.0594)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.872, 1=>0.128)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.92, 1=>0.0797)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.929, 1=>0.0714)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.791, 1=>0.209)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.827, 1=>0.173)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mach = machine(logistic_model, X, y)\n", + "fit!(mach)\n", + "pred = predict(mach, X)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.5", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/Manifest.toml b/examples/Manifest.toml index b43e53e..e69de29 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -1,994 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.8.5" -manifest_format = "2.0" -project_hash = "bb8a7bcab9f98b02a846587066cc275d77de9db7" - -[[deps.ARFFFiles]] -deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] -git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" -uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" -version = "1.4.1" - -[[deps.AbstractTrees]] -git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.4" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.2" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArrayInterface]] -deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" -uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.4.11" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.BangBang]] -deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] -git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.39" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.BitFlags]] -git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.7" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" - -[[deps.CategoricalArrays]] -deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] -git-tree-sha1 = "1568b28f91293458345dabba6a5ea3f183250a61" -uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" -version = "0.10.8" - -[[deps.CategoricalDistributions]] -deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes", "UnicodePlots"] -git-tree-sha1 = "ed760a4fde49997ff9360a780abe6e20175162aa" -uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" -version = "0.1.11" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.16.0" - -[[deps.ChangesOfVariables]] -deps = ["InverseFunctions", "LinearAlgebra", "Test"] -git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" -uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.8" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.2" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "d9a8f86737b665e15a9641ecbac64deef9ce6724" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.23.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.4" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.10" - -[[deps.Combinatorics]] -git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" -uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -version = "1.0.2" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.9.0" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" - -[[deps.ComputationalResources]] -git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" -uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" -version = "0.3.2" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "5372dbbf8f0bdb8c700db5367132925c0771ef7e" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.2.1" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "c53fc348ca4d40d7b371e71fd52251839080cbc9" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.4" - -[[deps.Contour]] -git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" -uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.6.2" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[deps.DensityInterface]] -deps = ["InverseFunctions", "Test"] -git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" -uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" -version = "0.4.0" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "b6def76ffad15143924a2199f72a5cd883a2e8a9" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.9" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.Distributions]] -deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.100" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - -[[deps.EarlyStopping]] -deps = ["Dates", "Statistics"] -git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" -uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" -version = "0.3.0" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "e90caa41f5a86296e014e148ee061bd6c3edec96" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.9" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.20" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "a20eaa3ad64254c61eeb5f230d9306e937405434" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.6.1" - -[[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "c6e4a1fbe73b31a3dea94b1da449503b8830c306" -uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.21.1" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "5eab648309e2e060198b45820af1a37182de3cce" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.0" - -[[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" -uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" - -[[deps.Imbalance]] -deps = ["CategoricalArrays", "CategoricalDistributions", "Distances", "LinearAlgebra", "MLJModelInterface", "MLJTestInterface", "NearestNeighbors", "OrderedCollections", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "TableOperations", "TableTransforms", "Tables", "TransformsBase"] -git-tree-sha1 = "53eeb73d88913134cab0b0e04dd58901769fc7db" -uuid = "c709b415-507b-45b7-9a3d-1767c89fde68" -version = "0.1.0" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InlineStrings]] -deps = ["Parsers"] -git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.12" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IterationControl]] -deps = ["EarlyStopping", "InteractiveUtils"] -git-tree-sha1 = "d7df9a6fdd82a8cfdfe93a94fcce35515be634da" -uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" -version = "0.5.3" - -[[deps.IterativeSolvers]] -deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] -git-tree-sha1 = "1169632f425f79429f245113b775a0e3d121457c" -uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" -version = "0.9.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.0" - -[[deps.LatinHypercubeSampling]] -deps = ["Random", "StableRNGs", "StatsBase", "Test"] -git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" -uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" -version = "1.9.0" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LineSearches]] -deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] -git-tree-sha1 = "7bbea35cec17305fc70a0e5b4641477dc0789d9d" -uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" -version = "7.2.0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LinearMaps]] -deps = ["ChainRulesCore", "LinearAlgebra", "SparseArrays", "Statistics"] -git-tree-sha1 = "6698ab5e662b47ffc63a82b2f43c1cee015cf80d" -uuid = "7a12625a-238d-50fd-b39a-03d52299707e" -version = "3.11.0" - -[[deps.LogExpFunctions]] -deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.26" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "0d097476b6c381ab7906460ef1ef1638fbce1d91" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.2" - -[[deps.LossFunctions]] -deps = ["Markdown", "Requires", "Statistics"] -git-tree-sha1 = "df9da07efb9b05ca7ef701acec891ee8f73c99e2" -uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" -version = "0.11.1" - -[[deps.MLFlowClient]] -deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] -git-tree-sha1 = "32cee10a6527476bef0c6484ff4c60c2cead5d3e" -uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83" -version = "0.4.4" - -[[deps.MLJ]] -deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "193f1f1ac77d91eabe1ac81ff48646b378270eef" -uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" -version = "0.19.5" - -[[deps.MLJBalancing]] -deps = ["MLJBase", "MLJModelInterface", "OrderedCollections", "Random"] -path = "/Users/essam/.julia/dev/MLJBalancing" -uuid = "45f359ea-796d-4f51-95a5-deb1a414c586" -version = "1.0.0-DEV" - -[[deps.MLJBase]] -deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "0b7307d1a7214ec3c0ba305571e713f9492ea984" -uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.21.14" - -[[deps.MLJEnsembles]] -deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] -git-tree-sha1 = "95b306ef8108067d26dfde9ff3457d59911cc0d6" -uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" -version = "0.3.3" - -[[deps.MLJFlow]] -deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"] -git-tree-sha1 = "bceeeb648c9aa2fc6f65f957c688b164d30f2905" -uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f" -version = "0.1.1" - -[[deps.MLJIteration]] -deps = ["IterationControl", "MLJBase", "Random", "Serialization"] -git-tree-sha1 = "be6d5c71ab499a59e82d65e00a89ceba8732fcd5" -uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" -version = "0.5.1" - -[[deps.MLJLinearModels]] -deps = ["DocStringExtensions", "IterativeSolvers", "LinearAlgebra", "LinearMaps", "MLJModelInterface", "Optim", "Parameters"] -git-tree-sha1 = "c92bf0ea37bf51e1ef0160069c572825819748b8" -uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692" -version = "0.9.2" - -[[deps.MLJModelInterface]] -deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "03ae109be87f460fe3c96b8a0dbbf9c7bf840bd5" -uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.9.2" - -[[deps.MLJModels]] -deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "2b49f04f70266a2b040eb46ece157c4f5c1b0c13" -uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.10" - -[[deps.MLJTestInterface]] -deps = ["MLJBase", "Pkg", "Test"] -git-tree-sha1 = "9131806695e6a6d32c61ed5f7bccaadef9fef57e" -uuid = "72560011-54dd-4dc2-94f3-c5de45b75ecd" -version = "0.2.2" - -[[deps.MLJTuning]] -deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] -git-tree-sha1 = "02688098bd77827b64ed8ad747c14f715f98cfc4" -uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" -version = "0.7.4" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.11" - -[[deps.MarchingCubes]] -deps = ["PrecompileTools", "StaticArrays"] -git-tree-sha1 = "c8e29e2bacb98c9b6f10445227a8b0402f2f173a" -uuid = "299715c1-40a9-479a-aaf9-4a633d36f717" -version = "0.1.8" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] -git-tree-sha1 = "03a9b9718f5682ecb107ac9f7308991db4ce395b" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.7" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.0+0" - -[[deps.MicroCollections]] -deps = ["BangBang", "InitialValues", "Setfield"] -git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.4" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.1.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.2.1" - -[[deps.NLSolversBase]] -deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] -git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" -uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" -version = "7.8.3" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "2c3726ceb3388917602169bed973dbc97f1b51a8" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.13" - -[[deps.NelderMead]] -git-tree-sha1 = "25abc2f9b1c752e69229f37909461befa7c1f85d" -uuid = "2f6b4ddb-b4ff-44c0-b59b-2ab99302f970" -version = "0.4.0" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.20+0" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" - -[[deps.OpenML]] -deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"] -git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" -uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" -version = "0.3.1" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "51901a49222b09e3743c65b8847687ae5fc78eb2" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.1" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e78db7bd5c26fc5a6911b50a47ee302219157ea8" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.10+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optim]] -deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] -git-tree-sha1 = "963b004d15216f8129f6c0f7d187efa136570be0" -uuid = "429524aa-4258-5aef-a3af-852621145aeb" -version = "1.7.7" - -[[deps.OrderedCollections]] -git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.2" - -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.17" - -[[deps.Parameters]] -deps = ["OrderedCollections", "UnPack"] -git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" -uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" -version = "0.12.3" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.7.2" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" - -[[deps.PositiveFactorizations]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "17275485f373e6673f7e7f97051f703ed5b15b20" -uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" -version = "0.2.4" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.0" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.0" - -[[deps.PrettyPrinting]] -git-tree-sha1 = "22a601b04a154ca38867b991d5017469dc75f2db" -uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" -version = "0.4.1" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "ee094908d720185ddbdc58dbe0c1cbe35453ec7a" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.2.7" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "00099623ffee15972c16111bcf84c58a0051257c" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.9.0" - -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.8.2" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RecipesBase]] -deps = ["PrecompileTools"] -git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.3.4" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.RelocatableFolders]] -deps = ["SHA", "Scratch"] -git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" -uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "1.0.0" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.ScientificTypes]] -deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] -git-tree-sha1 = "75ccd10ca65b939dab03b812994e571bf1e3e1da" -uuid = "321657f4-b219-11e9-178b-2701a2544e81" -version = "3.0.2" - -[[deps.ScientificTypesBase]] -git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b" -uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" -version = "3.0.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.0" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "04bdff0b09c65ff3e06a05e3eb7b120223da3d39" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" - -[[deps.SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.SpecialFunctions]] -deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StableRNGs]] -deps = ["Random", "Test"] -git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" -uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" -version = "1.0.0" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "51621cca8651d9e334a659443a74ce50a3b6dfab" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.3" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - -[[deps.StatisticalTraits]] -deps = ["ScientificTypesBase"] -git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782" -uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" -version = "3.2.0" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.0" - -[[deps.StatsFuns]] -deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.0" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.0" - -[[deps.TableOperations]] -deps = ["SentinelArrays", "Tables", "Test"] -git-tree-sha1 = "e383c87cf2a1dc41fa30c093b2a19877c83e1bc1" -uuid = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" -version = "1.2.0" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.TableTransforms]] -deps = ["AbstractTrees", "CategoricalArrays", "Distributions", "LinearAlgebra", "NelderMead", "PrettyTables", "Random", "ScientificTypes", "Statistics", "StatsBase", "Tables", "Transducers", "TransformsBase"] -git-tree-sha1 = "d2fc117cc24ad1e459c9ff9d839e201431ec608a" -uuid = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e" -version = "1.10.0" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "a1f34829d5ac0ef499f6d84428bd6b4c71f02ead" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.11.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.1" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.13" - -[[deps.Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.78" - -[[deps.TransformsBase]] -deps = ["AbstractTrees"] -git-tree-sha1 = "53e92e907bd67eef12e319ca932a7dd036428bfc" -uuid = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8" -version = "1.2.1" - -[[deps.URIs]] -git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.0" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.UnPack]] -git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" -uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.2" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnicodePlots]] -deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "PrecompileTools", "Printf", "Requires", "SparseArrays", "StaticArrays", "StatsBase"] -git-tree-sha1 = "b96de03092fe4b18ac7e4786bee55578d4b75ae8" -uuid = "b8865327-cd53-5732-bb35-84acbb429228" -version = "3.6.0" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.12+3" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.1.1+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" diff --git a/examples/Project.toml b/examples/Project.toml index d236ff2..c1d18e1 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,7 +1,9 @@ [deps] -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Imbalance = "c709b415-507b-45b7-9a3d-1767c89fde68" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" -MLJBalancing = "45f359ea-796d-4f51-95a5-deb1a414c586" +MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" +MLJBalancing = "45f359ea-796d-4f51-95a5-deb1a414c586" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Imbalance = "c709b415-507b-45b7-9a3d-1767c89fde68" diff --git a/src/MLJBalancing.jl b/src/MLJBalancing.jl index d891030..82fc04b 100644 --- a/src/MLJBalancing.jl +++ b/src/MLJBalancing.jl @@ -2,10 +2,17 @@ module MLJBalancing using MLJBase using MLJModelInterface +using MLUtils using OrderedCollections +using Random +using Random: AbstractRNG, Xoshiro, rand +using StatsBase: sample + MMI = MLJModelInterface +include("balanced_bagging.jl") +export BalancedBaggingClassifier include("balanced_model.jl") export BalancedModel -end \ No newline at end of file +end diff --git a/src/balanced_bagging.jl b/src/balanced_bagging.jl new file mode 100644 index 0000000..15a70d1 --- /dev/null +++ b/src/balanced_bagging.jl @@ -0,0 +1,252 @@ + +""" +Return a dictionary `result` mapping each unique value in a given abstract vector `y` + to the vector of indices where that value occurs. +""" +function group_inds(y::AbstractVector{T}) where {T} + result = LittleDict{T,AbstractVector{Int}}() + for (i, v) in enumerate(y) + # Make a new entry in the dict if it doesn't exist + if !haskey(result, v) + result[v] = [] + end + # It exists, so push the index belonging to the class + push!(result[v], i) + end + return freeze(result) +end + +const ERR_MULTICLASS_UNSUPP(num_classes) = + "Only binary classification supported by BalancedBaggingClassifier. Got $num_classes classes" + +""" +Given an abstract vector `y` where any element takes one of two values, return the indices of the + most frequent of them, the indices of the least frequent of them, and the counts of each. +""" +function get_majority_minority_inds_counts(y) + # a tuple mapping each class to its indices + labels_inds = collect(group_inds(y)) + num_classes = length(labels_inds) + num_classes == 2 || throw(ArgumentError(ERR_MULTICLASS_UNSUPP(num_classes))) + # get the length of each class + first_class_count = length(labels_inds[1][2]) + second_class_count = length(labels_inds[2][2]) + # get the majority and minority inds by comparing lengths + if first_class_count > second_class_count + majority_inds, minority_inds = labels_inds[1][2], labels_inds[2][2] + return majority_inds, minority_inds, first_class_count, second_class_count + else + majority_inds, minority_inds = labels_inds[2][2], labels_inds[1][2] + return majority_inds, minority_inds, second_class_count, first_class_count + end +end + +""" +Given data `X`, `y` where `X` is a table and `y` is an abstract vector (which may be wrapped in nodes), + the indices and counts of the majority and minority classes and abstract rng, + return `X_sub`, `y_sub`, in the form of nodes, which are the result of randomly undersampling + the majority class data in `X`, `y` so that both classes occur equally frequently. +""" +function get_some_balanced_subset( + X, + y, + majority_inds, + minority_inds, + majority_count, + minority_count, + rng::AbstractRNG, +) + # randomly sample a subset of size minority_count indices from those belonging to majority class + random_inds = sample(rng, 1:majority_count, minority_count, replace = true) + majority_inds_undersampled = majority_inds[random_inds] + # find the corresponding subset of data which includes all minority and majority subset + balanced_subset_inds = vcat(minority_inds, majority_inds_undersampled) + X_sub = node(X -> getobs(X, balanced_subset_inds), X) + y_sub = node(y -> y[balanced_subset_inds], y) + return X_sub, y_sub +end + + +""" +Construct an BalancedBaggingClassifier model. +""" +mutable struct BalancedBaggingClassifier{RI<:Union{AbstractRNG, Integer},I<:Integer,P<:Probabilistic} <: + ProbabilisticNetworkComposite + model::P + T::I + rng::RI +end + +rng_handler(rng::Integer) = Random.Xoshiro(rng) +rng_handler(rng::AbstractRNG) = rng +const ERR_MISSING_CLF = "No model specified. Please specify a probabilistic classifier using the `model` keyword argument." +const ERR_BAD_T = "The number of ensemble models `T` cannot be negative." +const INFO_DEF_T(T_def) = "The number of ensemble models was not given and was thus, automatically set to $T_def"* + " which is the ratio of the frequency of the majority class to that of the minority class" +function BalancedBaggingClassifier(; + model = nothing, + T = 0, + rng = Random.default_rng(), +) + model === nothing && error(ERR_MISSING_CLF) + T < 0 && error(ERR_BAD_T) + rng = rng_handler(rng) + return BalancedBaggingClassifier(model, T, rng) +end + +function MLJBase.prefit(composite_model::BalancedBaggingClassifier, verbosity, X, y) + Xs, ys = source(X), source(y) + majority_inds, minority_inds, majority_count, minority_count = + get_majority_minority_inds_counts(y) + T = composite_model.T + if composite_model.T == 0 + T_def = round(Int, majority_count/minority_count) + T = T_def + @info INFO_DEF_T(T_def) + end + # get as much balanced subsets as needed + X_y_list_s = [ + get_some_balanced_subset( + Xs, + ys, + majority_inds, + minority_inds, + majority_count, + minority_count, + composite_model.rng, + ) for i in 1:T + ] + # Make a machine for each + machines = (machine(:model, Xsub, ysub) for (Xsub, ysub) in X_y_list_s) + # Average the predictions from nodes + all_preds = [MLJBase.predict(mach, Xs) for (mach, (X, _)) in zip(machines, X_y_list_s)] + yhat = mean(all_preds) + return (; predict=yhat ) +end + +### To register with MLJ +MMI.metadata_pkg( + BalancedBaggingClassifier, + name = "MLJBalancing", + package_uuid = "45f359ea-796d-4f51-95a5-deb1a414c586", + package_url = "https://github.com/JuliaAI/MLJBalancing.jl", + is_pure_julia = true, +) + +MMI.metadata_model( + BalancedBaggingClassifier, + input_scitype = Union{Union{Infinite,Finite}}, + output_scitype = Union{Union{Infinite,Finite}}, + target_scitype = AbstractVector, + load_path = "MLJBalancing." * string(BalancedBaggingClassifier), +) + +MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{P}}) where {P} = + MLJBase.prepend(:model, iteration_parameter(P)) +for trait in [ + :input_scitype, + :output_scitype, + :target_scitype, + :fit_data_scitype, + :predict_scitype, + :transform_scitype, + :inverse_transform_scitype, + :is_pure_julia, + :supports_weights, + :supports_class_weights, + :supports_online, + :supports_training_losses, + :is_supervised, + :prediction_type, +] + quote + MMI.$trait(::Type{<:BalancedBaggingClassifier{P}}) where {P} = MMI.$trait(P) + end |> eval +end + +""" + BalancedBaggingClassifier + A model type for constructing a balanced bagging classifier, based on [MLJBalancing.jl](https://github.com/JuliaAI/MLJBalancing). + + From MLJ, the type can be imported using + + `BalancedBaggingClassifier = @load BalancedBaggingClassifier pkg=MLJBalancing`` + + Construct an instance with default hyper-parameters using the syntax `bagging_model = BalancedBaggingClassifier(model=...)` + + Given a probablistic classifier.`BalancedBaggingClassifier` performs bagging by undersampling + only majority data in each bag so that its includes as much samples as in the minority data. + This is proposed with an Adaboost classifier where the output scores are averaged in the paper + Xu-Ying Liu, Jianxin Wu, & Zhi-Hua Zhou. (2009). Exploratory Undersampling for Class-Imbalance Learning. + IEEE Transactions on Systems, Man, and Cybernetics, Part B (Cybernetics), 39 (2), 539–5501 + + + # Training data + + In MLJ or MLJBase, bind an instance `model` to data with + + mach = machine(model, X, y) + + where + + - `X`: input features of a form supported by the `model` being wrapped (typically a table, e.g., `DataFrame`, + with `Continuous` columns will be supported, as a minimum) + + - `y`: the binary target, which can be any `AbstractVector` where `length(unique(y)) == 2` + + + Train the machine with `fit!(mach, rows=...)`. + + + # Hyperparameters + + - `model<:Probabilistic`: The classifier to use to train on each bag. + + - `T::integer=0`: The number of bags to be used in the ensemble. If not given, will be set as + the ratio between the frequency of the majority and minority classes. + + - `rng::Union{AbstractRNG, Integer}=default_rng()`: Either an `AbstractRNG` object or an `Integer` + seed to be used with `Xoshiro` + + # Operations + + - `predict(mach, Xnew)`: return predictions of the target given + features `Xnew` having the same scitype as `X` above. Predictions + are probabilistic, but uncalibrated. + + - `predict_mode(mach, Xnew)`: return the mode of each prediction above + + + + # Example + + ```julia + using MLJ + using Imbalance + + # Load base classifier and BalancedBaggingClassifier + BalancedBaggingClassifier = @load BalancedBaggingClassifier pkg=MLJBalancing + LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0 + + # Construct the base classifier and use it to construct a BalancedBaggingClassifier + logistic_model = LogisticClassifier() + model = BalancedBaggingClassifier(model=logistic_model, T=5) + + # Load the data and train the BalancedBaggingClassifier + X, y = Imbalance.generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], + probs = [0.9, 0.1], + type = "ColTable", + rng=42) + julia> Imbalance.checkbalance(y) + 1: ▇▇▇▇▇▇▇▇▇▇ 16 (19.0%) + 0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 84 (100.0%) + + mach = machine(model, X, y) |> fit! + + # Predict using the trained model + + yhat = predict(mach, X) # probabilistic predictions + predict_mode(mach, X) # point predictions + ``` +""" +BalancedBaggingClassifier diff --git a/test/balanced_bagging.jl b/test/balanced_bagging.jl new file mode 100644 index 0000000..085bb8d --- /dev/null +++ b/test/balanced_bagging.jl @@ -0,0 +1,118 @@ + +@testset "group_inds and get_majority_minority_inds_counts" begin + y = [0, 0, 0, 0, 1, 1, 1, 0] + @test MLJBalancing.group_inds(y) == Dict(0 => [1, 2, 3, 4, 8], 1 => [5, 6, 7]) + @test MLJBalancing.get_majority_minority_inds_counts(y) == + ([1, 2, 3, 4, 8], [5, 6, 7], 5, 3) + y = [0, 0, 0, 0, 1, 1, 1, 0, 2, 2, 2] + @test_throws MLJBalancing.ERR_MULTICLASS_UNSUPP(3) MLJBalancing.get_majority_minority_inds_counts( + y, + ) +end + +@testset "BalancedBaggingClassifier" begin + X, y = generate_imbalanced_data( + 100, + 5; + cat_feats_num_vals = [3, 2, 1, 2], + probs = [0.9, 0.1], + type = "ColTable", + rng = 42, + ) + majority_inds, minority_inds, majority_count, minority_count = + MLJBalancing.get_majority_minority_inds_counts(y) + Xs, ys = MLJBase.source(X), MLJBase.source(y) + X_sub, y_sub = MLJBalancing.get_some_balanced_subset( + Xs, + ys, + majority_inds, + minority_inds, + majority_count, + minority_count, + Random.Xoshiro(42) + ) + X_sub, y_sub = X_sub(rows = 1:100), y_sub(rows = 1:100) + majority_inds_sub, minority_inds_sub, _, _ = + MLJBalancing.get_majority_minority_inds_counts(y_sub) + + X_sub = Tables.matrix(X_sub) + X = Tables.matrix(X) + # minority untouched + @test sum(X_sub[minority_inds_sub, :]) == sum(X[minority_inds, :]) + # majority undersampled + @test issubset( + Set(eachrow(X_sub[majority_inds_sub, :])), + Set(eachrow(X[majority_inds, :])), + ) + # balances the data + @test length(y_sub[minority_inds_sub]) === length(y_sub[majority_inds_sub]) +end + +@testset "End-to-end Test" begin + ## setup parameters + R = Random.Xoshiro(42) + T = 2 + LogisticClassifier = @load LogisticClassifier pkg = MLJLinearModels verbosity = 0 + model = LogisticClassifier() + + ## setup data + # training + X, y = generate_imbalanced_data( + 100, + 5; + cat_feats_num_vals = [3, 2, 1, 2], + probs = [0.9, 0.1], + type = "ColTable", + rng = 42, + ) + # testing + Xt, yt = generate_imbalanced_data( + 5, + 5; + cat_feats_num_vals = [3, 2, 1, 2], + probs = [0.9, 0.1], + type = "ColTable", + rng = 42, + ) + + ## prepare subsets + majority_inds, minority_inds, majority_count, minority_count = + MLJBalancing.get_majority_minority_inds_counts(y) + Xs, ys = MLJBase.source(X), MLJBase.source(y) + X_sub1, y_sub1 = MLJBalancing.get_some_balanced_subset( + Xs, + ys, + majority_inds, + minority_inds, + majority_count, + minority_count, + R, + ) + X_sub1, y_sub1 = X_sub1(rows = 1:100), y_sub1(rows = 1:100) + X_sub2, y_sub2 = MLJBalancing.get_some_balanced_subset( + Xs, + ys, + majority_inds, + minority_inds, + majority_count, + minority_count, + R, + ) + X_sub2, y_sub2 = X_sub2(rows = 1:100), y_sub2(rows = 1:100) + + # training manually + mach1 = machine(model, X_sub1, y_sub1) + fit!(mach1) + mach2 = machine(model, X_sub2, y_sub2) + fit!(mach2) + pred1 = MLJBase.predict(mach1, Xt) + pred2 = MLJBase.predict(mach2, Xt) + pred_manual = mean([pred1, pred2]) + + ## using BalancedBagging + modelo = BalancedBaggingClassifier(model = model, T = 2, rng = Random.Xoshiro(42)) + mach = machine(modelo, X, y) + fit!(mach) + pred_auto = MLJBase.predict(mach, Xt) + @test sum(pred_manual) ≈ sum(pred_auto) +end diff --git a/test/runtests.jl b/test/runtests.jl index 1008fc8..00e9aac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using MLJModels using Imbalance using Random using DataFrames +using Tables - +include("balanced_bagging.jl") include("balanced_model.jl")