From bf99a536b04768fd372b2441dc38b6e79e171c6c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 15:12:14 +0100 Subject: [PATCH 001/131] Added tensors library --- Bonsai.ML.sln | 7 +++++++ .../Bonsai.ML.Tensors.csproj | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index c5a91b13..22b8a35a 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -30,6 +30,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.LinearDynamicalSy EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.HiddenMarkovModels.Design", "src\Bonsai.ML.HiddenMarkovModels.Design\Bonsai.ML.HiddenMarkovModels.Design.csproj", "{FC395DDC-62A4-4E14-A198-272AB05B33C7}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Tensors", "src\Bonsai.ML.Tensors\Bonsai.ML.Tensors.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -72,6 +74,10 @@ Global {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.Build.0 = Debug|Any CPU {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.ActiveCfg = Release|Any CPU {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.Build.0 = Release|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.Build.0 = Debug|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.ActiveCfg = Release|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -86,6 +92,7 @@ Global {39A4414F-52B1-42D7-82FA-E65DAD885264} = {12312384-8828-4786-AE19-EFCEDF968290} {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290} {17DF50BE-F481-4904-A4C8-5DF9725B2CA1} = {12312384-8828-4786-AE19-EFCEDF968290} + {06FCC9AF-CE38-44BB-92B3-0D451BE88537} = {12312384-8828-4786-AE19-EFCEDF968290} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {B6468F13-97CD-45E0-9E1E-C122D7F1E09F} diff --git a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj new file mode 100644 index 00000000..2a0a76e2 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj @@ -0,0 +1,19 @@ + + + Bonsai.ML.Tensors + A Bonsai package for TorchSharp tensor manipulations. + Bonsai Rx ML Tensors TorchSharp + net472;netstandard2.0 + 12.0 + + + + + + + + + + + + \ No newline at end of file From 19497a535e4cd1e0e984bfe9da197e1d3db8d5a4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 15:12:24 +0100 Subject: [PATCH 002/131] Added arange function --- src/Bonsai.ML.Tensors/Arange.cs | 39 +++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Arange.cs diff --git a/src/Bonsai.ML.Tensors/Arange.cs b/src/Bonsai.ML.Tensors/Arange.cs new file mode 100644 index 00000000..e3c355d0 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Arange.cs @@ -0,0 +1,39 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a 1-D tensor of values within a given range given the start, end, and step. + /// + [Combinator] + [Description("Creates a 1-D tensor of values within a given range given the start, end, and step.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Arange + { + /// + /// The start of the range. + /// + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + public int End { get; set; } = 10; + + /// + /// The step of the range. + /// + public int Step { get; set; } = 1; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(arange(Start, End, Step))); + } + } +} From fd57f8cab76bb950c3e71182f4377c94c85c12af Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 15:12:34 +0100 Subject: [PATCH 003/131] Added concat class --- src/Bonsai.ML.Tensors/Concat.cs | 44 +++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Concat.cs diff --git a/src/Bonsai.ML.Tensors/Concat.cs b/src/Bonsai.ML.Tensors/Concat.cs new file mode 100644 index 00000000..1a11eb0e --- /dev/null +++ b/src/Bonsai.ML.Tensors/Concat.cs @@ -0,0 +1,44 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Concatenates tensors along a given dimension. + /// + [Combinator] + [Description("Concatenates tensors along a given dimension.")] + [WorkflowElementCategory(ElementCategory.Combinator)] + public class Concat + { + /// + /// The dimension along which to concatenate the tensors. + /// + public long Dimension { get; set; } = 0; + + /// + /// Takes any number of observable sequences of tensors and concatenates the input tensors along the specified dimension by zipping each tensor together. + /// + public IObservable Process(params IObservable[] sources) + { + return sources.Aggregate((current, next) => + current.Zip(next, (tensor1, tensor2) => + cat(new Tensor[] { tensor1, tensor2 }, Dimension))); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + var tensor1 = value.Item1; + var tensor2 = value.Item2; + return cat(new Tensor[] { tensor1, tensor2 }, Dimension); + }); + } + } +} \ No newline at end of file From 0d3145763ecac55d63b0918cf133e8f7ab10396d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:55:34 +0100 Subject: [PATCH 004/131] Added arange function --- src/Bonsai.ML.Tensors/Arange.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Tensors/Arange.cs b/src/Bonsai.ML.Tensors/Arange.cs index e3c355d0..2a1eda40 100644 --- a/src/Bonsai.ML.Tensors/Arange.cs +++ b/src/Bonsai.ML.Tensors/Arange.cs @@ -2,6 +2,7 @@ using System.ComponentModel; using System.Reactive.Linq; using static TorchSharp.torch; +using TorchSharp; namespace Bonsai.ML.Tensors { @@ -29,7 +30,7 @@ public class Arange public int Step { get; set; } = 1; /// - /// Generates an observable sequence of 1-D tensors created with the function. + /// Generates an observable sequence of 1-D tensors created with the function. /// public IObservable Process() { From 0858c93b6a1e1771a7ff6ca9a5a5b98439b356b5 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:55:48 +0100 Subject: [PATCH 005/131] Added linspace --- src/Bonsai.ML.Tensors/Linspace.cs | 40 +++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Linspace.cs diff --git a/src/Bonsai.ML.Tensors/Linspace.cs b/src/Bonsai.ML.Tensors/Linspace.cs new file mode 100644 index 00000000..aa263500 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Linspace.cs @@ -0,0 +1,40 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. + /// + [Combinator] + [Description("Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Linspace + { + /// + /// The start of the range. + /// + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + public int End { get; set; } = 1; + + /// + /// The number of points to generate. + /// + public int Count { get; set; } = 10; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(linspace(Start, End, Count))); + } + } +} \ No newline at end of file From 812be103aa4dc12fccea21e281eb73f55a96b8f7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:15 +0100 Subject: [PATCH 006/131] Added meshgrid --- src/Bonsai.ML.Tensors/MeshGrid.cs | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/MeshGrid.cs diff --git a/src/Bonsai.ML.Tensors/MeshGrid.cs b/src/Bonsai.ML.Tensors/MeshGrid.cs new file mode 100644 index 00000000..6b0a2c73 --- /dev/null +++ b/src/Bonsai.ML.Tensors/MeshGrid.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Collections.Generic; +using static TorchSharp.torch; +using System.Linq; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class MeshGrid + { + /// + /// The indexing mode to use for the mesh grid. + /// + public string Indexing { get; set; } = "ij"; + + /// + /// Creates a mesh grid from the input tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(tensors => meshgrid(tensors, indexing: Indexing)); + } + } +} \ No newline at end of file From 0d8f732e93d399f2b1da735cbd8056776ce8e048 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:22 +0100 Subject: [PATCH 007/131] Added ones --- src/Bonsai.ML.Tensors/Ones.cs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Ones.cs diff --git a/src/Bonsai.ML.Tensors/Ones.cs b/src/Bonsai.ML.Tensors/Ones.cs new file mode 100644 index 00000000..499012bd --- /dev/null +++ b/src/Bonsai.ML.Tensors/Ones.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a tensor filled with ones. + /// + [Combinator] + [Description("Creates a tensor filled with ones.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Ones + { + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with ones. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + } +} \ No newline at end of file From c679dd487530c4acfa5fc4dcfd5601934ecf184e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:30 +0100 Subject: [PATCH 008/131] Added zeros --- src/Bonsai.ML.Tensors/Zeros.cs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Zeros.cs diff --git a/src/Bonsai.ML.Tensors/Zeros.cs b/src/Bonsai.ML.Tensors/Zeros.cs new file mode 100644 index 00000000..af220641 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Zeros.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a tensor filled with zeros. + /// + [Combinator] + [Description("Creates a tensor filled with zeros.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Zeros + { + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with zeros. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + } +} \ No newline at end of file From 044f30ee0f9d69156b1b86d760f3ba672e01e41e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:40 +0100 Subject: [PATCH 009/131] Added device initialization --- .../InitializeTorchDevice.cs | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/InitializeTorchDevice.cs diff --git a/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs b/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs new file mode 100644 index 00000000..dc9123f0 --- /dev/null +++ b/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using TorchSharp; + +namespace Bonsai.ML.Tensors +{ + /// + /// Initializes the Torch device with the specified device type. + /// + [Combinator] + [Description("Initializes the Torch device with the specified device type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class InitializeTorchDevice + { + /// + /// The device type to initialize. + /// + public DeviceType DeviceType { get; set; } + + /// + /// Initializes the Torch device with the specified device type. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => + { + InitializeDeviceType(DeviceType); + return Observable.Return(new Device(DeviceType)); + }); + } + } +} \ No newline at end of file From 9aeca78a07e3482bbf0917d7b8caac09760bc402 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:56 +0100 Subject: [PATCH 010/131] Added ability to move tensor to device --- src/Bonsai.ML.Tensors/ToDevice.cs | 34 +++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/ToDevice.cs diff --git a/src/Bonsai.ML.Tensors/ToDevice.cs b/src/Bonsai.ML.Tensors/ToDevice.cs new file mode 100644 index 00000000..574be5f3 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToDevice.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Moves the input tensor to the specified device. + /// + [Combinator] + [Description("Moves the input tensor to the specified device.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToDevice + { + /// + /// The device to which the input tensor should be moved. + /// + public Device Device { get; set; } + + /// + /// Returns the input tensor moved to the specified device. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.to(Device); + }); + } + } +} \ No newline at end of file From daa2519b58a19ab5fbee1867eefdea5dc1807478 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:57:11 +0100 Subject: [PATCH 011/131] Added permute --- src/Bonsai.ML.Tensors/Permute.cs | 33 ++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Permute.cs diff --git a/src/Bonsai.ML.Tensors/Permute.cs b/src/Bonsai.ML.Tensors/Permute.cs new file mode 100644 index 00000000..7f037d79 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Permute.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Permutes the dimensions of the input tensor according to the specified permutation. + /// + [Combinator] + [Description("Permutes the dimensions of the input tensor according to the specified permutation.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Permute + { + /// + /// The permutation of the dimensions. + /// + public long[] Dimensions { get; set; } = [0]; + + /// + /// Returns an observable sequence that permutes the dimensions of the input tensor according to the specified permutation. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.permute(Dimensions); + }); + } + } +} \ No newline at end of file From 05d88d5a568da7fd27e3fc15f140256b34983246 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:57:18 +0100 Subject: [PATCH 012/131] Added reshape --- src/Bonsai.ML.Tensors/Reshape.cs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Reshape.cs diff --git a/src/Bonsai.ML.Tensors/Reshape.cs b/src/Bonsai.ML.Tensors/Reshape.cs new file mode 100644 index 00000000..4fef3d83 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Reshape.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + [Combinator] + [Description("Reshapes the input tensor according to the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Reshape + { + /// + /// The dimensions of the reshaped tensor. + /// + public long[] Dimensions { get; set; } = [0]; + + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.reshape(Dimensions)); + } + } +} \ No newline at end of file From e07a7f7c680e8f6f28bb134ffa7f6acd922d6dfa Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:57:26 +0100 Subject: [PATCH 013/131] Added set --- src/Bonsai.ML.Tensors/Set.cs | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Set.cs diff --git a/src/Bonsai.ML.Tensors/Set.cs b/src/Bonsai.ML.Tensors/Set.cs new file mode 100644 index 00000000..3f2a6f50 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Set.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Sets the value of the input tensor at the specified index. + /// + [Combinator] + [Description("Sets the value of the input tensor at the specified index.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Set + { + /// + /// The index at which to set the value. + /// + public string Index + { + get => Helpers.IndexParser.SerializeIndexes(indexes); + set => indexes = Helpers.IndexParser.ParseString(value); + } + + private TensorIndex[] indexes; + + /// + /// The value to set at the specified index. + /// + [XmlIgnore] + public Tensor Value { get; set; } = null; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.index_put_(Value, indexes); + }); + } + } +} \ No newline at end of file From 59ec8cd5e241610fc33c3cbbdb5e51ccc3328867 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:01:42 +0100 Subject: [PATCH 014/131] Updated csproj with opencv.net package --- src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj index 2a0a76e2..8d87ac9b 100644 --- a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj +++ b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj @@ -4,15 +4,13 @@ A Bonsai package for TorchSharp tensor manipulations. Bonsai Rx ML Tensors TorchSharp net472;netstandard2.0 - 12.0 + true + - - - From b0519998bd38ad66a0565cadcfd32d945b7b2f6b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:02:07 +0100 Subject: [PATCH 015/131] Added concatenate class --- src/Bonsai.ML.Tensors/Concat.cs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/Bonsai.ML.Tensors/Concat.cs b/src/Bonsai.ML.Tensors/Concat.cs index 1a11eb0e..1dd99b7b 100644 --- a/src/Bonsai.ML.Tensors/Concat.cs +++ b/src/Bonsai.ML.Tensors/Concat.cs @@ -1,5 +1,6 @@ -using System; +using System; using System.ComponentModel; +using System.Linq; using System.Reactive.Linq; using static TorchSharp.torch; @@ -23,9 +24,9 @@ public class Concat /// public IObservable Process(params IObservable[] sources) { - return sources.Aggregate((current, next) => - current.Zip(next, (tensor1, tensor2) => - cat(new Tensor[] { tensor1, tensor2 }, Dimension))); + return sources.Aggregate((current, next) => + current.Zip(next, (tensor1, tensor2) => + cat([tensor1, tensor2], Dimension))); } /// @@ -33,12 +34,12 @@ public IObservable Process(params IObservable[] sources) /// public IObservable Process(IObservable> source) { - return source.Select(value => + return source.Select(value => { var tensor1 = value.Item1; var tensor2 = value.Item2; - return cat(new Tensor[] { tensor1, tensor2 }, Dimension); + return cat([tensor1, tensor2], Dimension); }); } } -} \ No newline at end of file +} From 31b39816f667a08518f6668e2c0ea14f8411cd7c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:03:04 +0100 Subject: [PATCH 016/131] Added convert data type --- src/Bonsai.ML.Tensors/ConvertDataType.cs | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/ConvertDataType.cs diff --git a/src/Bonsai.ML.Tensors/ConvertDataType.cs b/src/Bonsai.ML.Tensors/ConvertDataType.cs new file mode 100644 index 00000000..14b0db84 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ConvertDataType.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input tensor to the specified scalar type. + /// + [Combinator] + [Description("Converts the input tensor to the specified scalar type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ConvertDataType + { + /// + /// The scalar type to which to convert the input tensor. + /// + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Returns an observable sequence that converts the input tensor to the specified scalar type. + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => + { + return tensor.to_type(Type); + }); + } + } +} From 78e70261f30db29d5ef49def5cff6bdc9f276a2c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:03:22 +0100 Subject: [PATCH 017/131] Added create tensor method --- src/Bonsai.ML.Tensors/CreateTensor.cs | 245 ++++++++++++++++++++ src/Bonsai.ML.Tensors/Helpers/DataHelper.cs | 190 +++++++++++++++ 2 files changed, 435 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/CreateTensor.cs create mode 100644 src/Bonsai.ML.Tensors/Helpers/DataHelper.cs diff --git a/src/Bonsai.ML.Tensors/CreateTensor.cs b/src/Bonsai.ML.Tensors/CreateTensor.cs new file mode 100644 index 00000000..712c7243 --- /dev/null +++ b/src/Bonsai.ML.Tensors/CreateTensor.cs @@ -0,0 +1,245 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reactive.Linq; +using System.Reflection; +using System.Xml.Serialization; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// + [Combinator] + [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] + [WorkflowElementCategory(ElementCategory.Source)] + public class CreateTensor : ExpressionBuilder + { + Range argumentRange = new Range(0, 1); + + /// + public override Range ArgumentRange => argumentRange; + + /// + /// The data type of the tensor elements. + /// + public TensorDataType Type + { + get => scalarType; + set => scalarType = value; + } + + private TensorDataType scalarType = TensorDataType.Float32; + + /// + /// The values of the tensor elements. Uses Python-like syntax to specify the tensor values. + /// + public string Values + { + get => values; + set + { + values = value.Replace("False", "false").Replace("True", "true"); + } + } + + private string values = "[0]"; + + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + public Device Device { get => device; set => device = value; } + + private Device device = null; + + private Expression BuildTensorFromArray(Array arrayValues, Type returnType) + { + var rank = arrayValues.Rank; + var lengths = new int[rank]; + for (int i = 0; i < rank; i++) + { + lengths[i] = arrayValues.GetLength(i); + } + + var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); + var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); + var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); + + var assignments = new List(); + for (int i = 0; i < values.Length; i++) + { + var indices = new Expression[rank]; + int temp = i; + for (int j = rank - 1; j >= 0; j--) + { + indices[j] = Expression.Constant(temp % lengths[j]); + temp /= lengths[j]; + } + var value = Expression.Constant(arrayValues.GetValue(indices.Select(e => ((ConstantExpression)e).Value).Cast().ToArray())); + var arrayAccess = Expression.ArrayAccess(arrayVariable, indices); + var assignArrayValue = Expression.Assign(arrayAccess, value); + assignments.Add(assignArrayValue); + } + + var tensorDataInitializationBlock = Expression.Block( + arrayVariable, + assignArray, + Expression.Block(assignments), + arrayVariable + ); + + var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + arrayVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool), + typeof(string).MakeArrayType() + ] + ); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorDataInitializationBlock, + Expression.Constant(scalarType, typeof(ScalarType?)), + Expression.Constant(device, typeof(Device)), + Expression.Constant(false, typeof(bool)), + Expression.Constant(null, typeof(string).MakeArrayType()) + ); + + var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + private Expression BuildTensorFromScalarValue(object scalarValue, Type returnType) + { + var valueVariable = Expression.Variable(returnType, "value"); + var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); + + var tensorDataInitializationBlock = Expression.Block( + valueVariable, + assignValue, + valueVariable + ); + + var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(Device), + typeof(bool) + ] + ); + + var tensorCreationMethodArguments = new Expression[] { + Expression.Constant(device, typeof(Device) ), + Expression.Constant(false, typeof(bool) ) + }; + + if (tensorCreationMethodInfo == null) + { + tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool) + ] + ); + + tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + Expression.Constant(scalarType, typeof(ScalarType?)) + ).ToArray(); + } + + tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + tensorDataInitializationBlock + ).ToArray(); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorCreationMethodArguments + ); + + var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + /// + public override Expression Build(IEnumerable arguments) + { + var returnType = Helpers.TensorDataTypeHelper.GetTypeFromTensorDataType(scalarType); + var argTypes = arguments.Select(arg => arg.Type).ToArray(); + + var methodInfoArgumentTypes = new Type[] { + typeof(Tensor) + }; + + var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where(m => m.Name == "Process") + .ToArray(); + + var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) + .MakeGenericMethod( + arguments + .First() + .Type + .GetGenericArguments()[0] + ) : methods.FirstOrDefault(m => !m.IsGenericMethod); + + var tensorValues = Helpers.DataHelper.ParseString(values, returnType); + var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); + var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); + + try + { + return Expression.Call( + Expression.Constant(this), + methodInfo, + methodArguments + ); + } + finally + { + values = Helpers.DataHelper.SerializeData(tensorValues).Replace("False", "false").Replace("True", "true"); + scalarType = Helpers.TensorDataTypeHelper.GetTensorDataTypeFromType(returnType); + } + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values. + /// + public IObservable Process(Tensor tensor) + { + return Observable.Return(tensor); + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. + /// + public IObservable Process(IObservable source, Tensor tensor) + { + return Observable.Select(source, (_) => tensor); + } + } +} diff --git a/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs b/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs new file mode 100644 index 00000000..1bbf3228 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs @@ -0,0 +1,190 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Bonsai.ML.Tensors.Helpers +{ + /// + /// Provides helper methods for parsing tensor data types. + /// + public static class DataHelper + { + + /// + /// Serializes the input data into a string representation. + /// + public static string SerializeData(object data) + { + if (data is Array array) + { + return SerializeArray(array); + } + else + { + return JsonConvert.SerializeObject(data); + } + } + + /// + /// Serializes the input array into a string representation. + /// + public static string SerializeArray(Array array) + { + StringBuilder sb = new StringBuilder(); + SerializeArrayRecursive(array, sb, [0]); + return sb.ToString(); + } + + private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) + { + if (indices.Length < array.Rank) + { + sb.Append("["); + int length = array.GetLength(indices.Length); + for (int i = 0; i < length; i++) + { + int[] newIndices = new int[indices.Length + 1]; + indices.CopyTo(newIndices, 0); + newIndices[indices.Length] = i; + SerializeArrayRecursive(array, sb, newIndices); + if (i < length - 1) + { + sb.Append(", "); + } + } + sb.Append("]"); + } + else + { + object value = array.GetValue(indices); + sb.Append(value.ToString()); + } + } + + private static bool IsValidJson(string input) + { + int squareBrackets = 0; + foreach (char c in input) + { + if (c == '[') squareBrackets++; + else if (c == ']') squareBrackets--; + } + return squareBrackets == 0; + } + + /// + /// Parses the input string into an object of the specified type. + /// + public static object ParseString(string input, Type dtype) + { + if (!IsValidJson(input)) + { + throw new ArgumentException("JSON is invalid."); + } + var obj = JsonConvert.DeserializeObject(input); + int depth = ParseDepth(obj); + if (depth == 0) + { + return Convert.ChangeType(input, dtype); + } + int[] dimensions = ParseDimensions(obj, depth); + var resultArray = Array.CreateInstance(dtype, dimensions); + PopulateArray(obj, resultArray, [0], dtype); + return resultArray; + } + + private static int ParseDepth(JToken token, int currentDepth = 0) + { + if (token is JArray arr && arr.Count > 0) + { + return ParseDepth(arr[0], currentDepth + 1); + } + return currentDepth; + } + + private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) + { + if (depth == 0 || !(token is JArray)) + { + return [0]; + } + + List dimensions = new List(); + JToken current = token; + + while (current != null && current is JArray) + { + JArray currentArray = current as JArray; + dimensions.Add(currentArray.Count); + if (currentArray.Count > 0) + { + if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) + { + throw new Exception("Error parsing input. Dimensions are inconsistent."); + } + + if (!(currentArray.First() is JArray)) + { + if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) + { + throw new Exception("Error parsing types. All values must be of the same type and only numeric or boolean types are supported."); + } + } + } + + current = currentArray.Count > 0 ? currentArray[0] : null; + } + + if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) + { + var subArrayDimensions = new HashSet(); + foreach (JArray subArr in arr) + { + int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); + subArrayDimensions.Add(string.Join(",", subDims)); + } + + if (subArrayDimensions.Count > 1) + { + throw new ArgumentException("Inconsistent array dimensions."); + } + } + + return dimensions.ToArray(); + } + + private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) + { + if (token is JArray arr) + { + for (int i = 0; i < arr.Count; i++) + { + int[] newIndices = new int[indices.Length + 1]; + Array.Copy(indices, newIndices, indices.Length); + newIndices[newIndices.Length - 1] = i; + PopulateArray(arr[i], array, newIndices, dtype); + } + } + else + { + var values = ConvertType(token, dtype); + array.SetValue(values, indices); + } + } + + private static object ConvertType(object value, Type targetType) + { + try + { + return Convert.ChangeType(value, targetType); + } + catch (Exception ex) + { + throw new Exception("Error parsing type: ", ex); + } + } + } +} \ No newline at end of file From ac2f0cb138700397e1e59a999d1a4eb7f7221a01 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:03:45 +0100 Subject: [PATCH 018/131] Added index method and updated set method --- src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs | 91 ++++++++++++++++++++ src/Bonsai.ML.Tensors/Index.cs | 35 ++++++++ src/Bonsai.ML.Tensors/Set.cs | 4 +- 3 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs create mode 100644 src/Bonsai.ML.Tensors/Index.cs diff --git a/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs b/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs new file mode 100644 index 00000000..785eccea --- /dev/null +++ b/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs @@ -0,0 +1,91 @@ +using System; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors.Helpers +{ + /// + /// Provides helper methods to parse tensor indexes. + /// + public static class IndexHelper + { + + /// + /// Parses the input string into an array of tensor indexes. + /// + /// + public static TensorIndex[] ParseString(string input) + { + if (string.IsNullOrEmpty(input)) + { + return [0]; + } + + var indexStrings = input.Split(','); + var indices = new TensorIndex[indexStrings.Length]; + + for (int i = 0; i < indexStrings.Length; i++) + { + var indexString = indexStrings[i].Trim(); + if (int.TryParse(indexString, out int intIndex)) + { + indices[i] = TensorIndex.Single(intIndex); + } + else if (indexString == ":") + { + indices[i] = TensorIndex.Colon; + } + else if (indexString == "None") + { + indices[i] = TensorIndex.None; + } + else if (indexString == "...") + { + indices[i] = TensorIndex.Ellipsis; + } + else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") + { + indices[i] = TensorIndex.Bool(indexString.ToLower() == "true"); + } + else if (indexString.Contains(":")) + { + var rangeParts = indexString.Split(':'); + if (rangeParts.Length == 0) + { + indices[i] = TensorIndex.Slice(); + } + else if (rangeParts.Length == 1) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0])); + } + else if (rangeParts.Length == 2) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); + } + else if (rangeParts.Length == 3) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + return indices; + } + + /// + /// Serializes the input array of tensor indexes into a string representation. + /// + /// + /// + public static string SerializeIndexes(TensorIndex[] indexes) + { + return string.Join(", ", indexes); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Index.cs b/src/Bonsai.ML.Tensors/Index.cs new file mode 100644 index 00000000..3c1948f9 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Index.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. + /// Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis. + /// + [Combinator] + [Description("Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Index + { + /// + /// The indices to use for indexing the tensor. + /// + public string Indexes { get; set; } = string.Empty; + + /// + /// Indexes the input tensor with the specified indices. + /// + /// + /// + public IObservable Process(IObservable source) + { + var index = Helpers.IndexHelper.ParseString(Indexes); + return source.Select(tensor => { + return tensor.index(index); + }); + } + } +} diff --git a/src/Bonsai.ML.Tensors/Set.cs b/src/Bonsai.ML.Tensors/Set.cs index 3f2a6f50..7f6f8b92 100644 --- a/src/Bonsai.ML.Tensors/Set.cs +++ b/src/Bonsai.ML.Tensors/Set.cs @@ -21,8 +21,8 @@ public class Set /// public string Index { - get => Helpers.IndexParser.SerializeIndexes(indexes); - set => indexes = Helpers.IndexParser.ParseString(value); + get => Helpers.IndexHelper.SerializeIndexes(indexes); + set => indexes = Helpers.IndexHelper.ParseString(value); } private TensorIndex[] indexes; From eb62f4bb8aa352ebed1081fd7c2f2597db3c0eb3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:04:08 +0100 Subject: [PATCH 019/131] Defined tensor data types as subset of ScalarType --- src/Bonsai.ML.Tensors/TensorDataType.cs | 56 +++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/TensorDataType.cs diff --git a/src/Bonsai.ML.Tensors/TensorDataType.cs b/src/Bonsai.ML.Tensors/TensorDataType.cs new file mode 100644 index 00000000..a710a9ed --- /dev/null +++ b/src/Bonsai.ML.Tensors/TensorDataType.cs @@ -0,0 +1,56 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. + /// + public enum TensorDataType + { + /// + /// 8-bit unsigned integer. + /// + Byte = ScalarType.Byte, + + /// + /// 8-bit signed integer. + /// + Int8 = ScalarType.Int8, + + /// + /// 16-bit signed integer. + /// + Int16 = ScalarType.Int16, + + /// + /// 32-bit signed integer. + /// + Int32 = ScalarType.Int32, + + /// + /// 64-bit signed integer. + /// + Int64 = ScalarType.Int64, + + /// + /// 32-bit floating point. + /// + Float32 = ScalarType.Float32, + + /// + /// 64-bit floating point. + /// + Float64 = ScalarType.Float64, + + /// + /// Boolean. + /// + Bool = ScalarType.Bool + } +} \ No newline at end of file From c36931c41a850c2ea023bbd2cd9debb17194bad1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:04:33 +0100 Subject: [PATCH 020/131] Added to array method --- src/Bonsai.ML.Tensors/ToArray.cs | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/ToArray.cs diff --git a/src/Bonsai.ML.Tensors/ToArray.cs b/src/Bonsai.ML.Tensors/ToArray.cs new file mode 100644 index 00000000..af35ab4f --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToArray.cs @@ -0,0 +1,73 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input tensor into an array of the specified element type. + /// + [Combinator] + [Description("Converts the input tensor into an array of the specified element type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + public class ToArray : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public ToArray() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type mapping used to convert the input tensor into an array. + /// + public TypeMapping Type { get; set; } + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + methodInfo = methodInfo.MakeGenericMethod(returnType); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into an array of the specified element type. + /// + /// + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + return tensor.data().ToArray(); + }); + } + } +} \ No newline at end of file From 4f012509879f36df6583b404c0d485219e05e675 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:04:47 +0100 Subject: [PATCH 021/131] Added tensor data type helper --- .../Helpers/TensorDataTypeHelper.cs | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs diff --git a/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs new file mode 100644 index 00000000..7ea03f65 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Bonsai.ML.Tensors.Helpers +{ + /// + /// Provides helper methods for working with tensor data types. + /// + public class TensorDataTypeHelper + { + private static readonly Dictionary _lookup = new Dictionary + { + { TensorDataType.Byte, (typeof(byte), "byte") }, + { TensorDataType.Int16, (typeof(short), "short") }, + { TensorDataType.Int32, (typeof(int), "int") }, + { TensorDataType.Int64, (typeof(long), "long") }, + { TensorDataType.Float32, (typeof(float), "float") }, + { TensorDataType.Float64, (typeof(double), "double") }, + { TensorDataType.Bool, (typeof(bool), "bool") }, + { TensorDataType.Int8, (typeof(sbyte), "sbyte") }, + }; + + /// + /// Returns the type corresponding to the specified tensor data type. + /// + /// + /// + public static Type GetTypeFromTensorDataType(TensorDataType type) => _lookup[type].Type; + + /// + /// Returns the string representation corresponding to the specified tensor data type. + /// + /// + /// + public static string GetStringFromTensorDataType(TensorDataType type) => _lookup[type].StringValue; + + /// + /// Returns the tensor data type corresponding to the specified string representation. + /// + /// + /// + public static TensorDataType GetTensorDataTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Returns the tensor data type corresponding to the specified type. + /// + /// + /// + public static TensorDataType GetTensorDataTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} \ No newline at end of file From f879960dabf456826a401dc97f8ae7b708a28d4f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:05:16 +0100 Subject: [PATCH 022/131] Added methods to convert OpenCV types --- src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs | 169 ++++++++++++++++++ src/Bonsai.ML.Tensors/ToImage.cs | 28 +++ src/Bonsai.ML.Tensors/ToMat.cs | 28 +++ src/Bonsai.ML.Tensors/ToTensor.cs | 134 ++++++++++++++ 4 files changed, 359 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs create mode 100644 src/Bonsai.ML.Tensors/ToImage.cs create mode 100644 src/Bonsai.ML.Tensors/ToMat.cs create mode 100644 src/Bonsai.ML.Tensors/ToTensor.cs diff --git a/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs new file mode 100644 index 00000000..265f2119 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs @@ -0,0 +1,169 @@ +using System; +using System.Runtime.InteropServices; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors.Helpers +{ + /// + /// Helper class to convert between OpenCV mats, images and Torch tensors. + /// + public static class OpenCVHelper + { + private static Dictionary bitDepthLookup = new Dictionary { + { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, + { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, + { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, + { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, + { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, + { ScalarType.Int8, (IplDepth.S8, Depth.S8) } + }; + + private static ConcurrentDictionary deleters = new ConcurrentDictionary(); + + internal delegate void GCHandleDeleter(IntPtr memory); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_data(IntPtr handle); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); + + /// + /// Creates a tensor from a pointer to the data and the dimensions of the tensor. + /// + /// + /// + /// + /// + public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) + { + var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); + var gchp = GCHandle.ToIntPtr(dataHandle); + GCHandleDeleter deleter = null; + + deleter = new GCHandleDeleter((IntPtr ptrHandler) => + { + GCHandle.FromIntPtr(gchp).Free(); + deleters.TryRemove(deleter, out deleter); + }); + deleters.TryAdd(deleter, deleter); + + fixed (long* dimensionsPtr = dimensions) + { + IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + if (tensorHandle == IntPtr.Zero) { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + } + if (tensorHandle == IntPtr.Zero) { CheckForErrors(); } + var output = Tensor.UnsafeCreateTensor(tensorHandle); + return output; + } + } + + /// + /// Converts an OpenCV image to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(IplImage image) + { + if (image == null) + { + return empty([ 0, 0, 0 ]); + } + + int width = image.Width; + int height = image.Height; + int channels = image.Channels; + + var iplDepth = image.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; + + IntPtr tensorDataPtr = image.ImageData; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts an OpenCV mat to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(Mat mat) + { + if (mat == null) + { + return empty([0, 0, 0 ]); + } + + int width = mat.Size.Width; + int height = mat.Size.Height; + int channels = mat.Channels; + + var depth = mat.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; + + IntPtr tensorDataPtr = mat.Data; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts a Torch tensor to an OpenCV image. + /// + /// + /// + public unsafe static IplImage ToImage(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var iplDepth = bitDepthLookup[tensorType].IplDepth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); + + return image; + } + + /// + /// Converts a Torch tensor to an OpenCV mat. + /// + /// + /// + public unsafe static Mat ToMat(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var depth = bitDepthLookup[tensorType].Depth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); + + return mat; + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToImage.cs b/src/Bonsai.ML.Tensors/ToImage.cs new file mode 100644 index 00000000..e29a3825 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToImage.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input tensor into an OpenCV image. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToImage + { + /// + /// Converts the input tensor into an OpenCV image. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToImage); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToMat.cs b/src/Bonsai.ML.Tensors/ToMat.cs new file mode 100644 index 00000000..8a22f408 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToMat.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input tensor into an OpenCV mat. + /// + [Combinator] + [Description("Converts the input tensor into an OpenCV mat.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToMat + { + /// + /// Converts the input tensor into an OpenCV mat. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToMat); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToTensor.cs b/src/Bonsai.ML.Tensors/ToTensor.cs new file mode 100644 index 00000000..083e2797 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToTensor.cs @@ -0,0 +1,134 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input value into a tensor. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToTensor + { + /// + /// Converts an int into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a double into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a byte into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a bool into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a float into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a long into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a short into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an array into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an IplImage into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToTensor); + } + + /// + /// Converts a Mat into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToTensor); + } + } +} \ No newline at end of file From 40615ab31f008d78edc9952b53fed308b75da89f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 12:51:18 +0100 Subject: [PATCH 023/131] Refactored to torch namespace instead of tensors --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 17 ++ src/Bonsai.ML.Torch/Helpers/DataHelper.cs | 190 ++++++++++++++ src/Bonsai.ML.Torch/Helpers/IndexHelper.cs | 91 +++++++ src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs | 169 ++++++++++++ .../Helpers/TensorDataTypeHelper.cs | 53 ++++ .../Configuration/ModelConfiguration.cs | 5 + .../Configuration/ModuleConfiguration.cs | 5 + .../NeuralNets/LoadPretrainedModel.cs | 47 ++++ .../NeuralNets/ModelManager.cs | 5 + .../NeuralNets/Models/AlexNet.cs | 70 +++++ .../NeuralNets/Models/MNIST.cs | 61 +++++ .../NeuralNets/Models/MobileNet.cs | 72 +++++ .../NeuralNets/Models/PretrainedModels.cs | 9 + src/Bonsai.ML.Torch/Tensors/Arange.cs | 40 +++ src/Bonsai.ML.Torch/Tensors/Concat.cs | 45 ++++ .../Tensors/ConvertDataType.cs | 32 +++ src/Bonsai.ML.Torch/Tensors/CreateTensor.cs | 245 ++++++++++++++++++ src/Bonsai.ML.Torch/Tensors/Index.cs | 35 +++ .../Tensors/InitializeTorchDevice.cs | 35 +++ src/Bonsai.ML.Torch/Tensors/Linspace.cs | 40 +++ src/Bonsai.ML.Torch/Tensors/MeshGrid.cs | 33 +++ src/Bonsai.ML.Torch/Tensors/Ones.cs | 30 +++ src/Bonsai.ML.Torch/Tensors/Permute.cs | 33 +++ src/Bonsai.ML.Torch/Tensors/Reshape.cs | 32 +++ src/Bonsai.ML.Torch/Tensors/Set.cs | 48 ++++ src/Bonsai.ML.Torch/Tensors/TensorDataType.cs | 56 ++++ src/Bonsai.ML.Torch/Tensors/ToArray.cs | 73 ++++++ src/Bonsai.ML.Torch/Tensors/ToDevice.cs | 34 +++ src/Bonsai.ML.Torch/Tensors/ToImage.cs | 28 ++ src/Bonsai.ML.Torch/Tensors/ToMat.cs | 28 ++ src/Bonsai.ML.Torch/Tensors/ToTensor.cs | 134 ++++++++++ src/Bonsai.ML.Torch/Tensors/Zeros.cs | 30 +++ 32 files changed, 1825 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj create mode 100644 src/Bonsai.ML.Torch/Helpers/DataHelper.cs create mode 100644 src/Bonsai.ML.Torch/Helpers/IndexHelper.cs create mode 100644 src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs create mode 100644 src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Arange.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Concat.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/CreateTensor.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Index.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Linspace.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/MeshGrid.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Ones.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Permute.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Reshape.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Set.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/TensorDataType.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToArray.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToDevice.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToImage.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToMat.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToTensor.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Zeros.cs diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj new file mode 100644 index 00000000..9ed3c5d8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -0,0 +1,17 @@ + + + Bonsai.ML.Torch + A Bonsai package for TorchSharp tensor manipulations. + Bonsai Rx ML Tensors TorchSharp + net472;netstandard2.0 + true + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/DataHelper.cs b/src/Bonsai.ML.Torch/Helpers/DataHelper.cs new file mode 100644 index 00000000..ffed053a --- /dev/null +++ b/src/Bonsai.ML.Torch/Helpers/DataHelper.cs @@ -0,0 +1,190 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Bonsai.ML.Torch.Helpers +{ + /// + /// Provides helper methods for parsing tensor data types. + /// + public static class DataHelper + { + + /// + /// Serializes the input data into a string representation. + /// + public static string SerializeData(object data) + { + if (data is Array array) + { + return SerializeArray(array); + } + else + { + return JsonConvert.SerializeObject(data); + } + } + + /// + /// Serializes the input array into a string representation. + /// + public static string SerializeArray(Array array) + { + StringBuilder sb = new StringBuilder(); + SerializeArrayRecursive(array, sb, [0]); + return sb.ToString(); + } + + private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) + { + if (indices.Length < array.Rank) + { + sb.Append("["); + int length = array.GetLength(indices.Length); + for (int i = 0; i < length; i++) + { + int[] newIndices = new int[indices.Length + 1]; + indices.CopyTo(newIndices, 0); + newIndices[indices.Length] = i; + SerializeArrayRecursive(array, sb, newIndices); + if (i < length - 1) + { + sb.Append(", "); + } + } + sb.Append("]"); + } + else + { + object value = array.GetValue(indices); + sb.Append(value.ToString()); + } + } + + private static bool IsValidJson(string input) + { + int squareBrackets = 0; + foreach (char c in input) + { + if (c == '[') squareBrackets++; + else if (c == ']') squareBrackets--; + } + return squareBrackets == 0; + } + + /// + /// Parses the input string into an object of the specified type. + /// + public static object ParseString(string input, Type dtype) + { + if (!IsValidJson(input)) + { + throw new ArgumentException("JSON is invalid."); + } + var obj = JsonConvert.DeserializeObject(input); + int depth = ParseDepth(obj); + if (depth == 0) + { + return Convert.ChangeType(input, dtype); + } + int[] dimensions = ParseDimensions(obj, depth); + var resultArray = Array.CreateInstance(dtype, dimensions); + PopulateArray(obj, resultArray, [0], dtype); + return resultArray; + } + + private static int ParseDepth(JToken token, int currentDepth = 0) + { + if (token is JArray arr && arr.Count > 0) + { + return ParseDepth(arr[0], currentDepth + 1); + } + return currentDepth; + } + + private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) + { + if (depth == 0 || !(token is JArray)) + { + return [0]; + } + + List dimensions = new List(); + JToken current = token; + + while (current != null && current is JArray) + { + JArray currentArray = current as JArray; + dimensions.Add(currentArray.Count); + if (currentArray.Count > 0) + { + if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) + { + throw new Exception("Error parsing input. Dimensions are inconsistent."); + } + + if (!(currentArray.First() is JArray)) + { + if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) + { + throw new Exception("Error parsing types. All values must be of the same type and only numeric or boolean types are supported."); + } + } + } + + current = currentArray.Count > 0 ? currentArray[0] : null; + } + + if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) + { + var subArrayDimensions = new HashSet(); + foreach (JArray subArr in arr) + { + int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); + subArrayDimensions.Add(string.Join(",", subDims)); + } + + if (subArrayDimensions.Count > 1) + { + throw new ArgumentException("Inconsistent array dimensions."); + } + } + + return dimensions.ToArray(); + } + + private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) + { + if (token is JArray arr) + { + for (int i = 0; i < arr.Count; i++) + { + int[] newIndices = new int[indices.Length + 1]; + Array.Copy(indices, newIndices, indices.Length); + newIndices[newIndices.Length - 1] = i; + PopulateArray(arr[i], array, newIndices, dtype); + } + } + else + { + var values = ConvertType(token, dtype); + array.SetValue(values, indices); + } + } + + private static object ConvertType(object value, Type targetType) + { + try + { + return Convert.ChangeType(value, targetType); + } + catch (Exception ex) + { + throw new Exception("Error parsing type: ", ex); + } + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs b/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs new file mode 100644 index 00000000..541ae443 --- /dev/null +++ b/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs @@ -0,0 +1,91 @@ +using System; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Helpers +{ + /// + /// Provides helper methods to parse tensor indexes. + /// + public static class IndexHelper + { + + /// + /// Parses the input string into an array of tensor indexes. + /// + /// + public static TensorIndex[] ParseString(string input) + { + if (string.IsNullOrEmpty(input)) + { + return [0]; + } + + var indexStrings = input.Split(','); + var indices = new TensorIndex[indexStrings.Length]; + + for (int i = 0; i < indexStrings.Length; i++) + { + var indexString = indexStrings[i].Trim(); + if (int.TryParse(indexString, out int intIndex)) + { + indices[i] = TensorIndex.Single(intIndex); + } + else if (indexString == ":") + { + indices[i] = TensorIndex.Colon; + } + else if (indexString == "None") + { + indices[i] = TensorIndex.None; + } + else if (indexString == "...") + { + indices[i] = TensorIndex.Ellipsis; + } + else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") + { + indices[i] = TensorIndex.Bool(indexString.ToLower() == "true"); + } + else if (indexString.Contains(":")) + { + var rangeParts = indexString.Split(':'); + if (rangeParts.Length == 0) + { + indices[i] = TensorIndex.Slice(); + } + else if (rangeParts.Length == 1) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0])); + } + else if (rangeParts.Length == 2) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); + } + else if (rangeParts.Length == 3) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + return indices; + } + + /// + /// Serializes the input array of tensor indexes into a string representation. + /// + /// + /// + public static string SerializeIndexes(TensorIndex[] indexes) + { + return string.Join(", ", indexes); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs new file mode 100644 index 00000000..4e90fa35 --- /dev/null +++ b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs @@ -0,0 +1,169 @@ +using System; +using System.Runtime.InteropServices; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Helpers +{ + /// + /// Helper class to convert between OpenCV mats, images and Torch tensors. + /// + public static class OpenCVHelper + { + private static Dictionary bitDepthLookup = new Dictionary { + { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, + { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, + { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, + { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, + { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, + { ScalarType.Int8, (IplDepth.S8, Depth.S8) } + }; + + private static ConcurrentDictionary deleters = new ConcurrentDictionary(); + + internal delegate void GCHandleDeleter(IntPtr memory); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_data(IntPtr handle); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); + + /// + /// Creates a tensor from a pointer to the data and the dimensions of the tensor. + /// + /// + /// + /// + /// + public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) + { + var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); + var gchp = GCHandle.ToIntPtr(dataHandle); + GCHandleDeleter deleter = null; + + deleter = new GCHandleDeleter((IntPtr ptrHandler) => + { + GCHandle.FromIntPtr(gchp).Free(); + deleters.TryRemove(deleter, out deleter); + }); + deleters.TryAdd(deleter, deleter); + + fixed (long* dimensionsPtr = dimensions) + { + IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + if (tensorHandle == IntPtr.Zero) { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + } + if (tensorHandle == IntPtr.Zero) { CheckForErrors(); } + var output = Tensor.UnsafeCreateTensor(tensorHandle); + return output; + } + } + + /// + /// Converts an OpenCV image to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(IplImage image) + { + if (image == null) + { + return empty([ 0, 0, 0 ]); + } + + int width = image.Width; + int height = image.Height; + int channels = image.Channels; + + var iplDepth = image.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; + + IntPtr tensorDataPtr = image.ImageData; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts an OpenCV mat to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(Mat mat) + { + if (mat == null) + { + return empty([0, 0, 0 ]); + } + + int width = mat.Size.Width; + int height = mat.Size.Height; + int channels = mat.Channels; + + var depth = mat.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; + + IntPtr tensorDataPtr = mat.Data; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts a Torch tensor to an OpenCV image. + /// + /// + /// + public unsafe static IplImage ToImage(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var iplDepth = bitDepthLookup[tensorType].IplDepth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); + + return image; + } + + /// + /// Converts a Torch tensor to an OpenCV mat. + /// + /// + /// + public unsafe static Mat ToMat(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var depth = bitDepthLookup[tensorType].Depth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); + + return mat; + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs new file mode 100644 index 00000000..91faf20b --- /dev/null +++ b/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Bonsai.ML.Torch.Tensors; + +namespace Bonsai.ML.Torch.Helpers +{ + /// + /// Provides helper methods for working with tensor data types. + /// + public class TensorDataTypeHelper + { + private static readonly Dictionary _lookup = new Dictionary + { + { TensorDataType.Byte, (typeof(byte), "byte") }, + { TensorDataType.Int16, (typeof(short), "short") }, + { TensorDataType.Int32, (typeof(int), "int") }, + { TensorDataType.Int64, (typeof(long), "long") }, + { TensorDataType.Float32, (typeof(float), "float") }, + { TensorDataType.Float64, (typeof(double), "double") }, + { TensorDataType.Bool, (typeof(bool), "bool") }, + { TensorDataType.Int8, (typeof(sbyte), "sbyte") }, + }; + + /// + /// Returns the type corresponding to the specified tensor data type. + /// + /// + /// + public static Type GetTypeFromTensorDataType(TensorDataType type) => _lookup[type].Type; + + /// + /// Returns the string representation corresponding to the specified tensor data type. + /// + /// + /// + public static string GetStringFromTensorDataType(TensorDataType type) => _lookup[type].StringValue; + + /// + /// Returns the tensor data type corresponding to the specified string representation. + /// + /// + /// + public static TensorDataType GetTensorDataTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Returns the tensor data type corresponding to the specified type. + /// + /// + /// + public static TensorDataType GetTensorDataTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs new file mode 100644 index 00000000..7628f72d --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs @@ -0,0 +1,5 @@ +namespace Bonsai.ML.Torch.NeuralNets.Configuration; + +public class ModelConfiguration +{ +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs new file mode 100644 index 00000000..dfe56272 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs @@ -0,0 +1,5 @@ +namespace Bonsai.ML.Torch.NeuralNets.Configuration; + +public class ModuleConfiguration +{ +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs new file mode 100644 index 00000000..fb7722f2 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -0,0 +1,47 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; +using static TorchSharp.torch.nn; +using Bonsai.Expressions; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadPretrainedModel + { + public Models.PretrainedModels ModelName { get; set; } + public Device Device { get; set; } + + private int numClasses = 10; + + public IObservable Process() + { + Module model = null; + var modelName = ModelName.ToString().ToLower(); + var device = Device; + + switch (modelName) + { + case "alexnet": + model = new Models.AlexNet(modelName, numClasses, device); + break; + case "mobilenet": + model = new Models.MobileNet(modelName, numClasses, device); + break; + case "mnist": + model = new Models.MNIST(modelName, device); + break; + default: + throw new ArgumentException($"Model {modelName} not supported."); + } + + return Observable.Defer(() => { + return Observable.Return(model); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs b/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs new file mode 100644 index 00000000..035b3ca1 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs @@ -0,0 +1,5 @@ +namespace Bonsai.ML.Torch.NeuralNets; + +public class ModelManager +{ +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs new file mode 100644 index 00000000..4ca9f79c --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -0,0 +1,70 @@ +using System; +using System.IO; +using System.Linq; +using System.Collections.Generic; +using System.Diagnostics; + +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using static TorchSharp.torch.nn.functional; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + /// + /// Modified version of original AlexNet to fix CIFAR10 32x32 images. + /// + public class AlexNet : Module + { + private readonly Module features; + private readonly Module avgPool; + private readonly Module classifier; + + public AlexNet(string name, int numClasses, Device device = null) : base(name) + { + features = Sequential( + ("c1", Conv2d(3, 64, kernelSize: 3, stride: 2, padding: 1)), + ("r1", ReLU(inplace: true)), + ("mp1", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("c2", Conv2d(64, 192, kernelSize: 3, padding: 1)), + ("r2", ReLU(inplace: true)), + ("mp2", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("c3", Conv2d(192, 384, kernelSize: 3, padding: 1)), + ("r3", ReLU(inplace: true)), + ("c4", Conv2d(384, 256, kernelSize: 3, padding: 1)), + ("r4", ReLU(inplace: true)), + ("c5", Conv2d(256, 256, kernelSize: 3, padding: 1)), + ("r5", ReLU(inplace: true)), + ("mp3", MaxPool2d(kernelSize: new long[] { 2, 2 }))); + + avgPool = AdaptiveAvgPool2d(new long[] { 2, 2 }); + + classifier = Sequential( + ("d1", Dropout()), + ("l1", Linear(256 * 2 * 2, 4096)), + ("r1", ReLU(inplace: true)), + ("d2", Dropout()), + ("l2", Linear(4096, 4096)), + ("r3", ReLU(inplace: true)), + ("d3", Dropout()), + ("l3", Linear(4096, numClasses)) + ); + + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + public override Tensor forward(Tensor input) + { + var f = features.forward(input); + var avg = avgPool.forward(f); + + var x = avg.view(new long[] { avg.shape[0], 256 * 2 * 2 }); + + return classifier.forward(x); + } + } + +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs new file mode 100644 index 00000000..b707e2d5 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -0,0 +1,61 @@ +using System; +using System.IO; +using System.Collections.Generic; +using System.Diagnostics; +using TorchSharp; +using static TorchSharp.torch; + +using static TorchSharp.torch.nn; +using static TorchSharp.torch.nn.functional; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + public class MNIST : Module + { + private Module conv1 = Conv2d(1, 32, 3); + private Module conv2 = Conv2d(32, 64, 3); + private Module fc1 = Linear(9216, 128); + private Module fc2 = Linear(128, 10); + + private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 }); + + private Module relu1 = ReLU(); + private Module relu2 = ReLU(); + private Module relu3 = ReLU(); + + private Module dropout1 = Dropout(0.25); + private Module dropout2 = Dropout(0.5); + + private Module flatten = Flatten(); + private Module logsm = LogSoftmax(1); + + public MNIST(string name, Device device = null) : base(name) + { + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + public override Tensor forward(Tensor input) + { + var l11 = conv1.forward(input); + var l12 = relu1.forward(l11); + + var l21 = conv2.forward(l12); + var l22 = relu2.forward(l21); + var l23 = pool1.forward(l22); + var l24 = dropout1.forward(l23); + + var x = flatten.forward(l24); + + var l31 = fc1.forward(x); + var l32 = relu3.forward(l31); + var l33 = dropout2.forward(l32); + + var l41 = fc2.forward(l33); + + return logsm.forward(l41); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs new file mode 100644 index 00000000..e9d66038 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using Bonsai.ML.Torch.Tensors; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + /// + /// Modified version of MobileNet to classify CIFAR10 32x32 images. + /// + /// + /// With an unaugmented CIFAR-10 data set, the author of this saw training converge + /// at roughly 75% accuracy on the test set, over the course of 1500 epochs. + /// + public class MobileNet : Module + { + // The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenet.py + // Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE + + private readonly long[] planes = new long[] { 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 }; + private readonly long[] strides = new long[] { 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 }; + + private readonly Module layers; + + public MobileNet(string name, int numClasses, Device device = null) : base(name) + { + if (planes.Length != strides.Length) throw new ArgumentException("'planes' and 'strides' must have the same length."); + + var modules = new List<(string, Module)>(); + + modules.Add(($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false))); + modules.Add(($"bnrm2d-first", BatchNorm2d(32))); + modules.Add(($"relu-first", ReLU())); + MakeLayers(modules, 32); + modules.Add(("avgpool", AvgPool2d(new long[] { 2, 2 }))); + modules.Add(("flatten", Flatten())); + modules.Add(($"linear", Linear(planes[planes.Length-1], numClasses))); + + layers = Sequential(modules); + + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + private void MakeLayers(List<(string, Module)> modules, long in_planes) + { + + for (var i = 0; i < strides.Length; i++) { + var out_planes = planes[i]; + var stride = strides[i]; + + modules.Add(($"conv2d-{i}a", Conv2d(in_planes, in_planes, kernelSize: 3, stride: stride, padding: 1, groups: in_planes, bias: false))); + modules.Add(($"bnrm2d-{i}a", BatchNorm2d(in_planes))); + modules.Add(($"relu-{i}a", ReLU())); + modules.Add(($"conv2d-{i}b", Conv2d(in_planes, out_planes, kernelSize: 1L, stride: 1L, padding: 0L, bias: false))); + modules.Add(($"bnrm2d-{i}b", BatchNorm2d(out_planes))); + modules.Add(($"relu-{i}b", ReLU())); + + in_planes = out_planes; + } + } + + public override Tensor forward(Tensor input) + { + return layers.forward(input); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs new file mode 100644 index 00000000..a3c65bdc --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs @@ -0,0 +1,9 @@ +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + public enum PretrainedModels + { + AlexNet, + MobileNet, + MNIST + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Arange.cs b/src/Bonsai.ML.Torch/Tensors/Arange.cs new file mode 100644 index 00000000..011d0708 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Arange.cs @@ -0,0 +1,40 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using TorchSharp; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a 1-D tensor of values within a given range given the start, end, and step. + /// + [Combinator] + [Description("Creates a 1-D tensor of values within a given range given the start, end, and step.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Arange + { + /// + /// The start of the range. + /// + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + public int End { get; set; } = 10; + + /// + /// The step of the range. + /// + public int Step { get; set; } = 1; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(arange(Start, End, Step))); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/Concat.cs b/src/Bonsai.ML.Torch/Tensors/Concat.cs new file mode 100644 index 00000000..52275bb7 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Concat.cs @@ -0,0 +1,45 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Concatenates tensors along a given dimension. + /// + [Combinator] + [Description("Concatenates tensors along a given dimension.")] + [WorkflowElementCategory(ElementCategory.Combinator)] + public class Concat + { + /// + /// The dimension along which to concatenate the tensors. + /// + public long Dimension { get; set; } = 0; + + /// + /// Takes any number of observable sequences of tensors and concatenates the input tensors along the specified dimension by zipping each tensor together. + /// + public IObservable Process(params IObservable[] sources) + { + return sources.Aggregate((current, next) => + current.Zip(next, (tensor1, tensor2) => + cat([tensor1, tensor2], Dimension))); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + var tensor1 = value.Item1; + var tensor2 = value.Item2; + return cat([tensor1, tensor2], Dimension); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs b/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs new file mode 100644 index 00000000..3683d2a6 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input tensor to the specified scalar type. + /// + [Combinator] + [Description("Converts the input tensor to the specified scalar type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ConvertDataType + { + /// + /// The scalar type to which to convert the input tensor. + /// + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Returns an observable sequence that converts the input tensor to the specified scalar type. + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => + { + return tensor.to_type(Type); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs b/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs new file mode 100644 index 00000000..4585b70b --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs @@ -0,0 +1,245 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reactive.Linq; +using System.Reflection; +using System.Xml.Serialization; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// + [Combinator] + [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] + [WorkflowElementCategory(ElementCategory.Source)] + public class CreateTensor : ExpressionBuilder + { + Range argumentRange = new Range(0, 1); + + /// + public override Range ArgumentRange => argumentRange; + + /// + /// The data type of the tensor elements. + /// + public TensorDataType Type + { + get => scalarType; + set => scalarType = value; + } + + private TensorDataType scalarType = TensorDataType.Float32; + + /// + /// The values of the tensor elements. Uses Python-like syntax to specify the tensor values. + /// + public string Values + { + get => values; + set + { + values = value.Replace("False", "false").Replace("True", "true"); + } + } + + private string values = "[0]"; + + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + public Device Device { get => device; set => device = value; } + + private Device device = null; + + private Expression BuildTensorFromArray(Array arrayValues, Type returnType) + { + var rank = arrayValues.Rank; + var lengths = new int[rank]; + for (int i = 0; i < rank; i++) + { + lengths[i] = arrayValues.GetLength(i); + } + + var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); + var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); + var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); + + var assignments = new List(); + for (int i = 0; i < values.Length; i++) + { + var indices = new Expression[rank]; + int temp = i; + for (int j = rank - 1; j >= 0; j--) + { + indices[j] = Expression.Constant(temp % lengths[j]); + temp /= lengths[j]; + } + var value = Expression.Constant(arrayValues.GetValue(indices.Select(e => ((ConstantExpression)e).Value).Cast().ToArray())); + var arrayAccess = Expression.ArrayAccess(arrayVariable, indices); + var assignArrayValue = Expression.Assign(arrayAccess, value); + assignments.Add(assignArrayValue); + } + + var tensorDataInitializationBlock = Expression.Block( + arrayVariable, + assignArray, + Expression.Block(assignments), + arrayVariable + ); + + var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + arrayVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool), + typeof(string).MakeArrayType() + ] + ); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorDataInitializationBlock, + Expression.Constant(scalarType, typeof(ScalarType?)), + Expression.Constant(device, typeof(Device)), + Expression.Constant(false, typeof(bool)), + Expression.Constant(null, typeof(string).MakeArrayType()) + ); + + var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + private Expression BuildTensorFromScalarValue(object scalarValue, Type returnType) + { + var valueVariable = Expression.Variable(returnType, "value"); + var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); + + var tensorDataInitializationBlock = Expression.Block( + valueVariable, + assignValue, + valueVariable + ); + + var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(Device), + typeof(bool) + ] + ); + + var tensorCreationMethodArguments = new Expression[] { + Expression.Constant(device, typeof(Device) ), + Expression.Constant(false, typeof(bool) ) + }; + + if (tensorCreationMethodInfo == null) + { + tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool) + ] + ); + + tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + Expression.Constant(scalarType, typeof(ScalarType?)) + ).ToArray(); + } + + tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + tensorDataInitializationBlock + ).ToArray(); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorCreationMethodArguments + ); + + var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + /// + public override Expression Build(IEnumerable arguments) + { + var returnType = Helpers.TensorDataTypeHelper.GetTypeFromTensorDataType(scalarType); + var argTypes = arguments.Select(arg => arg.Type).ToArray(); + + var methodInfoArgumentTypes = new Type[] { + typeof(Tensor) + }; + + var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where(m => m.Name == "Process") + .ToArray(); + + var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) + .MakeGenericMethod( + arguments + .First() + .Type + .GetGenericArguments()[0] + ) : methods.FirstOrDefault(m => !m.IsGenericMethod); + + var tensorValues = Helpers.DataHelper.ParseString(values, returnType); + var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); + var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); + + try + { + return Expression.Call( + Expression.Constant(this), + methodInfo, + methodArguments + ); + } + finally + { + values = Helpers.DataHelper.SerializeData(tensorValues).Replace("False", "false").Replace("True", "true"); + scalarType = Helpers.TensorDataTypeHelper.GetTensorDataTypeFromType(returnType); + } + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values. + /// + public IObservable Process(Tensor tensor) + { + return Observable.Return(tensor); + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. + /// + public IObservable Process(IObservable source, Tensor tensor) + { + return Observable.Select(source, (_) => tensor); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/Index.cs b/src/Bonsai.ML.Torch/Tensors/Index.cs new file mode 100644 index 00000000..78024237 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Index.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. + /// Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis. + /// + [Combinator] + [Description("Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Index + { + /// + /// The indices to use for indexing the tensor. + /// + public string Indexes { get; set; } = string.Empty; + + /// + /// Indexes the input tensor with the specified indices. + /// + /// + /// + public IObservable Process(IObservable source) + { + var index = Helpers.IndexHelper.ParseString(Indexes); + return source.Select(tensor => { + return tensor.index(index); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs new file mode 100644 index 00000000..2258467f --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using TorchSharp; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Initializes the Torch device with the specified device type. + /// + [Combinator] + [Description("Initializes the Torch device with the specified device type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class InitializeTorchDevice + { + /// + /// The device type to initialize. + /// + public DeviceType DeviceType { get; set; } + + /// + /// Initializes the Torch device with the specified device type. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => + { + InitializeDeviceType(DeviceType); + return Observable.Return(new Device(DeviceType)); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Linspace.cs b/src/Bonsai.ML.Torch/Tensors/Linspace.cs new file mode 100644 index 00000000..6e7495f8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Linspace.cs @@ -0,0 +1,40 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. + /// + [Combinator] + [Description("Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Linspace + { + /// + /// The start of the range. + /// + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + public int End { get; set; } = 1; + + /// + /// The number of points to generate. + /// + public int Count { get; set; } = 10; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(linspace(Start, End, Count))); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs b/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs new file mode 100644 index 00000000..77f4cecb --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Collections.Generic; +using static TorchSharp.torch; +using System.Linq; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class MeshGrid + { + /// + /// The indexing mode to use for the mesh grid. + /// + public string Indexing { get; set; } = "ij"; + + /// + /// Creates a mesh grid from the input tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(tensors => meshgrid(tensors, indexing: Indexing)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Ones.cs b/src/Bonsai.ML.Torch/Tensors/Ones.cs new file mode 100644 index 00000000..77768dd1 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Ones.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a tensor filled with ones. + /// + [Combinator] + [Description("Creates a tensor filled with ones.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Ones + { + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with ones. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Permute.cs b/src/Bonsai.ML.Torch/Tensors/Permute.cs new file mode 100644 index 00000000..317e34f8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Permute.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Permutes the dimensions of the input tensor according to the specified permutation. + /// + [Combinator] + [Description("Permutes the dimensions of the input tensor according to the specified permutation.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Permute + { + /// + /// The permutation of the dimensions. + /// + public long[] Dimensions { get; set; } = [0]; + + /// + /// Returns an observable sequence that permutes the dimensions of the input tensor according to the specified permutation. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.permute(Dimensions); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Reshape.cs b/src/Bonsai.ML.Torch/Tensors/Reshape.cs new file mode 100644 index 00000000..5d3e9412 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Reshape.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + [Combinator] + [Description("Reshapes the input tensor according to the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Reshape + { + /// + /// The dimensions of the reshaped tensor. + /// + public long[] Dimensions { get; set; } = [0]; + + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.reshape(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Set.cs b/src/Bonsai.ML.Torch/Tensors/Set.cs new file mode 100644 index 00000000..a4d8b2d2 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Set.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Sets the value of the input tensor at the specified index. + /// + [Combinator] + [Description("Sets the value of the input tensor at the specified index.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Set + { + /// + /// The index at which to set the value. + /// + public string Index + { + get => Helpers.IndexHelper.SerializeIndexes(indexes); + set => indexes = Helpers.IndexHelper.ParseString(value); + } + + private TensorIndex[] indexes; + + /// + /// The value to set at the specified index. + /// + [XmlIgnore] + public Tensor Value { get; set; } = null; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.index_put_(Value, indexes); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs b/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs new file mode 100644 index 00000000..de1ba8d2 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs @@ -0,0 +1,56 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. + /// + public enum TensorDataType + { + /// + /// 8-bit unsigned integer. + /// + Byte = ScalarType.Byte, + + /// + /// 8-bit signed integer. + /// + Int8 = ScalarType.Int8, + + /// + /// 16-bit signed integer. + /// + Int16 = ScalarType.Int16, + + /// + /// 32-bit signed integer. + /// + Int32 = ScalarType.Int32, + + /// + /// 64-bit signed integer. + /// + Int64 = ScalarType.Int64, + + /// + /// 32-bit floating point. + /// + Float32 = ScalarType.Float32, + + /// + /// 64-bit floating point. + /// + Float64 = ScalarType.Float64, + + /// + /// Boolean. + /// + Bool = ScalarType.Bool + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToArray.cs b/src/Bonsai.ML.Torch/Tensors/ToArray.cs new file mode 100644 index 00000000..70083ad2 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToArray.cs @@ -0,0 +1,73 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input tensor into an array of the specified element type. + /// + [Combinator] + [Description("Converts the input tensor into an array of the specified element type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + public class ToArray : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public ToArray() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type mapping used to convert the input tensor into an array. + /// + public TypeMapping Type { get; set; } + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + methodInfo = methodInfo.MakeGenericMethod(returnType); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into an array of the specified element type. + /// + /// + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + return tensor.data().ToArray(); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToDevice.cs b/src/Bonsai.ML.Torch/Tensors/ToDevice.cs new file mode 100644 index 00000000..4aa1b92a --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToDevice.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Moves the input tensor to the specified device. + /// + [Combinator] + [Description("Moves the input tensor to the specified device.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToDevice + { + /// + /// The device to which the input tensor should be moved. + /// + public Device Device { get; set; } + + /// + /// Returns the input tensor moved to the specified device. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.to(Device); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToImage.cs b/src/Bonsai.ML.Torch/Tensors/ToImage.cs new file mode 100644 index 00000000..eebf8399 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToImage.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input tensor into an OpenCV image. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToImage + { + /// + /// Converts the input tensor into an OpenCV image. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToImage); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToMat.cs b/src/Bonsai.ML.Torch/Tensors/ToMat.cs new file mode 100644 index 00000000..756ac636 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToMat.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input tensor into an OpenCV mat. + /// + [Combinator] + [Description("Converts the input tensor into an OpenCV mat.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToMat + { + /// + /// Converts the input tensor into an OpenCV mat. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToMat); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToTensor.cs b/src/Bonsai.ML.Torch/Tensors/ToTensor.cs new file mode 100644 index 00000000..753d4422 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToTensor.cs @@ -0,0 +1,134 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input value into a tensor. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToTensor + { + /// + /// Converts an int into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a double into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a byte into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a bool into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a float into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a long into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a short into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an array into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an IplImage into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToTensor); + } + + /// + /// Converts a Mat into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToTensor); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Zeros.cs b/src/Bonsai.ML.Torch/Tensors/Zeros.cs new file mode 100644 index 00000000..256a43ed --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Zeros.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a tensor filled with zeros. + /// + [Combinator] + [Description("Creates a tensor filled with zeros.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Zeros + { + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with zeros. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + } +} \ No newline at end of file From 1386ff3ac3605ef4a8a6f9ea9fc1cdb833d2f5f4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 12:53:06 +0100 Subject: [PATCH 024/131] Removed previous Bonsai.ML.Tensors directory and contents --- src/Bonsai.ML.Tensors/Arange.cs | 40 --- .../Bonsai.ML.Tensors.csproj | 17 -- src/Bonsai.ML.Tensors/Concat.cs | 45 ---- src/Bonsai.ML.Tensors/ConvertDataType.cs | 32 --- src/Bonsai.ML.Tensors/CreateTensor.cs | 245 ------------------ src/Bonsai.ML.Tensors/Helpers/DataHelper.cs | 190 -------------- src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs | 91 ------- src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs | 169 ------------ .../Helpers/TensorDataTypeHelper.cs | 52 ---- src/Bonsai.ML.Tensors/Index.cs | 35 --- .../InitializeTorchDevice.cs | 35 --- src/Bonsai.ML.Tensors/Linspace.cs | 40 --- src/Bonsai.ML.Tensors/MeshGrid.cs | 33 --- src/Bonsai.ML.Tensors/Ones.cs | 30 --- src/Bonsai.ML.Tensors/Permute.cs | 33 --- src/Bonsai.ML.Tensors/Reshape.cs | 32 --- src/Bonsai.ML.Tensors/Set.cs | 48 ---- src/Bonsai.ML.Tensors/TensorDataType.cs | 56 ---- src/Bonsai.ML.Tensors/ToArray.cs | 73 ------ src/Bonsai.ML.Tensors/ToDevice.cs | 34 --- src/Bonsai.ML.Tensors/ToImage.cs | 28 -- src/Bonsai.ML.Tensors/ToMat.cs | 28 -- src/Bonsai.ML.Tensors/ToTensor.cs | 134 ---------- src/Bonsai.ML.Tensors/Zeros.cs | 30 --- 24 files changed, 1550 deletions(-) delete mode 100644 src/Bonsai.ML.Tensors/Arange.cs delete mode 100644 src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj delete mode 100644 src/Bonsai.ML.Tensors/Concat.cs delete mode 100644 src/Bonsai.ML.Tensors/ConvertDataType.cs delete mode 100644 src/Bonsai.ML.Tensors/CreateTensor.cs delete mode 100644 src/Bonsai.ML.Tensors/Helpers/DataHelper.cs delete mode 100644 src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs delete mode 100644 src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs delete mode 100644 src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs delete mode 100644 src/Bonsai.ML.Tensors/Index.cs delete mode 100644 src/Bonsai.ML.Tensors/InitializeTorchDevice.cs delete mode 100644 src/Bonsai.ML.Tensors/Linspace.cs delete mode 100644 src/Bonsai.ML.Tensors/MeshGrid.cs delete mode 100644 src/Bonsai.ML.Tensors/Ones.cs delete mode 100644 src/Bonsai.ML.Tensors/Permute.cs delete mode 100644 src/Bonsai.ML.Tensors/Reshape.cs delete mode 100644 src/Bonsai.ML.Tensors/Set.cs delete mode 100644 src/Bonsai.ML.Tensors/TensorDataType.cs delete mode 100644 src/Bonsai.ML.Tensors/ToArray.cs delete mode 100644 src/Bonsai.ML.Tensors/ToDevice.cs delete mode 100644 src/Bonsai.ML.Tensors/ToImage.cs delete mode 100644 src/Bonsai.ML.Tensors/ToMat.cs delete mode 100644 src/Bonsai.ML.Tensors/ToTensor.cs delete mode 100644 src/Bonsai.ML.Tensors/Zeros.cs diff --git a/src/Bonsai.ML.Tensors/Arange.cs b/src/Bonsai.ML.Tensors/Arange.cs deleted file mode 100644 index 2a1eda40..00000000 --- a/src/Bonsai.ML.Tensors/Arange.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using TorchSharp; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a 1-D tensor of values within a given range given the start, end, and step. - /// - [Combinator] - [Description("Creates a 1-D tensor of values within a given range given the start, end, and step.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Arange - { - /// - /// The start of the range. - /// - public int Start { get; set; } = 0; - - /// - /// The end of the range. - /// - public int End { get; set; } = 10; - - /// - /// The step of the range. - /// - public int Step { get; set; } = 1; - - /// - /// Generates an observable sequence of 1-D tensors created with the function. - /// - public IObservable Process() - { - return Observable.Defer(() => Observable.Return(arange(Start, End, Step))); - } - } -} diff --git a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj deleted file mode 100644 index 8d87ac9b..00000000 --- a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj +++ /dev/null @@ -1,17 +0,0 @@ - - - Bonsai.ML.Tensors - A Bonsai package for TorchSharp tensor manipulations. - Bonsai Rx ML Tensors TorchSharp - net472;netstandard2.0 - true - - - - - - - - - - \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Concat.cs b/src/Bonsai.ML.Tensors/Concat.cs deleted file mode 100644 index 1dd99b7b..00000000 --- a/src/Bonsai.ML.Tensors/Concat.cs +++ /dev/null @@ -1,45 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Concatenates tensors along a given dimension. - /// - [Combinator] - [Description("Concatenates tensors along a given dimension.")] - [WorkflowElementCategory(ElementCategory.Combinator)] - public class Concat - { - /// - /// The dimension along which to concatenate the tensors. - /// - public long Dimension { get; set; } = 0; - - /// - /// Takes any number of observable sequences of tensors and concatenates the input tensors along the specified dimension by zipping each tensor together. - /// - public IObservable Process(params IObservable[] sources) - { - return sources.Aggregate((current, next) => - current.Zip(next, (tensor1, tensor2) => - cat([tensor1, tensor2], Dimension))); - } - - /// - /// Concatenates the input tensors along the specified dimension. - /// - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - var tensor1 = value.Item1; - var tensor2 = value.Item2; - return cat([tensor1, tensor2], Dimension); - }); - } - } -} diff --git a/src/Bonsai.ML.Tensors/ConvertDataType.cs b/src/Bonsai.ML.Tensors/ConvertDataType.cs deleted file mode 100644 index 14b0db84..00000000 --- a/src/Bonsai.ML.Tensors/ConvertDataType.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input tensor to the specified scalar type. - /// - [Combinator] - [Description("Converts the input tensor to the specified scalar type.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ConvertDataType - { - /// - /// The scalar type to which to convert the input tensor. - /// - public ScalarType Type { get; set; } = ScalarType.Float32; - - /// - /// Returns an observable sequence that converts the input tensor to the specified scalar type. - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => - { - return tensor.to_type(Type); - }); - } - } -} diff --git a/src/Bonsai.ML.Tensors/CreateTensor.cs b/src/Bonsai.ML.Tensors/CreateTensor.cs deleted file mode 100644 index 712c7243..00000000 --- a/src/Bonsai.ML.Tensors/CreateTensor.cs +++ /dev/null @@ -1,245 +0,0 @@ -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Linq.Expressions; -using System.Reactive.Linq; -using System.Reflection; -using System.Xml.Serialization; -using Bonsai.Expressions; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". - /// - [Combinator] - [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] - [WorkflowElementCategory(ElementCategory.Source)] - public class CreateTensor : ExpressionBuilder - { - Range argumentRange = new Range(0, 1); - - /// - public override Range ArgumentRange => argumentRange; - - /// - /// The data type of the tensor elements. - /// - public TensorDataType Type - { - get => scalarType; - set => scalarType = value; - } - - private TensorDataType scalarType = TensorDataType.Float32; - - /// - /// The values of the tensor elements. Uses Python-like syntax to specify the tensor values. - /// - public string Values - { - get => values; - set - { - values = value.Replace("False", "false").Replace("True", "true"); - } - } - - private string values = "[0]"; - - /// - /// The device on which to create the tensor. - /// - [XmlIgnore] - public Device Device { get => device; set => device = value; } - - private Device device = null; - - private Expression BuildTensorFromArray(Array arrayValues, Type returnType) - { - var rank = arrayValues.Rank; - var lengths = new int[rank]; - for (int i = 0; i < rank; i++) - { - lengths[i] = arrayValues.GetLength(i); - } - - var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); - var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); - var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); - - var assignments = new List(); - for (int i = 0; i < values.Length; i++) - { - var indices = new Expression[rank]; - int temp = i; - for (int j = rank - 1; j >= 0; j--) - { - indices[j] = Expression.Constant(temp % lengths[j]); - temp /= lengths[j]; - } - var value = Expression.Constant(arrayValues.GetValue(indices.Select(e => ((ConstantExpression)e).Value).Cast().ToArray())); - var arrayAccess = Expression.ArrayAccess(arrayVariable, indices); - var assignArrayValue = Expression.Assign(arrayAccess, value); - assignments.Add(assignArrayValue); - } - - var tensorDataInitializationBlock = Expression.Block( - arrayVariable, - assignArray, - Expression.Block(assignments), - arrayVariable - ); - - var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( - "tensor", [ - arrayVariable.Type, - typeof(ScalarType?), - typeof(Device), - typeof(bool), - typeof(string).MakeArrayType() - ] - ); - - var tensorAssignment = Expression.Call( - tensorCreationMethodInfo, - tensorDataInitializationBlock, - Expression.Constant(scalarType, typeof(ScalarType?)), - Expression.Constant(device, typeof(Device)), - Expression.Constant(false, typeof(bool)), - Expression.Constant(null, typeof(string).MakeArrayType()) - ); - - var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); - var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); - - var buildTensor = Expression.Block( - tensorVariable, - assignTensor, - tensorVariable - ); - - return buildTensor; - } - - private Expression BuildTensorFromScalarValue(object scalarValue, Type returnType) - { - var valueVariable = Expression.Variable(returnType, "value"); - var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); - - var tensorDataInitializationBlock = Expression.Block( - valueVariable, - assignValue, - valueVariable - ); - - var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( - "tensor", [ - valueVariable.Type, - typeof(Device), - typeof(bool) - ] - ); - - var tensorCreationMethodArguments = new Expression[] { - Expression.Constant(device, typeof(Device) ), - Expression.Constant(false, typeof(bool) ) - }; - - if (tensorCreationMethodInfo == null) - { - tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( - "tensor", [ - valueVariable.Type, - typeof(ScalarType?), - typeof(Device), - typeof(bool) - ] - ); - - tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( - Expression.Constant(scalarType, typeof(ScalarType?)) - ).ToArray(); - } - - tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( - tensorDataInitializationBlock - ).ToArray(); - - var tensorAssignment = Expression.Call( - tensorCreationMethodInfo, - tensorCreationMethodArguments - ); - - var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); - var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); - - var buildTensor = Expression.Block( - tensorVariable, - assignTensor, - tensorVariable - ); - - return buildTensor; - } - - /// - public override Expression Build(IEnumerable arguments) - { - var returnType = Helpers.TensorDataTypeHelper.GetTypeFromTensorDataType(scalarType); - var argTypes = arguments.Select(arg => arg.Type).ToArray(); - - var methodInfoArgumentTypes = new Type[] { - typeof(Tensor) - }; - - var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) - .Where(m => m.Name == "Process") - .ToArray(); - - var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) - .MakeGenericMethod( - arguments - .First() - .Type - .GetGenericArguments()[0] - ) : methods.FirstOrDefault(m => !m.IsGenericMethod); - - var tensorValues = Helpers.DataHelper.ParseString(values, returnType); - var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); - var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); - - try - { - return Expression.Call( - Expression.Constant(this), - methodInfo, - methodArguments - ); - } - finally - { - values = Helpers.DataHelper.SerializeData(tensorValues).Replace("False", "false").Replace("True", "true"); - scalarType = Helpers.TensorDataTypeHelper.GetTensorDataTypeFromType(returnType); - } - } - - /// - /// Returns an observable sequence that creates a tensor from the specified values. - /// - public IObservable Process(Tensor tensor) - { - return Observable.Return(tensor); - } - - /// - /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. - /// - public IObservable Process(IObservable source, Tensor tensor) - { - return Observable.Select(source, (_) => tensor); - } - } -} diff --git a/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs b/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs deleted file mode 100644 index 1bbf3228..00000000 --- a/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs +++ /dev/null @@ -1,190 +0,0 @@ -using System; -using System.Text; -using System.Collections.Generic; -using System.Linq; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; - -namespace Bonsai.ML.Tensors.Helpers -{ - /// - /// Provides helper methods for parsing tensor data types. - /// - public static class DataHelper - { - - /// - /// Serializes the input data into a string representation. - /// - public static string SerializeData(object data) - { - if (data is Array array) - { - return SerializeArray(array); - } - else - { - return JsonConvert.SerializeObject(data); - } - } - - /// - /// Serializes the input array into a string representation. - /// - public static string SerializeArray(Array array) - { - StringBuilder sb = new StringBuilder(); - SerializeArrayRecursive(array, sb, [0]); - return sb.ToString(); - } - - private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) - { - if (indices.Length < array.Rank) - { - sb.Append("["); - int length = array.GetLength(indices.Length); - for (int i = 0; i < length; i++) - { - int[] newIndices = new int[indices.Length + 1]; - indices.CopyTo(newIndices, 0); - newIndices[indices.Length] = i; - SerializeArrayRecursive(array, sb, newIndices); - if (i < length - 1) - { - sb.Append(", "); - } - } - sb.Append("]"); - } - else - { - object value = array.GetValue(indices); - sb.Append(value.ToString()); - } - } - - private static bool IsValidJson(string input) - { - int squareBrackets = 0; - foreach (char c in input) - { - if (c == '[') squareBrackets++; - else if (c == ']') squareBrackets--; - } - return squareBrackets == 0; - } - - /// - /// Parses the input string into an object of the specified type. - /// - public static object ParseString(string input, Type dtype) - { - if (!IsValidJson(input)) - { - throw new ArgumentException("JSON is invalid."); - } - var obj = JsonConvert.DeserializeObject(input); - int depth = ParseDepth(obj); - if (depth == 0) - { - return Convert.ChangeType(input, dtype); - } - int[] dimensions = ParseDimensions(obj, depth); - var resultArray = Array.CreateInstance(dtype, dimensions); - PopulateArray(obj, resultArray, [0], dtype); - return resultArray; - } - - private static int ParseDepth(JToken token, int currentDepth = 0) - { - if (token is JArray arr && arr.Count > 0) - { - return ParseDepth(arr[0], currentDepth + 1); - } - return currentDepth; - } - - private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) - { - if (depth == 0 || !(token is JArray)) - { - return [0]; - } - - List dimensions = new List(); - JToken current = token; - - while (current != null && current is JArray) - { - JArray currentArray = current as JArray; - dimensions.Add(currentArray.Count); - if (currentArray.Count > 0) - { - if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) - { - throw new Exception("Error parsing input. Dimensions are inconsistent."); - } - - if (!(currentArray.First() is JArray)) - { - if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) - { - throw new Exception("Error parsing types. All values must be of the same type and only numeric or boolean types are supported."); - } - } - } - - current = currentArray.Count > 0 ? currentArray[0] : null; - } - - if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) - { - var subArrayDimensions = new HashSet(); - foreach (JArray subArr in arr) - { - int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); - subArrayDimensions.Add(string.Join(",", subDims)); - } - - if (subArrayDimensions.Count > 1) - { - throw new ArgumentException("Inconsistent array dimensions."); - } - } - - return dimensions.ToArray(); - } - - private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) - { - if (token is JArray arr) - { - for (int i = 0; i < arr.Count; i++) - { - int[] newIndices = new int[indices.Length + 1]; - Array.Copy(indices, newIndices, indices.Length); - newIndices[newIndices.Length - 1] = i; - PopulateArray(arr[i], array, newIndices, dtype); - } - } - else - { - var values = ConvertType(token, dtype); - array.SetValue(values, indices); - } - } - - private static object ConvertType(object value, Type targetType) - { - try - { - return Convert.ChangeType(value, targetType); - } - catch (Exception ex) - { - throw new Exception("Error parsing type: ", ex); - } - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs b/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs deleted file mode 100644 index 785eccea..00000000 --- a/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors.Helpers -{ - /// - /// Provides helper methods to parse tensor indexes. - /// - public static class IndexHelper - { - - /// - /// Parses the input string into an array of tensor indexes. - /// - /// - public static TensorIndex[] ParseString(string input) - { - if (string.IsNullOrEmpty(input)) - { - return [0]; - } - - var indexStrings = input.Split(','); - var indices = new TensorIndex[indexStrings.Length]; - - for (int i = 0; i < indexStrings.Length; i++) - { - var indexString = indexStrings[i].Trim(); - if (int.TryParse(indexString, out int intIndex)) - { - indices[i] = TensorIndex.Single(intIndex); - } - else if (indexString == ":") - { - indices[i] = TensorIndex.Colon; - } - else if (indexString == "None") - { - indices[i] = TensorIndex.None; - } - else if (indexString == "...") - { - indices[i] = TensorIndex.Ellipsis; - } - else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") - { - indices[i] = TensorIndex.Bool(indexString.ToLower() == "true"); - } - else if (indexString.Contains(":")) - { - var rangeParts = indexString.Split(':'); - if (rangeParts.Length == 0) - { - indices[i] = TensorIndex.Slice(); - } - else if (rangeParts.Length == 1) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0])); - } - else if (rangeParts.Length == 2) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); - } - else if (rangeParts.Length == 3) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); - } - else - { - throw new Exception($"Invalid index format: {indexString}"); - } - } - else - { - throw new Exception($"Invalid index format: {indexString}"); - } - } - return indices; - } - - /// - /// Serializes the input array of tensor indexes into a string representation. - /// - /// - /// - public static string SerializeIndexes(TensorIndex[] indexes) - { - return string.Join(", ", indexes); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs deleted file mode 100644 index 265f2119..00000000 --- a/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs +++ /dev/null @@ -1,169 +0,0 @@ -using System; -using System.Runtime.InteropServices; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors.Helpers -{ - /// - /// Helper class to convert between OpenCV mats, images and Torch tensors. - /// - public static class OpenCVHelper - { - private static Dictionary bitDepthLookup = new Dictionary { - { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, - { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, - { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, - { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, - { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, - { ScalarType.Int8, (IplDepth.S8, Depth.S8) } - }; - - private static ConcurrentDictionary deleters = new ConcurrentDictionary(); - - internal delegate void GCHandleDeleter(IntPtr memory); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_data(IntPtr handle); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); - - /// - /// Creates a tensor from a pointer to the data and the dimensions of the tensor. - /// - /// - /// - /// - /// - public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) - { - var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); - var gchp = GCHandle.ToIntPtr(dataHandle); - GCHandleDeleter deleter = null; - - deleter = new GCHandleDeleter((IntPtr ptrHandler) => - { - GCHandle.FromIntPtr(gchp).Free(); - deleters.TryRemove(deleter, out deleter); - }); - deleters.TryAdd(deleter, deleter); - - fixed (long* dimensionsPtr = dimensions) - { - IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - if (tensorHandle == IntPtr.Zero) { - GC.Collect(); - GC.WaitForPendingFinalizers(); - tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - } - if (tensorHandle == IntPtr.Zero) { CheckForErrors(); } - var output = Tensor.UnsafeCreateTensor(tensorHandle); - return output; - } - } - - /// - /// Converts an OpenCV image to a Torch tensor. - /// - /// - /// - public static Tensor ToTensor(IplImage image) - { - if (image == null) - { - return empty([ 0, 0, 0 ]); - } - - int width = image.Width; - int height = image.Height; - int channels = image.Channels; - - var iplDepth = image.Depth; - var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; - - IntPtr tensorDataPtr = image.ImageData; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); - } - - /// - /// Converts an OpenCV mat to a Torch tensor. - /// - /// - /// - public static Tensor ToTensor(Mat mat) - { - if (mat == null) - { - return empty([0, 0, 0 ]); - } - - int width = mat.Size.Width; - int height = mat.Size.Height; - int channels = mat.Channels; - - var depth = mat.Depth; - var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; - - IntPtr tensorDataPtr = mat.Data; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); - } - - /// - /// Converts a Torch tensor to an OpenCV image. - /// - /// - /// - public unsafe static IplImage ToImage(Tensor tensor) - { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; - - var tensorType = tensor.dtype; - var iplDepth = bitDepthLookup[tensorType].IplDepth; - - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); - - return image; - } - - /// - /// Converts a Torch tensor to an OpenCV mat. - /// - /// - /// - public unsafe static Mat ToMat(Tensor tensor) - { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; - - var tensorType = tensor.dtype; - var depth = bitDepthLookup[tensorType].Depth; - - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); - - return mat; - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs deleted file mode 100644 index 7ea03f65..00000000 --- a/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Bonsai.ML.Tensors.Helpers -{ - /// - /// Provides helper methods for working with tensor data types. - /// - public class TensorDataTypeHelper - { - private static readonly Dictionary _lookup = new Dictionary - { - { TensorDataType.Byte, (typeof(byte), "byte") }, - { TensorDataType.Int16, (typeof(short), "short") }, - { TensorDataType.Int32, (typeof(int), "int") }, - { TensorDataType.Int64, (typeof(long), "long") }, - { TensorDataType.Float32, (typeof(float), "float") }, - { TensorDataType.Float64, (typeof(double), "double") }, - { TensorDataType.Bool, (typeof(bool), "bool") }, - { TensorDataType.Int8, (typeof(sbyte), "sbyte") }, - }; - - /// - /// Returns the type corresponding to the specified tensor data type. - /// - /// - /// - public static Type GetTypeFromTensorDataType(TensorDataType type) => _lookup[type].Type; - - /// - /// Returns the string representation corresponding to the specified tensor data type. - /// - /// - /// - public static string GetStringFromTensorDataType(TensorDataType type) => _lookup[type].StringValue; - - /// - /// Returns the tensor data type corresponding to the specified string representation. - /// - /// - /// - public static TensorDataType GetTensorDataTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; - - /// - /// Returns the tensor data type corresponding to the specified type. - /// - /// - /// - public static TensorDataType GetTensorDataTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Index.cs b/src/Bonsai.ML.Tensors/Index.cs deleted file mode 100644 index 3c1948f9..00000000 --- a/src/Bonsai.ML.Tensors/Index.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. - /// Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis. - /// - [Combinator] - [Description("Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Index - { - /// - /// The indices to use for indexing the tensor. - /// - public string Indexes { get; set; } = string.Empty; - - /// - /// Indexes the input tensor with the specified indices. - /// - /// - /// - public IObservable Process(IObservable source) - { - var index = Helpers.IndexHelper.ParseString(Indexes); - return source.Select(tensor => { - return tensor.index(index); - }); - } - } -} diff --git a/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs b/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs deleted file mode 100644 index dc9123f0..00000000 --- a/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using TorchSharp; - -namespace Bonsai.ML.Tensors -{ - /// - /// Initializes the Torch device with the specified device type. - /// - [Combinator] - [Description("Initializes the Torch device with the specified device type.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class InitializeTorchDevice - { - /// - /// The device type to initialize. - /// - public DeviceType DeviceType { get; set; } - - /// - /// Initializes the Torch device with the specified device type. - /// - /// - public IObservable Process() - { - return Observable.Defer(() => - { - InitializeDeviceType(DeviceType); - return Observable.Return(new Device(DeviceType)); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Linspace.cs b/src/Bonsai.ML.Tensors/Linspace.cs deleted file mode 100644 index aa263500..00000000 --- a/src/Bonsai.ML.Tensors/Linspace.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. - /// - [Combinator] - [Description("Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Linspace - { - /// - /// The start of the range. - /// - public int Start { get; set; } = 0; - - /// - /// The end of the range. - /// - public int End { get; set; } = 1; - - /// - /// The number of points to generate. - /// - public int Count { get; set; } = 10; - - /// - /// Generates an observable sequence of 1-D tensors created with the function. - /// - /// - public IObservable Process() - { - return Observable.Defer(() => Observable.Return(linspace(Start, End, Count))); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/MeshGrid.cs b/src/Bonsai.ML.Tensors/MeshGrid.cs deleted file mode 100644 index 6b0a2c73..00000000 --- a/src/Bonsai.ML.Tensors/MeshGrid.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using System.Collections.Generic; -using static TorchSharp.torch; -using System.Linq; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. - /// - [Combinator] - [Description("")] - [WorkflowElementCategory(ElementCategory.Source)] - public class MeshGrid - { - /// - /// The indexing mode to use for the mesh grid. - /// - public string Indexing { get; set; } = "ij"; - - /// - /// Creates a mesh grid from the input tensors. - /// - /// - /// - public IObservable Process(IObservable> source) - { - return source.Select(tensors => meshgrid(tensors, indexing: Indexing)); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Ones.cs b/src/Bonsai.ML.Tensors/Ones.cs deleted file mode 100644 index 499012bd..00000000 --- a/src/Bonsai.ML.Tensors/Ones.cs +++ /dev/null @@ -1,30 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a tensor filled with ones. - /// - [Combinator] - [Description("Creates a tensor filled with ones.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Ones - { - /// - /// The size of the tensor. - /// - public long[] Size { get; set; } = [0]; - - /// - /// Generates an observable sequence of tensors filled with ones. - /// - /// - public IObservable Process() - { - return Observable.Defer(() => Observable.Return(ones(Size))); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Permute.cs b/src/Bonsai.ML.Tensors/Permute.cs deleted file mode 100644 index 7f037d79..00000000 --- a/src/Bonsai.ML.Tensors/Permute.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Permutes the dimensions of the input tensor according to the specified permutation. - /// - [Combinator] - [Description("Permutes the dimensions of the input tensor according to the specified permutation.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Permute - { - /// - /// The permutation of the dimensions. - /// - public long[] Dimensions { get; set; } = [0]; - - /// - /// Returns an observable sequence that permutes the dimensions of the input tensor according to the specified permutation. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - return tensor.permute(Dimensions); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Reshape.cs b/src/Bonsai.ML.Tensors/Reshape.cs deleted file mode 100644 index 4fef3d83..00000000 --- a/src/Bonsai.ML.Tensors/Reshape.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Reshapes the input tensor according to the specified dimensions. - /// - [Combinator] - [Description("Reshapes the input tensor according to the specified dimensions.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Reshape - { - /// - /// The dimensions of the reshaped tensor. - /// - public long[] Dimensions { get; set; } = [0]; - - /// - /// Reshapes the input tensor according to the specified dimensions. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(input => input.reshape(Dimensions)); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Set.cs b/src/Bonsai.ML.Tensors/Set.cs deleted file mode 100644 index 7f6f8b92..00000000 --- a/src/Bonsai.ML.Tensors/Set.cs +++ /dev/null @@ -1,48 +0,0 @@ -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using System.Xml.Serialization; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Sets the value of the input tensor at the specified index. - /// - [Combinator] - [Description("Sets the value of the input tensor at the specified index.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Set - { - /// - /// The index at which to set the value. - /// - public string Index - { - get => Helpers.IndexHelper.SerializeIndexes(indexes); - set => indexes = Helpers.IndexHelper.ParseString(value); - } - - private TensorIndex[] indexes; - - /// - /// The value to set at the specified index. - /// - [XmlIgnore] - public Tensor Value { get; set; } = null; - - /// - /// Returns an observable sequence that sets the value of the input tensor at the specified index. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - return tensor.index_put_(Value, indexes); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/TensorDataType.cs b/src/Bonsai.ML.Tensors/TensorDataType.cs deleted file mode 100644 index a710a9ed..00000000 --- a/src/Bonsai.ML.Tensors/TensorDataType.cs +++ /dev/null @@ -1,56 +0,0 @@ -using System; -using System.Text; -using System.Collections.Generic; -using System.Linq; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. - /// - public enum TensorDataType - { - /// - /// 8-bit unsigned integer. - /// - Byte = ScalarType.Byte, - - /// - /// 8-bit signed integer. - /// - Int8 = ScalarType.Int8, - - /// - /// 16-bit signed integer. - /// - Int16 = ScalarType.Int16, - - /// - /// 32-bit signed integer. - /// - Int32 = ScalarType.Int32, - - /// - /// 64-bit signed integer. - /// - Int64 = ScalarType.Int64, - - /// - /// 32-bit floating point. - /// - Float32 = ScalarType.Float32, - - /// - /// 64-bit floating point. - /// - Float64 = ScalarType.Float64, - - /// - /// Boolean. - /// - Bool = ScalarType.Bool - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToArray.cs b/src/Bonsai.ML.Tensors/ToArray.cs deleted file mode 100644 index af35ab4f..00000000 --- a/src/Bonsai.ML.Tensors/ToArray.cs +++ /dev/null @@ -1,73 +0,0 @@ -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using System.Xml.Serialization; -using System.Linq.Expressions; -using System.Reflection; -using Bonsai.Expressions; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input tensor into an array of the specified element type. - /// - [Combinator] - [Description("Converts the input tensor into an array of the specified element type.")] - [WorkflowElementCategory(ElementCategory.Transform)] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - public class ToArray : SingleArgumentExpressionBuilder - { - /// - /// Initializes a new instance of the class. - /// - public ToArray() - { - Type = new TypeMapping(); - } - - /// - /// Gets or sets the type mapping used to convert the input tensor into an array. - /// - public TypeMapping Type { get; set; } - - /// - public override Expression Build(IEnumerable arguments) - { - TypeMapping typeMapping = Type; - var returnType = typeMapping.GetType().GetGenericArguments()[0]; - MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); - methodInfo = methodInfo.MakeGenericMethod(returnType); - Expression sourceExpression = arguments.First(); - - return Expression.Call( - Expression.Constant(this), - methodInfo, - sourceExpression - ); - } - - /// - /// Converts the input tensor into an array of the specified element type. - /// - /// - /// - /// - public IObservable Process(IObservable source) where T : unmanaged - { - return source.Select(tensor => - { - return tensor.data().ToArray(); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToDevice.cs b/src/Bonsai.ML.Tensors/ToDevice.cs deleted file mode 100644 index 574be5f3..00000000 --- a/src/Bonsai.ML.Tensors/ToDevice.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Moves the input tensor to the specified device. - /// - [Combinator] - [Description("Moves the input tensor to the specified device.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToDevice - { - /// - /// The device to which the input tensor should be moved. - /// - public Device Device { get; set; } - - /// - /// Returns the input tensor moved to the specified device. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - return tensor.to(Device); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToImage.cs b/src/Bonsai.ML.Tensors/ToImage.cs deleted file mode 100644 index e29a3825..00000000 --- a/src/Bonsai.ML.Tensors/ToImage.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input tensor into an OpenCV image. - /// - [Combinator] - [Description("")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToImage - { - /// - /// Converts the input tensor into an OpenCV image. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(Helpers.OpenCVHelper.ToImage); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToMat.cs b/src/Bonsai.ML.Tensors/ToMat.cs deleted file mode 100644 index 8a22f408..00000000 --- a/src/Bonsai.ML.Tensors/ToMat.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input tensor into an OpenCV mat. - /// - [Combinator] - [Description("Converts the input tensor into an OpenCV mat.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToMat - { - /// - /// Converts the input tensor into an OpenCV mat. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(Helpers.OpenCVHelper.ToMat); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToTensor.cs b/src/Bonsai.ML.Tensors/ToTensor.cs deleted file mode 100644 index 083e2797..00000000 --- a/src/Bonsai.ML.Tensors/ToTensor.cs +++ /dev/null @@ -1,134 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input value into a tensor. - /// - [Combinator] - [Description("")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToTensor - { - /// - /// Converts an int into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a double into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a byte into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a bool into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a float into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a long into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a short into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts an array into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts an IplImage into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(Helpers.OpenCVHelper.ToTensor); - } - - /// - /// Converts a Mat into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(Helpers.OpenCVHelper.ToTensor); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Zeros.cs b/src/Bonsai.ML.Tensors/Zeros.cs deleted file mode 100644 index af220641..00000000 --- a/src/Bonsai.ML.Tensors/Zeros.cs +++ /dev/null @@ -1,30 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a tensor filled with zeros. - /// - [Combinator] - [Description("Creates a tensor filled with zeros.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Zeros - { - /// - /// The size of the tensor. - /// - public long[] Size { get; set; } = [0]; - - /// - /// Generates an observable sequence of tensors filled with zeros. - /// - /// - public IObservable Process() - { - return Observable.Defer(() => Observable.Return(ones(Size))); - } - } -} \ No newline at end of file From 0e820d7a4cc306f973b927b0cd0224e450606af0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 12:54:13 +0100 Subject: [PATCH 025/131] Updated solution to reflect change --- Bonsai.ML.sln | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index 22b8a35a..30c6b6f1 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -30,7 +30,7 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.LinearDynamicalSy EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.HiddenMarkovModels.Design", "src\Bonsai.ML.HiddenMarkovModels.Design\Bonsai.ML.HiddenMarkovModels.Design.csproj", "{FC395DDC-62A4-4E14-A198-272AB05B33C7}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Tensors", "src\Bonsai.ML.Tensors\Bonsai.ML.Tensors.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch", "src\Bonsai.ML.Torch\Bonsai.ML.Torch.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution From f06a4451c8fb803e55cc4b0571ea9aa06ddb4f3b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 13:30:34 +0100 Subject: [PATCH 026/131] Moved Tensors namespace to main Torch namespace --- src/Bonsai.ML.Torch/{Tensors => }/Arange.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Concat.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ConvertDataType.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/CreateTensor.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Index.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/InitializeTorchDevice.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Linspace.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/MeshGrid.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Ones.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Permute.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Reshape.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Set.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/TensorDataType.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToArray.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToDevice.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToImage.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToMat.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToTensor.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Zeros.cs | 2 +- 19 files changed, 19 insertions(+), 19 deletions(-) rename src/Bonsai.ML.Torch/{Tensors => }/Arange.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/Concat.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/ConvertDataType.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/CreateTensor.cs (99%) rename src/Bonsai.ML.Torch/{Tensors => }/Index.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/InitializeTorchDevice.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/Linspace.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/MeshGrid.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/Ones.cs (95%) rename src/Bonsai.ML.Torch/{Tensors => }/Permute.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/Reshape.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/Set.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/TensorDataType.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/ToArray.cs (98%) rename src/Bonsai.ML.Torch/{Tensors => }/ToDevice.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/ToImage.cs (95%) rename src/Bonsai.ML.Torch/{Tensors => }/ToMat.cs (95%) rename src/Bonsai.ML.Torch/{Tensors => }/ToTensor.cs (99%) rename src/Bonsai.ML.Torch/{Tensors => }/Zeros.cs (95%) diff --git a/src/Bonsai.ML.Torch/Tensors/Arange.cs b/src/Bonsai.ML.Torch/Arange.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Arange.cs rename to src/Bonsai.ML.Torch/Arange.cs index 011d0708..14e3259b 100644 --- a/src/Bonsai.ML.Torch/Tensors/Arange.cs +++ b/src/Bonsai.ML.Torch/Arange.cs @@ -4,7 +4,7 @@ using static TorchSharp.torch; using TorchSharp; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a 1-D tensor of values within a given range given the start, end, and step. diff --git a/src/Bonsai.ML.Torch/Tensors/Concat.cs b/src/Bonsai.ML.Torch/Concat.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Concat.cs rename to src/Bonsai.ML.Torch/Concat.cs index 52275bb7..b07b211d 100644 --- a/src/Bonsai.ML.Torch/Tensors/Concat.cs +++ b/src/Bonsai.ML.Torch/Concat.cs @@ -4,7 +4,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Concatenates tensors along a given dimension. diff --git a/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs b/src/Bonsai.ML.Torch/ConvertDataType.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs rename to src/Bonsai.ML.Torch/ConvertDataType.cs index 3683d2a6..59981adc 100644 --- a/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs +++ b/src/Bonsai.ML.Torch/ConvertDataType.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input tensor to the specified scalar type. diff --git a/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs similarity index 99% rename from src/Bonsai.ML.Torch/Tensors/CreateTensor.cs rename to src/Bonsai.ML.Torch/CreateTensor.cs index 4585b70b..0100f920 100644 --- a/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -9,7 +9,7 @@ using Bonsai.Expressions; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". diff --git a/src/Bonsai.ML.Torch/Tensors/Index.cs b/src/Bonsai.ML.Torch/Index.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Index.cs rename to src/Bonsai.ML.Torch/Index.cs index 78024237..5b7b9192 100644 --- a/src/Bonsai.ML.Torch/Tensors/Index.cs +++ b/src/Bonsai.ML.Torch/Index.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. diff --git a/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs rename to src/Bonsai.ML.Torch/InitializeTorchDevice.cs index 2258467f..e82daa36 100644 --- a/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -4,7 +4,7 @@ using static TorchSharp.torch; using TorchSharp; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Initializes the Torch device with the specified device type. diff --git a/src/Bonsai.ML.Torch/Tensors/Linspace.cs b/src/Bonsai.ML.Torch/Linspace.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Linspace.cs rename to src/Bonsai.ML.Torch/Linspace.cs index 6e7495f8..ee6516cf 100644 --- a/src/Bonsai.ML.Torch/Tensors/Linspace.cs +++ b/src/Bonsai.ML.Torch/Linspace.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. diff --git a/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs b/src/Bonsai.ML.Torch/MeshGrid.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/MeshGrid.cs rename to src/Bonsai.ML.Torch/MeshGrid.cs index 77f4cecb..725b12a9 100644 --- a/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs +++ b/src/Bonsai.ML.Torch/MeshGrid.cs @@ -5,7 +5,7 @@ using static TorchSharp.torch; using System.Linq; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. diff --git a/src/Bonsai.ML.Torch/Tensors/Ones.cs b/src/Bonsai.ML.Torch/Ones.cs similarity index 95% rename from src/Bonsai.ML.Torch/Tensors/Ones.cs rename to src/Bonsai.ML.Torch/Ones.cs index 77768dd1..52bf8732 100644 --- a/src/Bonsai.ML.Torch/Tensors/Ones.cs +++ b/src/Bonsai.ML.Torch/Ones.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a tensor filled with ones. diff --git a/src/Bonsai.ML.Torch/Tensors/Permute.cs b/src/Bonsai.ML.Torch/Permute.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/Permute.cs rename to src/Bonsai.ML.Torch/Permute.cs index 317e34f8..a82107ba 100644 --- a/src/Bonsai.ML.Torch/Tensors/Permute.cs +++ b/src/Bonsai.ML.Torch/Permute.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Permutes the dimensions of the input tensor according to the specified permutation. diff --git a/src/Bonsai.ML.Torch/Tensors/Reshape.cs b/src/Bonsai.ML.Torch/Reshape.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/Reshape.cs rename to src/Bonsai.ML.Torch/Reshape.cs index 5d3e9412..ebdc8e41 100644 --- a/src/Bonsai.ML.Torch/Tensors/Reshape.cs +++ b/src/Bonsai.ML.Torch/Reshape.cs @@ -4,7 +4,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Reshapes the input tensor according to the specified dimensions. diff --git a/src/Bonsai.ML.Torch/Tensors/Set.cs b/src/Bonsai.ML.Torch/Set.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Set.cs rename to src/Bonsai.ML.Torch/Set.cs index a4d8b2d2..14bf3dad 100644 --- a/src/Bonsai.ML.Torch/Tensors/Set.cs +++ b/src/Bonsai.ML.Torch/Set.cs @@ -6,7 +6,7 @@ using System.Xml.Serialization; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Sets the value of the input tensor at the specified index. diff --git a/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs b/src/Bonsai.ML.Torch/TensorDataType.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/TensorDataType.cs rename to src/Bonsai.ML.Torch/TensorDataType.cs index de1ba8d2..fe8861f3 100644 --- a/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs +++ b/src/Bonsai.ML.Torch/TensorDataType.cs @@ -6,7 +6,7 @@ using Newtonsoft.Json.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. diff --git a/src/Bonsai.ML.Torch/Tensors/ToArray.cs b/src/Bonsai.ML.Torch/ToArray.cs similarity index 98% rename from src/Bonsai.ML.Torch/Tensors/ToArray.cs rename to src/Bonsai.ML.Torch/ToArray.cs index 70083ad2..1c2c721a 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToArray.cs +++ b/src/Bonsai.ML.Torch/ToArray.cs @@ -9,7 +9,7 @@ using Bonsai.Expressions; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input tensor into an array of the specified element type. diff --git a/src/Bonsai.ML.Torch/Tensors/ToDevice.cs b/src/Bonsai.ML.Torch/ToDevice.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/ToDevice.cs rename to src/Bonsai.ML.Torch/ToDevice.cs index 4aa1b92a..cb73f733 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToDevice.cs +++ b/src/Bonsai.ML.Torch/ToDevice.cs @@ -4,7 +4,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Moves the input tensor to the specified device. diff --git a/src/Bonsai.ML.Torch/Tensors/ToImage.cs b/src/Bonsai.ML.Torch/ToImage.cs similarity index 95% rename from src/Bonsai.ML.Torch/Tensors/ToImage.cs rename to src/Bonsai.ML.Torch/ToImage.cs index eebf8399..894a9602 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToImage.cs +++ b/src/Bonsai.ML.Torch/ToImage.cs @@ -5,7 +5,7 @@ using OpenCV.Net; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input tensor into an OpenCV image. diff --git a/src/Bonsai.ML.Torch/Tensors/ToMat.cs b/src/Bonsai.ML.Torch/ToMat.cs similarity index 95% rename from src/Bonsai.ML.Torch/Tensors/ToMat.cs rename to src/Bonsai.ML.Torch/ToMat.cs index 756ac636..fa50020c 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToMat.cs +++ b/src/Bonsai.ML.Torch/ToMat.cs @@ -5,7 +5,7 @@ using OpenCV.Net; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input tensor into an OpenCV mat. diff --git a/src/Bonsai.ML.Torch/Tensors/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs similarity index 99% rename from src/Bonsai.ML.Torch/Tensors/ToTensor.cs rename to src/Bonsai.ML.Torch/ToTensor.cs index 753d4422..5bb460de 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToTensor.cs +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -5,7 +5,7 @@ using OpenCV.Net; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input value into a tensor. diff --git a/src/Bonsai.ML.Torch/Tensors/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs similarity index 95% rename from src/Bonsai.ML.Torch/Tensors/Zeros.cs rename to src/Bonsai.ML.Torch/Zeros.cs index 256a43ed..5af526d6 100644 --- a/src/Bonsai.ML.Torch/Tensors/Zeros.cs +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a tensor filled with zeros. From 980a03d01d3c389729a441dbca1491ebe1098ed4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 13:35:03 +0100 Subject: [PATCH 027/131] Updated to reflect new namespace --- src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs | 1 - src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs index 91faf20b..66a5396b 100644 --- a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs +++ b/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Linq; -using Bonsai.ML.Torch.Tensors; namespace Bonsai.ML.Torch.Helpers { diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs index e9d66038..f82a33f9 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using Bonsai.ML.Torch.Tensors; using TorchSharp; using static TorchSharp.torch; using static TorchSharp.torch.nn; From b3f34173db5f6536e589698b1dce9229fd9973dd Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 13:42:25 +0100 Subject: [PATCH 028/131] Removed unfinished classes --- .../NeuralNets/Configuration/ModelConfiguration.cs | 5 ----- .../NeuralNets/Configuration/ModuleConfiguration.cs | 5 ----- src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs | 5 ----- 3 files changed, 15 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs delete mode 100644 src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs delete mode 100644 src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs deleted file mode 100644 index 7628f72d..00000000 --- a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs +++ /dev/null @@ -1,5 +0,0 @@ -namespace Bonsai.ML.Torch.NeuralNets.Configuration; - -public class ModelConfiguration -{ -} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs deleted file mode 100644 index dfe56272..00000000 --- a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs +++ /dev/null @@ -1,5 +0,0 @@ -namespace Bonsai.ML.Torch.NeuralNets.Configuration; - -public class ModuleConfiguration -{ -} diff --git a/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs b/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs deleted file mode 100644 index 035b3ca1..00000000 --- a/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs +++ /dev/null @@ -1,5 +0,0 @@ -namespace Bonsai.ML.Torch.NeuralNets; - -public class ModelManager -{ -} From df313d738092eb9bc09fa85e9b25acafffa4fa44 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:14:31 +0100 Subject: [PATCH 029/131] Updated to use common Bonsai.ML.Data project --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 1 + src/Bonsai.ML.Torch/Helpers/DataHelper.cs | 190 --------------------- 2 files changed, 1 insertion(+), 190 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/Helpers/DataHelper.cs diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj index 9ed3c5d8..bb401adc 100644 --- a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -13,5 +13,6 @@ + \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/DataHelper.cs b/src/Bonsai.ML.Torch/Helpers/DataHelper.cs deleted file mode 100644 index ffed053a..00000000 --- a/src/Bonsai.ML.Torch/Helpers/DataHelper.cs +++ /dev/null @@ -1,190 +0,0 @@ -using System; -using System.Text; -using System.Collections.Generic; -using System.Linq; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; - -namespace Bonsai.ML.Torch.Helpers -{ - /// - /// Provides helper methods for parsing tensor data types. - /// - public static class DataHelper - { - - /// - /// Serializes the input data into a string representation. - /// - public static string SerializeData(object data) - { - if (data is Array array) - { - return SerializeArray(array); - } - else - { - return JsonConvert.SerializeObject(data); - } - } - - /// - /// Serializes the input array into a string representation. - /// - public static string SerializeArray(Array array) - { - StringBuilder sb = new StringBuilder(); - SerializeArrayRecursive(array, sb, [0]); - return sb.ToString(); - } - - private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) - { - if (indices.Length < array.Rank) - { - sb.Append("["); - int length = array.GetLength(indices.Length); - for (int i = 0; i < length; i++) - { - int[] newIndices = new int[indices.Length + 1]; - indices.CopyTo(newIndices, 0); - newIndices[indices.Length] = i; - SerializeArrayRecursive(array, sb, newIndices); - if (i < length - 1) - { - sb.Append(", "); - } - } - sb.Append("]"); - } - else - { - object value = array.GetValue(indices); - sb.Append(value.ToString()); - } - } - - private static bool IsValidJson(string input) - { - int squareBrackets = 0; - foreach (char c in input) - { - if (c == '[') squareBrackets++; - else if (c == ']') squareBrackets--; - } - return squareBrackets == 0; - } - - /// - /// Parses the input string into an object of the specified type. - /// - public static object ParseString(string input, Type dtype) - { - if (!IsValidJson(input)) - { - throw new ArgumentException("JSON is invalid."); - } - var obj = JsonConvert.DeserializeObject(input); - int depth = ParseDepth(obj); - if (depth == 0) - { - return Convert.ChangeType(input, dtype); - } - int[] dimensions = ParseDimensions(obj, depth); - var resultArray = Array.CreateInstance(dtype, dimensions); - PopulateArray(obj, resultArray, [0], dtype); - return resultArray; - } - - private static int ParseDepth(JToken token, int currentDepth = 0) - { - if (token is JArray arr && arr.Count > 0) - { - return ParseDepth(arr[0], currentDepth + 1); - } - return currentDepth; - } - - private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) - { - if (depth == 0 || !(token is JArray)) - { - return [0]; - } - - List dimensions = new List(); - JToken current = token; - - while (current != null && current is JArray) - { - JArray currentArray = current as JArray; - dimensions.Add(currentArray.Count); - if (currentArray.Count > 0) - { - if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) - { - throw new Exception("Error parsing input. Dimensions are inconsistent."); - } - - if (!(currentArray.First() is JArray)) - { - if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) - { - throw new Exception("Error parsing types. All values must be of the same type and only numeric or boolean types are supported."); - } - } - } - - current = currentArray.Count > 0 ? currentArray[0] : null; - } - - if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) - { - var subArrayDimensions = new HashSet(); - foreach (JArray subArr in arr) - { - int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); - subArrayDimensions.Add(string.Join(",", subDims)); - } - - if (subArrayDimensions.Count > 1) - { - throw new ArgumentException("Inconsistent array dimensions."); - } - } - - return dimensions.ToArray(); - } - - private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) - { - if (token is JArray arr) - { - for (int i = 0; i < arr.Count; i++) - { - int[] newIndices = new int[indices.Length + 1]; - Array.Copy(indices, newIndices, indices.Length); - newIndices[newIndices.Length - 1] = i; - PopulateArray(arr[i], array, newIndices, dtype); - } - } - else - { - var values = ConvertType(token, dtype); - array.SetValue(values, indices); - } - } - - private static object ConvertType(object value, Type targetType) - { - try - { - return Convert.ChangeType(value, targetType); - } - catch (Exception ex) - { - throw new Exception("Error parsing type: ", ex); - } - } - } -} \ No newline at end of file From e75e2a9faa385b4dd82b87ed4cdd172fd0e60428 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:15:11 +0100 Subject: [PATCH 030/131] Added additional overloads to process method --- src/Bonsai.ML.Torch/Concat.cs | 60 +++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/Concat.cs b/src/Bonsai.ML.Torch/Concat.cs index b07b211d..34adf731 100644 --- a/src/Bonsai.ML.Torch/Concat.cs +++ b/src/Bonsai.ML.Torch/Concat.cs @@ -2,6 +2,7 @@ using System.ComponentModel; using System.Linq; using System.Reactive.Linq; +using System.Collections.Generic; using static TorchSharp.torch; namespace Bonsai.ML.Torch @@ -36,9 +37,62 @@ public IObservable Process(IObservable> source) { return source.Select(value => { - var tensor1 = value.Item1; - var tensor2 = value.Item2; - return cat([tensor1, tensor2], Dimension); + return cat([value.Item1, value.Item2], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4, value.Item5], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4, value.Item5, value.Item6], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat(value.ToList(), Dimension); }); } } From 35f5308ee6dce04b2cb5a8f920db603b1cd8c4ba Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:15:54 +0100 Subject: [PATCH 031/131] Updated to use common data tools --- src/Bonsai.ML.Torch/CreateTensor.cs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 0100f920..1fc8ef5e 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -8,6 +8,8 @@ using System.Xml.Serialization; using Bonsai.Expressions; using static TorchSharp.torch; +using Bonsai.ML.Data; +using Bonsai.ML.Torch.Helpers; namespace Bonsai.ML.Torch { @@ -188,7 +190,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp /// public override Expression Build(IEnumerable arguments) { - var returnType = Helpers.TensorDataTypeHelper.GetTypeFromTensorDataType(scalarType); + var returnType = TensorDataTypeLookup.GetTypeFromTensorDataType(scalarType); var argTypes = arguments.Select(arg => arg.Type).ToArray(); var methodInfoArgumentTypes = new Type[] { @@ -207,7 +209,7 @@ public override Expression Build(IEnumerable arguments) .GetGenericArguments()[0] ) : methods.FirstOrDefault(m => !m.IsGenericMethod); - var tensorValues = Helpers.DataHelper.ParseString(values, returnType); + var tensorValues = ArrayHelper.ParseString(values, returnType); var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); @@ -221,8 +223,8 @@ public override Expression Build(IEnumerable arguments) } finally { - values = Helpers.DataHelper.SerializeData(tensorValues).Replace("False", "false").Replace("True", "true"); - scalarType = Helpers.TensorDataTypeHelper.GetTensorDataTypeFromType(returnType); + values = ArrayHelper.SerializeToJson(tensorValues).ToLower(); + scalarType = TensorDataTypeLookup.GetTensorDataTypeFromType(returnType); } } From c13c5106b9d1169ac253ee6cfa758af10d03dd58 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:16:11 +0100 Subject: [PATCH 032/131] Updated tensor data type and helper --- src/Bonsai.ML.Torch/TensorDataType.cs | 6 ------ .../TensorDataTypeHelper.cs => TensorDataTypeLookup.cs} | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) rename src/Bonsai.ML.Torch/{Helpers/TensorDataTypeHelper.cs => TensorDataTypeLookup.cs} (98%) diff --git a/src/Bonsai.ML.Torch/TensorDataType.cs b/src/Bonsai.ML.Torch/TensorDataType.cs index fe8861f3..f76a04c1 100644 --- a/src/Bonsai.ML.Torch/TensorDataType.cs +++ b/src/Bonsai.ML.Torch/TensorDataType.cs @@ -1,9 +1,3 @@ -using System; -using System.Text; -using System.Collections.Generic; -using System.Linq; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; using static TorchSharp.torch; namespace Bonsai.ML.Torch diff --git a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs similarity index 98% rename from src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs rename to src/Bonsai.ML.Torch/TensorDataTypeLookup.cs index 66a5396b..6e2b1be0 100644 --- a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs +++ b/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch.Helpers /// /// Provides helper methods for working with tensor data types. /// - public class TensorDataTypeHelper + public class TensorDataTypeLookup { private static readonly Dictionary _lookup = new Dictionary { From a7742ef63848d04353d6dbfdb77a149de8e4c4f5 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:16:46 +0100 Subject: [PATCH 033/131] Updated formatting --- src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs index 4e90fa35..b0938f42 100644 --- a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs +++ b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs @@ -13,7 +13,8 @@ namespace Bonsai.ML.Torch.Helpers /// public static class OpenCVHelper { - private static Dictionary bitDepthLookup = new Dictionary { + private static Dictionary bitDepthLookup = new Dictionary + { { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, @@ -55,12 +56,16 @@ public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dim fixed (long* dimensionsPtr = dimensions) { IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - if (tensorHandle == IntPtr.Zero) { + if (tensorHandle == IntPtr.Zero) + { GC.Collect(); GC.WaitForPendingFinalizers(); tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); } - if (tensorHandle == IntPtr.Zero) { CheckForErrors(); } + if (tensorHandle == IntPtr.Zero) + { + CheckForErrors(); + } var output = Tensor.UnsafeCreateTensor(tensorHandle); return output; } From 6bd2e666fd369d99f4d7e3c1ab78caabec0375d2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:01:06 +0100 Subject: [PATCH 034/131] Removed bonsai core dependency in favor of bonsai.ml dependency --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj index bb401adc..2a0c1d53 100644 --- a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -7,12 +7,12 @@ true - + \ No newline at end of file From df6bd160c206e07ecb0fbb8f618f250373cba569 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:03:36 +0100 Subject: [PATCH 035/131] Fixed bugs with create tensor method and updated to use string formatter --- src/Bonsai.ML.Torch/CreateTensor.cs | 41 +++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 1fc8ef5e..66509bbc 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -9,6 +9,7 @@ using Bonsai.Expressions; using static TorchSharp.torch; using Bonsai.ML.Data; +using Bonsai.ML.Python; using Bonsai.ML.Torch.Helpers; namespace Bonsai.ML.Torch @@ -45,7 +46,7 @@ public string Values get => values; set { - values = value.Replace("False", "false").Replace("True", "true"); + values = value.ToLower(); } } @@ -55,18 +56,20 @@ public string Values /// The device on which to create the tensor. /// [XmlIgnore] - public Device Device { get => device; set => device = value; } + public Device Device + { + get => device; + set => device = value; + } private Device device = null; private Expression BuildTensorFromArray(Array arrayValues, Type returnType) { var rank = arrayValues.Rank; - var lengths = new int[rank]; - for (int i = 0; i < rank; i++) - { - lengths[i] = arrayValues.GetLength(i); - } + var lengths = Enumerable.Range(0, rank) + .Select(arrayValues.GetLength) + .ToArray(); var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); @@ -89,7 +92,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) } var tensorDataInitializationBlock = Expression.Block( - arrayVariable, + [arrayVariable], assignArray, Expression.Block(assignments), arrayVariable @@ -108,7 +111,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) var tensorAssignment = Expression.Call( tensorCreationMethodInfo, tensorDataInitializationBlock, - Expression.Constant(scalarType, typeof(ScalarType?)), + Expression.Constant((ScalarType)scalarType, typeof(ScalarType?)), Expression.Constant(device, typeof(Device)), Expression.Constant(false, typeof(bool)), Expression.Constant(null, typeof(string).MakeArrayType()) @@ -118,7 +121,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); var buildTensor = Expression.Block( - tensorVariable, + [tensorVariable], assignTensor, tensorVariable ); @@ -132,7 +135,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); var tensorDataInitializationBlock = Expression.Block( - valueVariable, + [valueVariable], assignValue, valueVariable ); @@ -145,10 +148,10 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp ] ); - var tensorCreationMethodArguments = new Expression[] { - Expression.Constant(device, typeof(Device) ), - Expression.Constant(false, typeof(bool) ) - }; + Expression[] tensorCreationMethodArguments = [ + Expression.Constant(device, typeof(Device)), + Expression.Constant(false, typeof(bool)) + ]; if (tensorCreationMethodInfo == null) { @@ -193,9 +196,7 @@ public override Expression Build(IEnumerable arguments) var returnType = TensorDataTypeLookup.GetTypeFromTensorDataType(scalarType); var argTypes = arguments.Select(arg => arg.Type).ToArray(); - var methodInfoArgumentTypes = new Type[] { - typeof(Tensor) - }; + Type[] methodInfoArgumentTypes = [typeof(Tensor)]; var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) .Where(m => m.Name == "Process") @@ -211,7 +212,7 @@ public override Expression Build(IEnumerable arguments) var tensorValues = ArrayHelper.ParseString(values, returnType); var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); - var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); + var methodArguments = arguments.Count() == 0 ? [buildTensor] : arguments.Concat([buildTensor]); try { @@ -223,7 +224,7 @@ public override Expression Build(IEnumerable arguments) } finally { - values = ArrayHelper.SerializeToJson(tensorValues).ToLower(); + values = StringFormatter.FormatToPython(tensorValues).ToLower(); scalarType = TensorDataTypeLookup.GetTensorDataTypeFromType(returnType); } } From 9b3c4f9f01cdda39b5691c87045c3d12362ea167 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:04:58 +0100 Subject: [PATCH 036/131] Added empty tensor creator --- src/Bonsai.ML.Torch/Empty.cs | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Empty.cs diff --git a/src/Bonsai.ML.Torch/Empty.cs b/src/Bonsai.ML.Torch/Empty.cs new file mode 100644 index 00000000..1c4f6af5 --- /dev/null +++ b/src/Bonsai.ML.Torch/Empty.cs @@ -0,0 +1,38 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Returns an empty tensor with the given data type and size. + /// + [Combinator] + [Description("Converts the input tensor into an OpenCV mat.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Empty + { + + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// The data type of the tensor elements. + /// + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Returns an empty tensor with the given data type and size. + /// + public IObservable Process() + { + return Observable.Defer(() => + { + return Observable.Return(empty(Size, Type)); + }); + } + } +} From 2ced9dd8cf3956d8eb99d5a826ff7137f5d35d36 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:05:25 +0100 Subject: [PATCH 037/131] Moved index helper to main library --- src/Bonsai.ML.Torch/Index.cs | 2 +- src/Bonsai.ML.Torch/{Helpers => }/IndexHelper.cs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename src/Bonsai.ML.Torch/{Helpers => }/IndexHelper.cs (94%) diff --git a/src/Bonsai.ML.Torch/Index.cs b/src/Bonsai.ML.Torch/Index.cs index 5b7b9192..818bb401 100644 --- a/src/Bonsai.ML.Torch/Index.cs +++ b/src/Bonsai.ML.Torch/Index.cs @@ -26,7 +26,7 @@ public class Index /// public IObservable Process(IObservable source) { - var index = Helpers.IndexHelper.ParseString(Indexes); + var index = IndexHelper.Parse(Indexes); return source.Select(tensor => { return tensor.index(index); }); diff --git a/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs b/src/Bonsai.ML.Torch/IndexHelper.cs similarity index 94% rename from src/Bonsai.ML.Torch/Helpers/IndexHelper.cs rename to src/Bonsai.ML.Torch/IndexHelper.cs index 541ae443..2af466a0 100644 --- a/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs +++ b/src/Bonsai.ML.Torch/IndexHelper.cs @@ -1,7 +1,7 @@ using System; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Helpers +namespace Bonsai.ML.Torch { /// /// Provides helper methods to parse tensor indexes. @@ -13,7 +13,7 @@ public static class IndexHelper /// Parses the input string into an array of tensor indexes. /// /// - public static TensorIndex[] ParseString(string input) + public static TensorIndex[] Parse(string input) { if (string.IsNullOrEmpty(input)) { @@ -83,7 +83,7 @@ public static TensorIndex[] ParseString(string input) /// /// /// - public static string SerializeIndexes(TensorIndex[] indexes) + public static string Serialize(TensorIndex[] indexes) { return string.Join(", ", indexes); } From a2fa42037d075395b3e6a9cf1fcf5a820199529f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:05:38 +0100 Subject: [PATCH 038/131] Moved opencv helper to main library --- src/Bonsai.ML.Torch/OpenCVHelper.cs | 174 ++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 src/Bonsai.ML.Torch/OpenCVHelper.cs diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs new file mode 100644 index 00000000..1ca049c9 --- /dev/null +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -0,0 +1,174 @@ +using System; +using System.Runtime.InteropServices; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Helper class to convert between OpenCV mats, images and Torch tensors. + /// + public static class OpenCVHelper + { + private static Dictionary bitDepthLookup = new Dictionary + { + { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, + { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, + { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, + { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, + { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, + { ScalarType.Int8, (IplDepth.S8, Depth.S8) } + }; + + private static ConcurrentDictionary deleters = new ConcurrentDictionary(); + + internal delegate void GCHandleDeleter(IntPtr memory); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_data(IntPtr handle); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); + + /// + /// Creates a tensor from a pointer to the data and the dimensions of the tensor. + /// + /// + /// + /// + /// + public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) + { + var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); + var gchp = GCHandle.ToIntPtr(dataHandle); + GCHandleDeleter deleter = null; + + deleter = new GCHandleDeleter((IntPtr ptrHandler) => + { + GCHandle.FromIntPtr(gchp).Free(); + deleters.TryRemove(deleter, out deleter); + }); + deleters.TryAdd(deleter, deleter); + + fixed (long* dimensionsPtr = dimensions) + { + IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + if (tensorHandle == IntPtr.Zero) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + } + if (tensorHandle == IntPtr.Zero) + { + CheckForErrors(); + } + var output = Tensor.UnsafeCreateTensor(tensorHandle); + return output; + } + } + + /// + /// Converts an OpenCV image to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(IplImage image) + { + if (image == null) + { + return empty([ 0, 0, 0 ]); + } + + int width = image.Width; + int height = image.Height; + int channels = image.Channels; + + var iplDepth = image.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; + + IntPtr tensorDataPtr = image.ImageData; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts an OpenCV mat to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(Mat mat) + { + if (mat == null) + { + return empty([0, 0, 0 ]); + } + + int width = mat.Size.Width; + int height = mat.Size.Height; + int channels = mat.Channels; + + var depth = mat.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; + + IntPtr tensorDataPtr = mat.Data; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts a Torch tensor to an OpenCV image. + /// + /// + /// + public unsafe static IplImage ToImage(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var iplDepth = bitDepthLookup[tensorType].IplDepth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); + + return image; + } + + /// + /// Converts a Torch tensor to an OpenCV mat. + /// + /// + /// + public unsafe static Mat ToMat(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var depth = bitDepthLookup[tensorType].Depth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); + + return mat; + } + } +} \ No newline at end of file From 05c12ce96f84ce38893a5ee50e801c6870b30dd9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:06:20 +0100 Subject: [PATCH 039/131] Removed opencv helper from helpers subfolder --- src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs | 174 -------------------- 1 file changed, 174 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs diff --git a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs deleted file mode 100644 index b0938f42..00000000 --- a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs +++ /dev/null @@ -1,174 +0,0 @@ -using System; -using System.Runtime.InteropServices; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch.Helpers -{ - /// - /// Helper class to convert between OpenCV mats, images and Torch tensors. - /// - public static class OpenCVHelper - { - private static Dictionary bitDepthLookup = new Dictionary - { - { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, - { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, - { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, - { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, - { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, - { ScalarType.Int8, (IplDepth.S8, Depth.S8) } - }; - - private static ConcurrentDictionary deleters = new ConcurrentDictionary(); - - internal delegate void GCHandleDeleter(IntPtr memory); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_data(IntPtr handle); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); - - /// - /// Creates a tensor from a pointer to the data and the dimensions of the tensor. - /// - /// - /// - /// - /// - public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) - { - var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); - var gchp = GCHandle.ToIntPtr(dataHandle); - GCHandleDeleter deleter = null; - - deleter = new GCHandleDeleter((IntPtr ptrHandler) => - { - GCHandle.FromIntPtr(gchp).Free(); - deleters.TryRemove(deleter, out deleter); - }); - deleters.TryAdd(deleter, deleter); - - fixed (long* dimensionsPtr = dimensions) - { - IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - if (tensorHandle == IntPtr.Zero) - { - GC.Collect(); - GC.WaitForPendingFinalizers(); - tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - } - if (tensorHandle == IntPtr.Zero) - { - CheckForErrors(); - } - var output = Tensor.UnsafeCreateTensor(tensorHandle); - return output; - } - } - - /// - /// Converts an OpenCV image to a Torch tensor. - /// - /// - /// - public static Tensor ToTensor(IplImage image) - { - if (image == null) - { - return empty([ 0, 0, 0 ]); - } - - int width = image.Width; - int height = image.Height; - int channels = image.Channels; - - var iplDepth = image.Depth; - var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; - - IntPtr tensorDataPtr = image.ImageData; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); - } - - /// - /// Converts an OpenCV mat to a Torch tensor. - /// - /// - /// - public static Tensor ToTensor(Mat mat) - { - if (mat == null) - { - return empty([0, 0, 0 ]); - } - - int width = mat.Size.Width; - int height = mat.Size.Height; - int channels = mat.Channels; - - var depth = mat.Depth; - var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; - - IntPtr tensorDataPtr = mat.Data; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); - } - - /// - /// Converts a Torch tensor to an OpenCV image. - /// - /// - /// - public unsafe static IplImage ToImage(Tensor tensor) - { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; - - var tensorType = tensor.dtype; - var iplDepth = bitDepthLookup[tensorType].IplDepth; - - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); - - return image; - } - - /// - /// Converts a Torch tensor to an OpenCV mat. - /// - /// - /// - public unsafe static Mat ToMat(Tensor tensor) - { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; - - var tensorType = tensor.dtype; - var depth = bitDepthLookup[tensorType].Depth; - - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); - - return mat; - } - } -} \ No newline at end of file From 57901ae31e7774511de5c9e97796b3d65dfe2c42 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:06:39 +0100 Subject: [PATCH 040/131] Updated with new index helper --- src/Bonsai.ML.Torch/Set.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/Set.cs b/src/Bonsai.ML.Torch/Set.cs index 14bf3dad..6b0fd86b 100644 --- a/src/Bonsai.ML.Torch/Set.cs +++ b/src/Bonsai.ML.Torch/Set.cs @@ -21,8 +21,8 @@ public class Set /// public string Index { - get => Helpers.IndexHelper.SerializeIndexes(indexes); - set => indexes = Helpers.IndexHelper.ParseString(value); + get => IndexHelper.Serialize(indexes); + set => indexes = IndexHelper.Parse(value); } private TensorIndex[] indexes; From 51ab15b599e62bd468c015a723c18e6bf96f4b7b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:07:31 +0100 Subject: [PATCH 041/131] Updated with opencv helper --- src/Bonsai.ML.Torch/ToImage.cs | 2 +- src/Bonsai.ML.Torch/ToMat.cs | 2 +- src/Bonsai.ML.Torch/ToTensor.cs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch/ToImage.cs b/src/Bonsai.ML.Torch/ToImage.cs index 894a9602..0b9d8ccd 100644 --- a/src/Bonsai.ML.Torch/ToImage.cs +++ b/src/Bonsai.ML.Torch/ToImage.cs @@ -22,7 +22,7 @@ public class ToImage /// public IObservable Process(IObservable source) { - return source.Select(Helpers.OpenCVHelper.ToImage); + return source.Select(OpenCVHelper.ToImage); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToMat.cs b/src/Bonsai.ML.Torch/ToMat.cs index fa50020c..1b1746ed 100644 --- a/src/Bonsai.ML.Torch/ToMat.cs +++ b/src/Bonsai.ML.Torch/ToMat.cs @@ -22,7 +22,7 @@ public class ToMat /// public IObservable Process(IObservable source) { - return source.Select(Helpers.OpenCVHelper.ToMat); + return source.Select(OpenCVHelper.ToMat); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs index 5bb460de..061a13cb 100644 --- a/src/Bonsai.ML.Torch/ToTensor.cs +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -118,7 +118,7 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(Helpers.OpenCVHelper.ToTensor); + return source.Select(OpenCVHelper.ToTensor); } /// @@ -128,7 +128,7 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(Helpers.OpenCVHelper.ToTensor); + return source.Select(OpenCVHelper.ToTensor); } } } \ No newline at end of file From ac4b398ceca001ff4f37aa418f69f486f63f6bdc Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:07:49 +0100 Subject: [PATCH 042/131] Added process overload for generating on input --- src/Bonsai.ML.Torch/Zeros.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/Bonsai.ML.Torch/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs index 5af526d6..69673d4a 100644 --- a/src/Bonsai.ML.Torch/Zeros.cs +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -26,5 +26,17 @@ public IObservable Process() { return Observable.Defer(() => Observable.Return(ones(Size))); } + + /// + /// Generates an observable sequence of tensors filled with zeros for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return ones(Size); + }); + } } } \ No newline at end of file From e69923af81025ef973ec5a7421f47ceeb5645b31 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 27 Sep 2024 12:39:06 +0100 Subject: [PATCH 043/131] Added xml ignore tag on device --- src/Bonsai.ML.Torch/ToDevice.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Bonsai.ML.Torch/ToDevice.cs b/src/Bonsai.ML.Torch/ToDevice.cs index cb73f733..531ff585 100644 --- a/src/Bonsai.ML.Torch/ToDevice.cs +++ b/src/Bonsai.ML.Torch/ToDevice.cs @@ -2,6 +2,7 @@ using System.ComponentModel; using System.Linq; using System.Reactive.Linq; +using System.Xml.Serialization; using static TorchSharp.torch; namespace Bonsai.ML.Torch @@ -17,6 +18,7 @@ public class ToDevice /// /// The device to which the input tensor should be moved. /// + [XmlIgnore] public Device Device { get; set; } /// From 9099cf958e8f5ebacabc6d2f44444aacabcd9860 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 16 Oct 2024 14:47:40 +0100 Subject: [PATCH 044/131] Updated to use shared module interface --- src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs index fb7722f2..28a3d57b 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -3,7 +3,6 @@ using System.Reactive.Linq; using static TorchSharp.torch; using System.Xml.Serialization; -using static TorchSharp.torch.nn; using Bonsai.Expressions; namespace Bonsai.ML.Torch.NeuralNets @@ -18,9 +17,9 @@ public class LoadPretrainedModel private int numClasses = 10; - public IObservable Process() + public IObservable> Process() { - Module model = null; + nn.Module model = null; var modelName = ModelName.ToString().ToLower(); var device = Device; From f76181e29c8f459a46df10fcfe8d5071f4c131e2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 30 Oct 2024 13:12:54 +0000 Subject: [PATCH 045/131] Modified to use Module interface --- src/Bonsai.ML.Torch/NeuralNets/Forward.cs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Forward.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs new file mode 100644 index 00000000..03b2e10e --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -0,0 +1,23 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; +using TorchSharp.Modules; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Forward + { + [XmlIgnore] + public nn.Module Model { get; set; } + + public IObservable Process(IObservable source) + { + return source.Select(Model.forward); + } + } +} \ No newline at end of file From befe81bea79205f63a69ffde516593611745fee1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 30 Oct 2024 13:13:31 +0000 Subject: [PATCH 046/131] Removed unnecessary null string --- .../NeuralNets/LoadPretrainedModel.cs | 5 +++- .../NeuralNets/LoadScriptModule.cs | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs index 28a3d57b..a87bd744 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -15,9 +15,12 @@ public class LoadPretrainedModel public Models.PretrainedModels ModelName { get; set; } public Device Device { get; set; } + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelWeightsPath { get; set; } + private int numClasses = 10; - public IObservable> Process() + public IObservable> Process() { nn.Module model = null; var modelName = ModelName.ToString().ToLower(); diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs new file mode 100644 index 00000000..9c73031f --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadScriptModule + { + + [XmlIgnore] + public Device Device { get; set; } = CPU; + + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelPath { get; set; } + + public IObservable> Process() + { + return Observable.Return((nn.IModule)jit.load(ModelPath, Device)); + } + } +} \ No newline at end of file From d4c4cc019015339a307628c7d37924b05c57b32c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 1 Nov 2024 13:27:03 +0000 Subject: [PATCH 047/131] Added a common interface --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 3 ++- src/Bonsai.ML.Torch/NeuralNets/Forward.cs | 2 +- src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs | 9 +++++++++ src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs | 4 ++-- src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs | 4 ++-- 5 files changed, 16 insertions(+), 6 deletions(-) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj index 2a0c1d53..97bfe18c 100644 --- a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -8,7 +8,8 @@ - + + diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs index 03b2e10e..e1a8a283 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -13,7 +13,7 @@ namespace Bonsai.ML.Torch.NeuralNets public class Forward { [XmlIgnore] - public nn.Module Model { get; set; } + public ITorchModule Model { get; set; } public IObservable Process(IObservable source) { diff --git a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs new file mode 100644 index 00000000..1bfcdab3 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs @@ -0,0 +1,9 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.NeuralNets +{ + public interface ITorchModule + { + public Tensor forward(Tensor tensor); + } +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs index a87bd744..43443c24 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -20,7 +20,7 @@ public class LoadPretrainedModel private int numClasses = 10; - public IObservable> Process() + public IObservable Process() { nn.Module model = null; var modelName = ModelName.ToString().ToLower(); @@ -42,7 +42,7 @@ public class LoadPretrainedModel } return Observable.Defer(() => { - return Observable.Return(model); + return Observable.Return((ITorchModule)model); }); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs index 9c73031f..7e5d53b0 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -18,9 +18,9 @@ public class LoadScriptModule [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string ModelPath { get; set; } - public IObservable> Process() + public IObservable Process() { - return Observable.Return((nn.IModule)jit.load(ModelPath, Device)); + return Observable.Return((ITorchModule)jit.load(ModelPath, Device)); } } } \ No newline at end of file From b9bacdd23e5b462f139600479c3f60b5af5008d1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:26:56 +0000 Subject: [PATCH 048/131] Added swap axes function --- src/Bonsai.ML.Torch/Swapaxes.cs | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Swapaxes.cs diff --git a/src/Bonsai.ML.Torch/Swapaxes.cs b/src/Bonsai.ML.Torch/Swapaxes.cs new file mode 100644 index 00000000..4777e882 --- /dev/null +++ b/src/Bonsai.ML.Torch/Swapaxes.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + [Combinator] + [Description("Swaps the axes of the input tensor.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Swapaxes + { + + /// + /// The value of axis 1. + /// + public long Axis1 { get; set; } = 0; + + /// + /// The value of axis 2. + /// + public long Axis2 { get; set; } = 1; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return swapaxes(tensor, Axis1, Axis2); + }); + } + } +} \ No newline at end of file From 624c703d876b2d9b491be592283126e85475fa47 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:27:05 +0000 Subject: [PATCH 049/131] Added tile function --- src/Bonsai.ML.Torch/Tile.cs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Tile.cs diff --git a/src/Bonsai.ML.Torch/Tile.cs b/src/Bonsai.ML.Torch/Tile.cs new file mode 100644 index 00000000..1df78122 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tile.cs @@ -0,0 +1,22 @@ +using static TorchSharp.torch; +using System; +using System.ComponentModel; +using System.Reactive.Linq; + +namespace Bonsai.ML.Torch +{ + [Combinator] + [Description("Constructs a tensor by repeating the elements of input. The Dimensions argument specifies the number of repetitions in each dimension.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Tile + { + public long[] Dimensions { get; set; } + + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tile(tensor, Dimensions); + }); + } + } +} From 2f799b7f511fc4e8a57f3be5bfbdb300683ccb5f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:27:39 +0000 Subject: [PATCH 050/131] Updated model loading and model forward procedure --- src/Bonsai.ML.Torch/NeuralNets/Forward.cs | 2 ++ .../NeuralNets/ITorchModule.cs | 1 + .../NeuralNets/LoadPretrainedModel.cs | 17 +++++---- .../NeuralNets/LoadScriptModule.cs | 4 ++- .../NeuralNets/TorchModuleAdapter.cs | 36 +++++++++++++++++++ src/Bonsai.ML.Torch/Vision/Normalize.cs | 28 +++++++++++++++ 6 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs create mode 100644 src/Bonsai.ML.Torch/Vision/Normalize.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs index e1a8a283..3aae4012 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -4,6 +4,7 @@ using static TorchSharp.torch; using System.Xml.Serialization; using TorchSharp.Modules; +using TorchSharp; namespace Bonsai.ML.Torch.NeuralNets { @@ -17,6 +18,7 @@ public class Forward public IObservable Process(IObservable source) { + Model.Module.eval(); return source.Select(Model.forward); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs index 1bfcdab3..e7ebf994 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs @@ -4,6 +4,7 @@ namespace Bonsai.ML.Torch.NeuralNets { public interface ITorchModule { + public nn.Module Module { get; } public Tensor forward(Tensor tensor); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs index 43443c24..e4dddba1 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -3,7 +3,6 @@ using System.Reactive.Linq; using static TorchSharp.torch; using System.Xml.Serialization; -using Bonsai.Expressions; namespace Bonsai.ML.Torch.NeuralNets { @@ -13,6 +12,8 @@ namespace Bonsai.ML.Torch.NeuralNets public class LoadPretrainedModel { public Models.PretrainedModels ModelName { get; set; } + + [XmlIgnore] public Device Device { get; set; } [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] @@ -22,27 +23,31 @@ public class LoadPretrainedModel public IObservable Process() { - nn.Module model = null; + nn.Module module = null; var modelName = ModelName.ToString().ToLower(); var device = Device; switch (modelName) { case "alexnet": - model = new Models.AlexNet(modelName, numClasses, device); + module = new Models.AlexNet(modelName, numClasses, device); + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); break; case "mobilenet": - model = new Models.MobileNet(modelName, numClasses, device); + module = new Models.MobileNet(modelName, numClasses, device); + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); break; case "mnist": - model = new Models.MNIST(modelName, device); + module = new Models.MNIST(modelName, device); + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); break; default: throw new ArgumentException($"Model {modelName} not supported."); } + var torchModule = new TorchModuleAdapter(module); return Observable.Defer(() => { - return Observable.Return((ITorchModule)model); + return Observable.Return((ITorchModule)torchModule); }); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs index 7e5d53b0..7e6c73fe 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -20,7 +20,9 @@ public class LoadScriptModule public IObservable Process() { - return Observable.Return((ITorchModule)jit.load(ModelPath, Device)); + var scriptModule = jit.load(ModelPath, Device); + var torchModule = new TorchModuleAdapter(scriptModule); + return Observable.Return((ITorchModule)torchModule); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs new file mode 100644 index 00000000..a1c44d96 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs @@ -0,0 +1,36 @@ +using System; +using System.Reflection; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.NeuralNets +{ + public class TorchModuleAdapter : ITorchModule + { + private readonly nn.Module _module = null; + + private readonly jit.ScriptModule _scriptModule = null; + + private Func forwardFunc; + + public nn.Module Module { get; } + + public TorchModuleAdapter(nn.Module module) + { + _module = module; + forwardFunc = _module.forward; + Module = _module; + } + + public TorchModuleAdapter(jit.ScriptModule scriptModule) + { + _scriptModule = scriptModule; + forwardFunc = _scriptModule.forward; + Module = _scriptModule; + } + + public Tensor forward(Tensor input) + { + return forwardFunc(input); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Vision/Normalize.cs b/src/Bonsai.ML.Torch/Vision/Normalize.cs new file mode 100644 index 00000000..fee8a3b9 --- /dev/null +++ b/src/Bonsai.ML.Torch/Vision/Normalize.cs @@ -0,0 +1,28 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torchvision; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Vision +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Normalize + { + private ITransform inputTransform; + + public IObservable Process(IObservable source) + { + inputTransform = transforms.Normalize(new double[] { 0.1307 }, new double[] { 0.3081 }); + + return source.Select(tensor => { + return inputTransform.call(tensor); + }); + } + } +} \ No newline at end of file From 637020ed4361847437528f0e4ad7a3da2706da7b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 14 Nov 2024 13:12:55 +0000 Subject: [PATCH 051/131] Added backward function for running online gradient descent with specified loss and optimization --- src/Bonsai.ML.Torch/NeuralNets/Backward.cs | 61 +++++++++++++++++++++ src/Bonsai.ML.Torch/NeuralNets/Loss.cs | 13 +++++ src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs | 13 +++++ 3 files changed, 87 insertions(+) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Backward.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Loss.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs new file mode 100644 index 00000000..5fa4902e --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs @@ -0,0 +1,61 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using static TorchSharp.torch.optim; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Backward + { + public Optimizer Optimizer { get; set; } + + [XmlIgnore] + public ITorchModule Model { get; set; } + + public Loss Loss { get; set; } + + public IObservable Process(IObservable> source) + { + optim.Optimizer optimizer = null; + switch (Optimizer) + { + case Optimizer.Adam: + optimizer = Adam(Model.Module.parameters()); + break; + } + + Module loss = null; + switch (Loss) + { + case Loss.NLLLoss: + loss = NLLLoss(); + break; + } + + var scheduler = lr_scheduler.StepLR(optimizer, 1, 0.7); + Model.Module.train(); + + return source.Select((input) => { + var (data, target) = input; + using (_ = NewDisposeScope()) + { + optimizer.zero_grad(); + + var prediction = Model.forward(data); + var output = loss.forward(prediction, target); + + output.backward(); + + optimizer.step(); + return output.MoveToOuterDisposeScope(); + } + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Loss.cs b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs new file mode 100644 index 00000000..ff003019 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs @@ -0,0 +1,13 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.optim; + +namespace Bonsai.ML.Torch.NeuralNets +{ + public enum Loss + { + NLLLoss, + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs new file mode 100644 index 00000000..3c0d4bb7 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs @@ -0,0 +1,13 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.optim; + +namespace Bonsai.ML.Torch.NeuralNets +{ + public enum Optimizer + { + Adam, + } +} \ No newline at end of file From 65d8c70138ec2b2e52dfe450362f47926f98090e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 14 Nov 2024 13:13:05 +0000 Subject: [PATCH 052/131] Added function to save model --- src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs | 28 +++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs new file mode 100644 index 00000000..314e9f3f --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; +using TorchSharp.Modules; +using TorchSharp; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Sink)] + public class SaveModel + { + [XmlIgnore] + public ITorchModule Model { get; set; } + + public string ModelPath { get; set; } + + public IObservable Process(IObservable source) + { + return source.Do(input => { + Model.Module.save(ModelPath); + }); + } + } +} \ No newline at end of file From 47dde6c00c97199e59428d047aa1081d98f5092b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 14:54:37 +0000 Subject: [PATCH 053/131] Changed name to indicate loading from an existing architecture --- .../NeuralNets/LoadModuleFromArchitecture.cs | 60 +++++++++++++++++++ .../NeuralNets/LoadPretrainedModel.cs | 54 ----------------- ...etrainedModels.cs => ModelArchitecture.cs} | 2 +- 3 files changed, 61 insertions(+), 55 deletions(-) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs delete mode 100644 src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs rename src/Bonsai.ML.Torch/NeuralNets/Models/{PretrainedModels.cs => ModelArchitecture.cs} (76%) diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs new file mode 100644 index 00000000..43d74544 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs @@ -0,0 +1,60 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadModuleFromArchitecture + { + public Models.ModelArchitecture ModelArchitecture { get; set; } + + [XmlIgnore] + public Device Device { get; set; } + + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelWeightsPath { get; set; } + + private int numClasses = 10; + public int NumClasses + { + get => numClasses; + set + { + if (value <= 0) + { + numClasses = 10; + } + else + { + numClasses = value; + } + } + } + + public IObservable Process() + { + var modelArchitecture = ModelArchitecture.ToString().ToLower(); + var device = Device; + + nn.Module module = modelArchitecture switch + { + "alexnet" => new Models.AlexNet(modelArchitecture, numClasses, device), + "mobilenet" => new Models.MobileNet(modelArchitecture, numClasses, device), + "mnist" => new Models.MNIST(modelArchitecture, device), + _ => throw new ArgumentException($"Model {modelArchitecture} not supported.") + }; + + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); + + var torchModule = new TorchModuleAdapter(module); + return Observable.Defer(() => { + return Observable.Return((ITorchModule)torchModule); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs deleted file mode 100644 index e4dddba1..00000000 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ /dev/null @@ -1,54 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using System.Xml.Serialization; - -namespace Bonsai.ML.Torch.NeuralNets -{ - [Combinator] - [Description("")] - [WorkflowElementCategory(ElementCategory.Source)] - public class LoadPretrainedModel - { - public Models.PretrainedModels ModelName { get; set; } - - [XmlIgnore] - public Device Device { get; set; } - - [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string ModelWeightsPath { get; set; } - - private int numClasses = 10; - - public IObservable Process() - { - nn.Module module = null; - var modelName = ModelName.ToString().ToLower(); - var device = Device; - - switch (modelName) - { - case "alexnet": - module = new Models.AlexNet(modelName, numClasses, device); - if (ModelWeightsPath is not null) module.load(ModelWeightsPath); - break; - case "mobilenet": - module = new Models.MobileNet(modelName, numClasses, device); - if (ModelWeightsPath is not null) module.load(ModelWeightsPath); - break; - case "mnist": - module = new Models.MNIST(modelName, device); - if (ModelWeightsPath is not null) module.load(ModelWeightsPath); - break; - default: - throw new ArgumentException($"Model {modelName} not supported."); - } - - var torchModule = new TorchModuleAdapter(module); - return Observable.Defer(() => { - return Observable.Return((ITorchModule)torchModule); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs similarity index 76% rename from src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs rename to src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs index a3c65bdc..0a221b5c 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs @@ -1,6 +1,6 @@ namespace Bonsai.ML.Torch.NeuralNets.Models { - public enum PretrainedModels + public enum ModelArchitecture { AlexNet, MobileNet, From 6d8442eee79db9566152111d01d472b2ba1b572e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 14:58:11 +0000 Subject: [PATCH 054/131] Added descriptions and documentation --- src/Bonsai.ML.Torch/Arange.cs | 5 +- src/Bonsai.ML.Torch/Concat.cs | 1 + src/Bonsai.ML.Torch/ConvertDataType.cs | 1 + src/Bonsai.ML.Torch/CreateTensor.cs | 33 +++++--- src/Bonsai.ML.Torch/Empty.cs | 23 ++++-- src/Bonsai.ML.Torch/InitializeTorchDevice.cs | 3 +- src/Bonsai.ML.Torch/Linspace.cs | 3 + src/Bonsai.ML.Torch/LoadTensor.cs | 33 ++++++++ src/Bonsai.ML.Torch/Mean.cs | 32 ++++++++ src/Bonsai.ML.Torch/MeshGrid.cs | 3 +- src/Bonsai.ML.Torch/Ones.cs | 13 +++ src/Bonsai.ML.Torch/Permute.cs | 1 + src/Bonsai.ML.Torch/Reshape.cs | 1 + src/Bonsai.ML.Torch/SaveTensor.cs | 34 ++++++++ src/Bonsai.ML.Torch/ScalarTypeLookup.cs | 53 +++++++++++++ src/Bonsai.ML.Torch/Set.cs | 11 +-- src/Bonsai.ML.Torch/Sum.cs | 28 +++++++ src/Bonsai.ML.Torch/Swapaxes.cs | 40 ---------- src/Bonsai.ML.Torch/TensorDataType.cs | 50 ------------ src/Bonsai.ML.Torch/TensorDataTypeLookup.cs | 52 ------------ src/Bonsai.ML.Torch/Tile.cs | 13 ++- src/Bonsai.ML.Torch/ToArray.cs | 1 + src/Bonsai.ML.Torch/ToDevice.cs | 1 + src/Bonsai.ML.Torch/ToImage.cs | 2 +- src/Bonsai.ML.Torch/ToNDArray.cs | 83 ++++++++++++++++++++ src/Bonsai.ML.Torch/ToTensor.cs | 2 +- src/Bonsai.ML.Torch/View.cs | 33 ++++++++ src/Bonsai.ML.Torch/Zeros.cs | 1 + 28 files changed, 383 insertions(+), 173 deletions(-) create mode 100644 src/Bonsai.ML.Torch/LoadTensor.cs create mode 100644 src/Bonsai.ML.Torch/Mean.cs create mode 100644 src/Bonsai.ML.Torch/SaveTensor.cs create mode 100644 src/Bonsai.ML.Torch/ScalarTypeLookup.cs create mode 100644 src/Bonsai.ML.Torch/Sum.cs delete mode 100644 src/Bonsai.ML.Torch/Swapaxes.cs delete mode 100644 src/Bonsai.ML.Torch/TensorDataType.cs delete mode 100644 src/Bonsai.ML.Torch/TensorDataTypeLookup.cs create mode 100644 src/Bonsai.ML.Torch/ToNDArray.cs create mode 100644 src/Bonsai.ML.Torch/View.cs diff --git a/src/Bonsai.ML.Torch/Arange.cs b/src/Bonsai.ML.Torch/Arange.cs index 14e3259b..fa80c08e 100644 --- a/src/Bonsai.ML.Torch/Arange.cs +++ b/src/Bonsai.ML.Torch/Arange.cs @@ -17,16 +17,19 @@ public class Arange /// /// The start of the range. /// + [Description("The start of the range.")] public int Start { get; set; } = 0; /// /// The end of the range. /// + [Description("The end of the range.")] public int End { get; set; } = 10; /// - /// The step of the range. + /// The step size between values. /// + [Description("The step size between values.")] public int Step { get; set; } = 1; /// diff --git a/src/Bonsai.ML.Torch/Concat.cs b/src/Bonsai.ML.Torch/Concat.cs index 34adf731..45402621 100644 --- a/src/Bonsai.ML.Torch/Concat.cs +++ b/src/Bonsai.ML.Torch/Concat.cs @@ -18,6 +18,7 @@ public class Concat /// /// The dimension along which to concatenate the tensors. /// + [Description("The dimension along which to concatenate the tensors.")] public long Dimension { get; set; } = 0; /// diff --git a/src/Bonsai.ML.Torch/ConvertDataType.cs b/src/Bonsai.ML.Torch/ConvertDataType.cs index 59981adc..efe3496b 100644 --- a/src/Bonsai.ML.Torch/ConvertDataType.cs +++ b/src/Bonsai.ML.Torch/ConvertDataType.cs @@ -16,6 +16,7 @@ public class ConvertDataType /// /// The scalar type to which to convert the input tensor. /// + [Description("The scalar type to which to convert the input tensor.")] public ScalarType Type { get; set; } = ScalarType.Float32; /// diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 66509bbc..2bddaa46 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -10,19 +10,21 @@ using static TorchSharp.torch; using Bonsai.ML.Data; using Bonsai.ML.Python; -using Bonsai.ML.Torch.Helpers; +using TorchSharp; namespace Bonsai.ML.Torch { /// - /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// Creates a tensor from the specified values. + /// Uses Python-like syntax to specify the tensor values. + /// For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". /// [Combinator] [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] [WorkflowElementCategory(ElementCategory.Source)] public class CreateTensor : ExpressionBuilder { - Range argumentRange = new Range(0, 1); + readonly Range argumentRange = new Range(0, 1); /// public override Range ArgumentRange => argumentRange; @@ -30,17 +32,21 @@ public class CreateTensor : ExpressionBuilder /// /// The data type of the tensor elements. /// - public TensorDataType Type + [Description("The data type of the tensor elements.")] + public ScalarType Type { get => scalarType; set => scalarType = value; } - private TensorDataType scalarType = TensorDataType.Float32; + private ScalarType scalarType = ScalarType.Float32; /// - /// The values of the tensor elements. Uses Python-like syntax to specify the tensor values. + /// The values of the tensor elements. + /// Uses Python-like syntax to specify the tensor values. + /// For example: "[[1, 2], [3, 4]]". /// + [Description("The values of the tensor elements. Uses Python-like syntax to specify the tensor values. For example: \"[[1, 2], [3, 4]]\".")] public string Values { get => values; @@ -56,6 +62,7 @@ public string Values /// The device on which to create the tensor. /// [XmlIgnore] + [Description("The device on which to create the tensor.")] public Device Device { get => device; @@ -98,7 +105,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) arrayVariable ); - var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + var tensorCreationMethodInfo = typeof(torch).GetMethod( "tensor", [ arrayVariable.Type, typeof(ScalarType?), @@ -111,7 +118,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) var tensorAssignment = Expression.Call( tensorCreationMethodInfo, tensorDataInitializationBlock, - Expression.Constant((ScalarType)scalarType, typeof(ScalarType?)), + Expression.Constant(scalarType, typeof(ScalarType?)), Expression.Constant(device, typeof(Device)), Expression.Constant(false, typeof(bool)), Expression.Constant(null, typeof(string).MakeArrayType()) @@ -140,7 +147,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp valueVariable ); - var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + var tensorCreationMethodInfo = typeof(torch).GetMethod( "tensor", [ valueVariable.Type, typeof(Device), @@ -155,7 +162,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp if (tensorCreationMethodInfo == null) { - tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + tensorCreationMethodInfo = typeof(torch).GetMethod( "tensor", [ valueVariable.Type, typeof(ScalarType?), @@ -193,7 +200,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp /// public override Expression Build(IEnumerable arguments) { - var returnType = TensorDataTypeLookup.GetTypeFromTensorDataType(scalarType); + var returnType = ScalarTypeLookup.GetTypeFromScalarType(scalarType); var argTypes = arguments.Select(arg => arg.Type).ToArray(); Type[] methodInfoArgumentTypes = [typeof(Tensor)]; @@ -225,7 +232,7 @@ public override Expression Build(IEnumerable arguments) finally { values = StringFormatter.FormatToPython(tensorValues).ToLower(); - scalarType = TensorDataTypeLookup.GetTensorDataTypeFromType(returnType); + scalarType = ScalarTypeLookup.GetScalarTypeFromType(returnType); } } @@ -242,7 +249,7 @@ public IObservable Process(Tensor tensor) /// public IObservable Process(IObservable source, Tensor tensor) { - return Observable.Select(source, (_) => tensor); + return source.Select(_ => tensor); } } } diff --git a/src/Bonsai.ML.Torch/Empty.cs b/src/Bonsai.ML.Torch/Empty.cs index 1c4f6af5..dafcee05 100644 --- a/src/Bonsai.ML.Torch/Empty.cs +++ b/src/Bonsai.ML.Torch/Empty.cs @@ -6,26 +6,27 @@ namespace Bonsai.ML.Torch { /// - /// Returns an empty tensor with the given data type and size. + /// Creates an empty tensor with the given data type and size. /// [Combinator] - [Description("Converts the input tensor into an OpenCV mat.")] - [WorkflowElementCategory(ElementCategory.Transform)] + [Description("Creates an empty tensor with the given data type and size.")] + [WorkflowElementCategory(ElementCategory.Source)] public class Empty { - /// /// The size of the tensor. /// + [Description("The size of the tensor.")] public long[] Size { get; set; } = [0]; /// /// The data type of the tensor elements. /// + [Description("The data type of the tensor elements.")] public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Returns an empty tensor with the given data type and size. + /// Creates an empty tensor with the given data type and size. /// public IObservable Process() { @@ -34,5 +35,17 @@ public IObservable Process() return Observable.Return(empty(Size, Type)); }); } + + /// + /// Generates an observable sequence of empty tensors for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return empty(Size, Type); + }); + } } } diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs index e82daa36..b8ce574c 100644 --- a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -11,12 +11,13 @@ namespace Bonsai.ML.Torch /// [Combinator] [Description("Initializes the Torch device with the specified device type.")] - [WorkflowElementCategory(ElementCategory.Transform)] + [WorkflowElementCategory(ElementCategory.Source)] public class InitializeTorchDevice { /// /// The device type to initialize. /// + [Description("The device type to initialize.")] public DeviceType DeviceType { get; set; } /// diff --git a/src/Bonsai.ML.Torch/Linspace.cs b/src/Bonsai.ML.Torch/Linspace.cs index ee6516cf..f7e27887 100644 --- a/src/Bonsai.ML.Torch/Linspace.cs +++ b/src/Bonsai.ML.Torch/Linspace.cs @@ -16,16 +16,19 @@ public class Linspace /// /// The start of the range. /// + [Description("The start of the range.")] public int Start { get; set; } = 0; /// /// The end of the range. /// + [Description("The end of the range.")] public int End { get; set; } = 1; /// /// The number of points to generate. /// + [Description("The number of points to generate.")] public int Count { get; set; } = 10; /// diff --git a/src/Bonsai.ML.Torch/LoadTensor.cs b/src/Bonsai.ML.Torch/LoadTensor.cs new file mode 100644 index 00000000..af1e7f05 --- /dev/null +++ b/src/Bonsai.ML.Torch/LoadTensor.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Loads a tensor from the specified file. + /// + [Combinator] + [Description("Loads a tensor from the specified file.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadTensor + { + /// + /// The path to the file containing the tensor. + /// + [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the file containing the tensor.")] + public string Path { get; set; } + + /// + /// Loads a tensor from the specified file. + /// + /// + public IObservable Process() + { + return Observable.Return(Tensor.Load(Path)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Mean.cs b/src/Bonsai.ML.Torch/Mean.cs new file mode 100644 index 00000000..294edf31 --- /dev/null +++ b/src/Bonsai.ML.Torch/Mean.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Takes the mean of the tensor along the specified dimensions. + /// + [Combinator] + [Description("Takes the mean of the tensor along the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Mean + { + /// + /// The dimensions along which to compute the mean. + /// + [Description("The dimensions along which to compute the mean.")] + public long[] Dimensions { get; set; } + + /// + /// Takes the mean of the tensor along the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.mean(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/MeshGrid.cs b/src/Bonsai.ML.Torch/MeshGrid.cs index 725b12a9..a32f9eca 100644 --- a/src/Bonsai.ML.Torch/MeshGrid.cs +++ b/src/Bonsai.ML.Torch/MeshGrid.cs @@ -11,13 +11,14 @@ namespace Bonsai.ML.Torch /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. /// [Combinator] - [Description("")] + [Description("Creates a mesh grid from an observable sequence of enumerable of 1-D tensors.")] [WorkflowElementCategory(ElementCategory.Source)] public class MeshGrid { /// /// The indexing mode to use for the mesh grid. /// + [Description("The indexing mode to use for the mesh grid.")] public string Indexing { get; set; } = "ij"; /// diff --git a/src/Bonsai.ML.Torch/Ones.cs b/src/Bonsai.ML.Torch/Ones.cs index 52bf8732..77d26577 100644 --- a/src/Bonsai.ML.Torch/Ones.cs +++ b/src/Bonsai.ML.Torch/Ones.cs @@ -16,6 +16,7 @@ public class Ones /// /// The size of the tensor. /// + [Description("The size of the tensor.")] public long[] Size { get; set; } = [0]; /// @@ -26,5 +27,17 @@ public IObservable Process() { return Observable.Defer(() => Observable.Return(ones(Size))); } + + /// + /// Generates an observable sequence of tensors filled with ones for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return ones(Size); + }); + } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Permute.cs b/src/Bonsai.ML.Torch/Permute.cs index a82107ba..507d31d2 100644 --- a/src/Bonsai.ML.Torch/Permute.cs +++ b/src/Bonsai.ML.Torch/Permute.cs @@ -16,6 +16,7 @@ public class Permute /// /// The permutation of the dimensions. /// + [Description("The permutation of the dimensions.")] public long[] Dimensions { get; set; } = [0]; /// diff --git a/src/Bonsai.ML.Torch/Reshape.cs b/src/Bonsai.ML.Torch/Reshape.cs index ebdc8e41..fdd07fa5 100644 --- a/src/Bonsai.ML.Torch/Reshape.cs +++ b/src/Bonsai.ML.Torch/Reshape.cs @@ -17,6 +17,7 @@ public class Reshape /// /// The dimensions of the reshaped tensor. /// + [Description("The dimensions of the reshaped tensor.")] public long[] Dimensions { get; set; } = [0]; /// diff --git a/src/Bonsai.ML.Torch/SaveTensor.cs b/src/Bonsai.ML.Torch/SaveTensor.cs new file mode 100644 index 00000000..1a3c4772 --- /dev/null +++ b/src/Bonsai.ML.Torch/SaveTensor.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Saves the input tensor to the specified file. + /// + [Combinator] + [Description("Saves the input tensor to the specified file.")] + [WorkflowElementCategory(ElementCategory.Sink)] + public class SaveTensor + { + /// + /// The path to the file where the tensor will be saved. + /// + [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the file where the tensor will be saved.")] + public string Path { get; set; } = string.Empty; + + /// + /// Saves the input tensor to the specified file. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Do(tensor => tensor.save(Path)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ScalarTypeLookup.cs b/src/Bonsai.ML.Torch/ScalarTypeLookup.cs new file mode 100644 index 00000000..1e4c6c57 --- /dev/null +++ b/src/Bonsai.ML.Torch/ScalarTypeLookup.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Provides methods to look up tensor data types. + /// + public static class ScalarTypeLookup + { + private static readonly Dictionary _lookup = new() + { + { ScalarType.Byte, (typeof(byte), "byte") }, + { ScalarType.Int16, (typeof(short), "short") }, + { ScalarType.Int32, (typeof(int), "int") }, + { ScalarType.Int64, (typeof(long), "long") }, + { ScalarType.Float32, (typeof(float), "float") }, + { ScalarType.Float64, (typeof(double), "double") }, + { ScalarType.Bool, (typeof(bool), "bool") }, + { ScalarType.Int8, (typeof(sbyte), "sbyte") }, + }; + + /// + /// Returns the type corresponding to the specified tensor data type. + /// + /// + /// + public static Type GetTypeFromScalarType(ScalarType type) => _lookup[type].Type; + + /// + /// Returns the string representation corresponding to the specified tensor data type. + /// + /// + /// + public static string GetStringFromScalarType(ScalarType type) => _lookup[type].StringValue; + + /// + /// Returns the tensor data type corresponding to the specified string representation. + /// + /// + /// + public static ScalarType GetScalarTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Returns the tensor data type corresponding to the specified type. + /// + /// + /// + public static ScalarType GetScalarTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Set.cs b/src/Bonsai.ML.Torch/Set.cs index 6b0fd86b..0e2965c6 100644 --- a/src/Bonsai.ML.Torch/Set.cs +++ b/src/Bonsai.ML.Torch/Set.cs @@ -19,18 +19,14 @@ public class Set /// /// The index at which to set the value. /// - public string Index - { - get => IndexHelper.Serialize(indexes); - set => indexes = IndexHelper.Parse(value); - } - - private TensorIndex[] indexes; + [Description("The index at which to set the value.")] + public string Index { get; set; } = string.Empty; /// /// The value to set at the specified index. /// [XmlIgnore] + [Description("The value to set at the specified index.")] public Tensor Value { get; set; } = null; /// @@ -41,6 +37,7 @@ public string Index public IObservable Process(IObservable source) { return source.Select(tensor => { + var indexes = IndexHelper.Parse(Index); return tensor.index_put_(Value, indexes); }); } diff --git a/src/Bonsai.ML.Torch/Sum.cs b/src/Bonsai.ML.Torch/Sum.cs new file mode 100644 index 00000000..d01efb95 --- /dev/null +++ b/src/Bonsai.ML.Torch/Sum.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + [Combinator] + [Description("Computes the sum of the input tensor elements along the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Sum + { + /// + /// The dimensions along which to compute the sum. + /// + public long[] Dimensions { get; set; } + + /// + /// Computes the sum of the input tensor elements along the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.sum(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Swapaxes.cs b/src/Bonsai.ML.Torch/Swapaxes.cs deleted file mode 100644 index 4777e882..00000000 --- a/src/Bonsai.ML.Torch/Swapaxes.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using System.Xml.Serialization; -using TorchSharp; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch -{ - [Combinator] - [Description("Swaps the axes of the input tensor.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Swapaxes - { - - /// - /// The value of axis 1. - /// - public long Axis1 { get; set; } = 0; - - /// - /// The value of axis 2. - /// - public long Axis2 { get; set; } = 1; - - /// - /// Returns an observable sequence that sets the value of the input tensor at the specified index. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - return swapaxes(tensor, Axis1, Axis2); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/TensorDataType.cs b/src/Bonsai.ML.Torch/TensorDataType.cs deleted file mode 100644 index f76a04c1..00000000 --- a/src/Bonsai.ML.Torch/TensorDataType.cs +++ /dev/null @@ -1,50 +0,0 @@ -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch -{ - /// - /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. - /// - public enum TensorDataType - { - /// - /// 8-bit unsigned integer. - /// - Byte = ScalarType.Byte, - - /// - /// 8-bit signed integer. - /// - Int8 = ScalarType.Int8, - - /// - /// 16-bit signed integer. - /// - Int16 = ScalarType.Int16, - - /// - /// 32-bit signed integer. - /// - Int32 = ScalarType.Int32, - - /// - /// 64-bit signed integer. - /// - Int64 = ScalarType.Int64, - - /// - /// 32-bit floating point. - /// - Float32 = ScalarType.Float32, - - /// - /// 64-bit floating point. - /// - Float64 = ScalarType.Float64, - - /// - /// Boolean. - /// - Bool = ScalarType.Bool - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs b/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs deleted file mode 100644 index 6e2b1be0..00000000 --- a/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Bonsai.ML.Torch.Helpers -{ - /// - /// Provides helper methods for working with tensor data types. - /// - public class TensorDataTypeLookup - { - private static readonly Dictionary _lookup = new Dictionary - { - { TensorDataType.Byte, (typeof(byte), "byte") }, - { TensorDataType.Int16, (typeof(short), "short") }, - { TensorDataType.Int32, (typeof(int), "int") }, - { TensorDataType.Int64, (typeof(long), "long") }, - { TensorDataType.Float32, (typeof(float), "float") }, - { TensorDataType.Float64, (typeof(double), "double") }, - { TensorDataType.Bool, (typeof(bool), "bool") }, - { TensorDataType.Int8, (typeof(sbyte), "sbyte") }, - }; - - /// - /// Returns the type corresponding to the specified tensor data type. - /// - /// - /// - public static Type GetTypeFromTensorDataType(TensorDataType type) => _lookup[type].Type; - - /// - /// Returns the string representation corresponding to the specified tensor data type. - /// - /// - /// - public static string GetStringFromTensorDataType(TensorDataType type) => _lookup[type].StringValue; - - /// - /// Returns the tensor data type corresponding to the specified string representation. - /// - /// - /// - public static TensorDataType GetTensorDataTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; - - /// - /// Returns the tensor data type corresponding to the specified type. - /// - /// - /// - public static TensorDataType GetTensorDataTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tile.cs b/src/Bonsai.ML.Torch/Tile.cs index 1df78122..df25b8ac 100644 --- a/src/Bonsai.ML.Torch/Tile.cs +++ b/src/Bonsai.ML.Torch/Tile.cs @@ -5,13 +5,24 @@ namespace Bonsai.ML.Torch { + /// + /// Constructs a tensor by repeating the elements of input. + /// [Combinator] - [Description("Constructs a tensor by repeating the elements of input. The Dimensions argument specifies the number of repetitions in each dimension.")] + [Description("Constructs a tensor by repeating the elements of input.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Tile { + /// + /// The number of repetitions in each dimension. + /// public long[] Dimensions { get; set; } + /// + /// Constructs a tensor by repeating the elements of input along the specified dimensions. + /// + /// + /// public IObservable Process(IObservable source) { return source.Select(tensor => { diff --git a/src/Bonsai.ML.Torch/ToArray.cs b/src/Bonsai.ML.Torch/ToArray.cs index 1c2c721a..e9ca21f1 100644 --- a/src/Bonsai.ML.Torch/ToArray.cs +++ b/src/Bonsai.ML.Torch/ToArray.cs @@ -38,6 +38,7 @@ public ToArray() /// /// Gets or sets the type mapping used to convert the input tensor into an array. /// + [Description("Gets or sets the type mapping used to convert the input tensor into an array.")] public TypeMapping Type { get; set; } /// diff --git a/src/Bonsai.ML.Torch/ToDevice.cs b/src/Bonsai.ML.Torch/ToDevice.cs index 531ff585..0377df46 100644 --- a/src/Bonsai.ML.Torch/ToDevice.cs +++ b/src/Bonsai.ML.Torch/ToDevice.cs @@ -19,6 +19,7 @@ public class ToDevice /// The device to which the input tensor should be moved. /// [XmlIgnore] + [Description("The device to which the input tensor should be moved.")] public Device Device { get; set; } /// diff --git a/src/Bonsai.ML.Torch/ToImage.cs b/src/Bonsai.ML.Torch/ToImage.cs index 0b9d8ccd..70c8227e 100644 --- a/src/Bonsai.ML.Torch/ToImage.cs +++ b/src/Bonsai.ML.Torch/ToImage.cs @@ -11,7 +11,7 @@ namespace Bonsai.ML.Torch /// Converts the input tensor into an OpenCV image. /// [Combinator] - [Description("")] + [Description("Converts the input tensor into an OpenCV image.")] [WorkflowElementCategory(ElementCategory.Transform)] public class ToImage { diff --git a/src/Bonsai.ML.Torch/ToNDArray.cs b/src/Bonsai.ML.Torch/ToNDArray.cs new file mode 100644 index 00000000..89b7b1e1 --- /dev/null +++ b/src/Bonsai.ML.Torch/ToNDArray.cs @@ -0,0 +1,83 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Converts the input tensor into an array of the specified element type and rank. + /// + [Combinator] + [Description("Converts the input tensor into an array of the specified element type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + public class ToNDArray : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public ToNDArray() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type mapping used to convert the input tensor into an array. + /// + [Description("Gets or sets the type mapping used to convert the input tensor into an array.")] + public TypeMapping Type { get; set; } + + /// + /// Gets or sets the rank of the output array. Must be greater than or equal to 1. + /// + [Description("Gets or sets the rank of the output array. Must be greater than or equal to 1.")] + public int Rank { get; set; } = 1; + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + var lengths = new int[Rank]; + Type arrayType = Array.CreateInstance(returnType, lengths).GetType(); + methodInfo = methodInfo.MakeGenericMethod(returnType, arrayType); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into an array of the specified element type. + /// + /// + /// + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + return (TResult)(object)tensor.data().ToNDArray(); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs index 061a13cb..7af26dc9 100644 --- a/src/Bonsai.ML.Torch/ToTensor.cs +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -11,7 +11,7 @@ namespace Bonsai.ML.Torch /// Converts the input value into a tensor. /// [Combinator] - [Description("")] + [Description("Converts the input value into a tensor.")] [WorkflowElementCategory(ElementCategory.Transform)] public class ToTensor { diff --git a/src/Bonsai.ML.Torch/View.cs b/src/Bonsai.ML.Torch/View.cs new file mode 100644 index 00000000..65a409be --- /dev/null +++ b/src/Bonsai.ML.Torch/View.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a new view of the input tensor with the specified dimensions. + /// + [Combinator] + [Description("Creates a new view of the input tensor with the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class View + { + /// + /// The dimensions of the reshaped tensor. + /// + [Description("The dimensions of the reshaped tensor.")] + public long[] Dimensions { get; set; } = [0]; + + /// + /// Creates a new view of the input tensor with the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.view(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs index 69673d4a..e4fb3c7a 100644 --- a/src/Bonsai.ML.Torch/Zeros.cs +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -16,6 +16,7 @@ public class Zeros /// /// The size of the tensor. /// + [Description("The size of the tensor.")] public long[] Size { get; set; } = [0]; /// From f40aa4fb5882d4503a7e5998f7c862662a613f33 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 17:21:29 +0000 Subject: [PATCH 055/131] Added some useful classes for linear algebra --- src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs | 26 +++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/Det.cs | 26 +++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs | 29 +++++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs | 27 ++++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs | 37 +++++++++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs | 34 +++++++++++++++++ src/Bonsai.ML.Torch/Vision/Normalize.cs | 14 +++---- 7 files changed, 186 insertions(+), 7 deletions(-) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Det.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs new file mode 100644 index 00000000..6843779a --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + /// + [Combinator] + [Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Cholesky + { + /// + /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(linalg.cholesky); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs new file mode 100644 index 00000000..90c5a45d --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the determinant of a square matrix. + /// + [Combinator] + [Description("Computes the determinant of a square matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Det + { + /// + /// Computes the determinant of a square matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(linalg.det); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs new file mode 100644 index 00000000..a94c8eb8 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs @@ -0,0 +1,29 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the eigenvalue decomposition of a square matrix if it exists. + /// + [Combinator] + [Description("Computes the eigenvalue decomposition of a square matrix if it exists.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Eig + { + /// + /// Computes the eigenvalue decomposition of a square matrix if it exists. + /// + /// + /// + public IObservable> Process(IObservable source) + { + return source.Select(tensor => { + var (eigvals, eigvecs) = linalg.eig(tensor); + return Tuple.Create(eigvals, eigvecs); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs new file mode 100644 index 00000000..58bb4ce3 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs @@ -0,0 +1,27 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the inverse of the input matrix. + /// + [Combinator] + [Description("Computes the inverse of the input matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Inv + { + /// + /// Computes the inverse of the input matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(inv); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs new file mode 100644 index 00000000..82914d39 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes a matrix norm. + /// + [Combinator] + [Description("Computes a matrix norm.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class MatrixNorm + { + + /// + /// The dimensions along which to compute the matrix norm. + /// + public long[] Dimensions { get; set; } = null; + + /// + /// If true, the reduced dimensions are retained in the result as dimensions with size one. + /// + public bool Keepdim { get; set; } = false; + + /// + /// Computes a matrix norm. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => linalg.norm(tensor, dims: Dimensions, keepdim: Keepdim)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs b/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs new file mode 100644 index 00000000..c722f53b --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the singular value decomposition (SVD) of a matrix. + /// + [Combinator] + [Description("Computes the singular value decomposition (SVD) of a matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class SVD + { + /// + /// Whether to compute the full or reduced SVD. + /// + public bool FullMatrices { get; set; } = false; + + /// + /// Computes the singular value decomposition (SVD) of a matrix. + /// + /// + /// + public IObservable> Process(IObservable source) + { + return source.Select(tensor => { + var (u, s, v) = linalg.svd(tensor, fullMatrices: FullMatrices); + return Tuple.Create(u, s, v); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Vision/Normalize.cs b/src/Bonsai.ML.Torch/Vision/Normalize.cs index fee8a3b9..f4fc80a3 100644 --- a/src/Bonsai.ML.Torch/Vision/Normalize.cs +++ b/src/Bonsai.ML.Torch/Vision/Normalize.cs @@ -5,23 +5,23 @@ using System.Reactive.Linq; using static TorchSharp.torch; using static TorchSharp.torchvision; -using System.Xml.Serialization; namespace Bonsai.ML.Torch.Vision { [Combinator] - [Description("")] + [Description("Normalizes the input tensor with the mean and standard deviation.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Normalize - { - private ITransform inputTransform; + { + public double[] Means { get; set; } = [ 0.1307 ]; + public double[] StdDevs { get; set; } = [ 0.3081 ]; + private ITransform transform = null; public IObservable Process(IObservable source) { - inputTransform = transforms.Normalize(new double[] { 0.1307 }, new double[] { 0.3081 }); - return source.Select(tensor => { - return inputTransform.call(tensor); + transform ??= transforms.Normalize(Means, StdDevs, tensor.dtype, tensor.device); + return transform.call(tensor); }); } } From abb108b6c14ea2c25c05010500a60fb207ec853a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 17:48:20 +0000 Subject: [PATCH 056/131] Adding classes for creating tensor indexes --- src/Bonsai.ML.Torch/Index.cs | 35 -------- src/Bonsai.ML.Torch/Index/BooleanIndex.cs | 42 ++++++++++ src/Bonsai.ML.Torch/Index/ColonIndex.cs | 36 ++++++++ src/Bonsai.ML.Torch/Index/EllipsesIndex.cs | 37 +++++++++ src/Bonsai.ML.Torch/Index/Index.cs | 37 +++++++++ src/Bonsai.ML.Torch/Index/IndexHelper.cs | 97 ++++++++++++++++++++++ src/Bonsai.ML.Torch/Index/NoneIndex.cs | 36 ++++++++ src/Bonsai.ML.Torch/Index/SingleIndex.cs | 42 ++++++++++ src/Bonsai.ML.Torch/Index/SliceIndex.cs | 54 ++++++++++++ src/Bonsai.ML.Torch/Index/TensorIndex.cs | 26 ++++++ src/Bonsai.ML.Torch/IndexHelper.cs | 91 -------------------- 11 files changed, 407 insertions(+), 126 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/Index.cs create mode 100644 src/Bonsai.ML.Torch/Index/BooleanIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/ColonIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/EllipsesIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/Index.cs create mode 100644 src/Bonsai.ML.Torch/Index/IndexHelper.cs create mode 100644 src/Bonsai.ML.Torch/Index/NoneIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/SingleIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/SliceIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/TensorIndex.cs delete mode 100644 src/Bonsai.ML.Torch/IndexHelper.cs diff --git a/src/Bonsai.ML.Torch/Index.cs b/src/Bonsai.ML.Torch/Index.cs deleted file mode 100644 index 818bb401..00000000 --- a/src/Bonsai.ML.Torch/Index.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch -{ - /// - /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. - /// Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis. - /// - [Combinator] - [Description("Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Index - { - /// - /// The indices to use for indexing the tensor. - /// - public string Indexes { get; set; } = string.Empty; - - /// - /// Indexes the input tensor with the specified indices. - /// - /// - /// - public IObservable Process(IObservable source) - { - var index = IndexHelper.Parse(Indexes); - return source.Select(tensor => { - return tensor.index(index); - }); - } - } -} diff --git a/src/Bonsai.ML.Torch/Index/BooleanIndex.cs b/src/Bonsai.ML.Torch/Index/BooleanIndex.cs new file mode 100644 index 00000000..f854aa56 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/BooleanIndex.cs @@ -0,0 +1,42 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents a boolean index that can be used to select elements from a tensor. +/// +[Combinator] +[Description("Represents a boolean index that can be used to select elements from a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class BooleanIndex +{ + /// + /// Gets or sets the boolean value used to select elements from a tensor. + /// + [Description("The boolean value used to select elements from a tensor.")] + public bool Value { get; set; } = false; + + /// + /// Generates the boolean index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Bool(Value)); + } + + /// + /// Processes the input sequence and generates the boolean index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Bool(Value)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/ColonIndex.cs b/src/Bonsai.ML.Torch/Index/ColonIndex.cs new file mode 100644 index 00000000..bfd9ca7b --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/ColonIndex.cs @@ -0,0 +1,36 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents the colon index used to select all elements along a given dimension. +/// +[Combinator] +[Description("Represents the colon index used to select all elements along a given dimension.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class ColonIndex +{ + /// + /// Generates the colon index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Colon); + } + + /// + /// Processes the input sequence and generates the colon index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Colon); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs b/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs new file mode 100644 index 00000000..06207a8e --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects all dimensions of a tensor. +/// +[Combinator] +[Description("Represents an index that selects all dimensions of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class EllipsisIndex +{ + + /// + /// Generates the ellipsis index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Ellipsis); + } + + /// + /// Processes the input sequence and generates the ellipsis index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Ellipsis); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/Index.cs b/src/Bonsai.ML.Torch/Index/Index.cs new file mode 100644 index 00000000..6846b8c7 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/Index.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Indexes a tensor by parsing the specified indices. +/// Indices are specified as a comma-separated values. +/// Currently supports Python-style slicing syntax. +/// This includes numeric indices, None, slices, and ellipsis. +/// +[Combinator] +[Description("Indexes a tensor by parsing the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Index +{ + /// + /// The indices to use for indexing the tensor. + /// + [Description("The indices to use for indexing the tensor. For example, '...,3:5,:'")] + public string Indexes { get; set; } = string.Empty; + + /// + /// Indexes the input tensor with the specified indices. + /// + /// + /// + public IObservable Process(IObservable source) + { + var index = IndexHelper.Parse(Indexes); + return source.Select(tensor => { + return tensor.index(index); + }); + } +} diff --git a/src/Bonsai.ML.Torch/Index/IndexHelper.cs b/src/Bonsai.ML.Torch/Index/IndexHelper.cs new file mode 100644 index 00000000..a6f60fb3 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/IndexHelper.cs @@ -0,0 +1,97 @@ +using System; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Provides helper methods to parse tensor indexes. +/// +public static class IndexHelper +{ + + /// + /// Parses the input string into an array of tensor indexes. + /// + /// + public static torch.TensorIndex[] Parse(string input) + { + if (string.IsNullOrEmpty(input)) + { + return [0]; + } + + var indexStrings = input.Split(','); + var indices = new torch.TensorIndex[indexStrings.Length]; + + for (int i = 0; i < indexStrings.Length; i++) + { + var indexString = indexStrings[i].Trim(); + if (int.TryParse(indexString, out int intIndex)) + { + indices[i] = torch.TensorIndex.Single(intIndex); + } + else if (indexString == ":") + { + indices[i] = torch.TensorIndex.Colon; + } + else if (indexString == "None") + { + indices[i] = torch.TensorIndex.None; + } + else if (indexString == "...") + { + indices[i] = torch.TensorIndex.Ellipsis; + } + else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") + { + indices[i] = torch.TensorIndex.Bool(indexString.ToLower() == "true"); + } + else if (indexString.Contains(":")) + { + var rangeParts = indexString.Split(':'); + rangeParts = [.. rangeParts.Where(p => { + p = p.Trim(); + return !string.IsNullOrEmpty(p); + })]; + + if (rangeParts.Length == 0) + { + indices[i] = torch.TensorIndex.Slice(); + } + else if (rangeParts.Length == 1) + { + indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0])); + } + else if (rangeParts.Length == 2) + { + indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); + } + else if (rangeParts.Length == 3) + { + indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + return indices; + } + + /// + /// Serializes the input array of tensor indexes into a string representation. + /// + /// + /// + public static string Serialize(torch.TensorIndex[] indexes) + { + return string.Join(", ", indexes); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/NoneIndex.cs b/src/Bonsai.ML.Torch/Index/NoneIndex.cs new file mode 100644 index 00000000..b10c9d86 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/NoneIndex.cs @@ -0,0 +1,36 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects no elements of a tensor. +/// +[Combinator] +[Description("Represents an index that selects no elements of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class NoneIndex +{ + /// + /// Generates the none index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.None); + } + + /// + /// Processes the input sequence and generates the none index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.None); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/SingleIndex.cs b/src/Bonsai.ML.Torch/Index/SingleIndex.cs new file mode 100644 index 00000000..9b3ec641 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/SingleIndex.cs @@ -0,0 +1,42 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects a single value of a tensor. +/// +[Combinator] +[Description("Represents an index that selects a single value of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class CreateTensorIndexSingle +{ + /// + /// Gets or sets the index value used to select a single element from a tensor. + /// + [Description("The index value used to select a single element from a tensor.")] + public long Index { get; set; } = 0; + + /// + /// Generates the single index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Single(Index)); + } + + /// + /// Processes the input sequence and generates the single index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Single(Index)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/SliceIndex.cs b/src/Bonsai.ML.Torch/Index/SliceIndex.cs new file mode 100644 index 00000000..b31802a4 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/SliceIndex.cs @@ -0,0 +1,54 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects a range of elements from a tensor. +/// +[Combinator] +[Description("Represents an index that selects a range of elements from a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class SliceIndex +{ + /// + /// Gets or sets the start index of the slice. + /// + [Description("The start index of the slice.")] + public long? Start { get; set; } = null; + + /// + /// Gets or sets the end index of the slice. + /// + [Description("The end index of the slice.")] + public long? End { get; set; } = null; + + /// + /// Gets or sets the step size of the slice. + /// + [Description("The step size of the slice.")] + public long? Step { get; set; } = null; + + /// + /// Generates the slice index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Slice(Start, End, Step)); + } + + /// + /// Processes the input sequence and generates the slice index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Slice(Start, End, Step)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/TensorIndex.cs b/src/Bonsai.ML.Torch/Index/TensorIndex.cs new file mode 100644 index 00000000..e3f6612d --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/TensorIndex.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that is created from a tensor. +/// +[Combinator] +[Description("Represents an index that is created from a tensor.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class TensorIndex +{ + /// + /// Converts the input tensor into a tensor index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(TorchSharp.torch.TensorIndex.Tensor); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/IndexHelper.cs b/src/Bonsai.ML.Torch/IndexHelper.cs deleted file mode 100644 index 2af466a0..00000000 --- a/src/Bonsai.ML.Torch/IndexHelper.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch -{ - /// - /// Provides helper methods to parse tensor indexes. - /// - public static class IndexHelper - { - - /// - /// Parses the input string into an array of tensor indexes. - /// - /// - public static TensorIndex[] Parse(string input) - { - if (string.IsNullOrEmpty(input)) - { - return [0]; - } - - var indexStrings = input.Split(','); - var indices = new TensorIndex[indexStrings.Length]; - - for (int i = 0; i < indexStrings.Length; i++) - { - var indexString = indexStrings[i].Trim(); - if (int.TryParse(indexString, out int intIndex)) - { - indices[i] = TensorIndex.Single(intIndex); - } - else if (indexString == ":") - { - indices[i] = TensorIndex.Colon; - } - else if (indexString == "None") - { - indices[i] = TensorIndex.None; - } - else if (indexString == "...") - { - indices[i] = TensorIndex.Ellipsis; - } - else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") - { - indices[i] = TensorIndex.Bool(indexString.ToLower() == "true"); - } - else if (indexString.Contains(":")) - { - var rangeParts = indexString.Split(':'); - if (rangeParts.Length == 0) - { - indices[i] = TensorIndex.Slice(); - } - else if (rangeParts.Length == 1) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0])); - } - else if (rangeParts.Length == 2) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); - } - else if (rangeParts.Length == 3) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); - } - else - { - throw new Exception($"Invalid index format: {indexString}"); - } - } - else - { - throw new Exception($"Invalid index format: {indexString}"); - } - } - return indices; - } - - /// - /// Serializes the input array of tensor indexes into a string representation. - /// - /// - /// - public static string Serialize(TensorIndex[] indexes) - { - return string.Join(", ", indexes); - } - } -} \ No newline at end of file From d8f19ff06375968329741014c6aa3eeec8fe5b41 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 17:49:02 +0000 Subject: [PATCH 057/131] Updated MNIST model architecture with correct fully connected layer size --- src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index b707e2d5..994aca73 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -15,9 +15,9 @@ public class MNIST : Module private Module conv1 = Conv2d(1, 32, 3); private Module conv2 = Conv2d(32, 64, 3); private Module fc1 = Linear(9216, 128); - private Module fc2 = Linear(128, 10); - - private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 }); + private Module fc2 = Linear(128, 128); + + private Module pool1 = MaxPool2d(kernelSize: [2, 2]); private Module relu1 = ReLU(); private Module relu2 = ReLU(); From 6331273b6bba56aa26447d4f90a53d5ebd41bc50 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:09:43 +0000 Subject: [PATCH 058/131] Added class to explicitly create a clone of a tensor --- src/Bonsai.ML.Torch/Clone.cs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Clone.cs diff --git a/src/Bonsai.ML.Torch/Clone.cs b/src/Bonsai.ML.Torch/Clone.cs new file mode 100644 index 00000000..b8dc15fd --- /dev/null +++ b/src/Bonsai.ML.Torch/Clone.cs @@ -0,0 +1,25 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Clones the input tensor. +/// +[Combinator] +[Description("Clones the input tensor.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Clone +{ + /// + /// Clones the input tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => tensor.clone()); + } +} \ No newline at end of file From 470850729c8fab7aac84575fdb4a66404c55fe01 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:10:39 +0000 Subject: [PATCH 059/131] Update to use correct width for an IplImage based on widthstep rather than width --- src/Bonsai.ML.Torch/OpenCVHelper.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs index 1ca049c9..a45e2228 100644 --- a/src/Bonsai.ML.Torch/OpenCVHelper.cs +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -82,10 +82,10 @@ public static Tensor ToTensor(IplImage image) { return empty([ 0, 0, 0 ]); } - - int width = image.Width; + // int width = image.Width; int height = image.Height; int channels = image.Channels; + var width = image.WidthStep / channels; var iplDepth = image.Depth; var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; From 1926bcd6993b1ef7570df749723583b7b3406c84 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:16:14 +0000 Subject: [PATCH 060/131] Explicitly use static torch.Tensor type for defining expressions --- src/Bonsai.ML.Torch/CreateTensor.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 2bddaa46..4436db0b 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -124,7 +124,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) Expression.Constant(null, typeof(string).MakeArrayType()) ); - var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var tensorVariable = Expression.Variable(typeof(torch.Tensor), "tensor"); var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); var buildTensor = Expression.Block( @@ -185,7 +185,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp tensorCreationMethodArguments ); - var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var tensorVariable = Expression.Variable(typeof(torch.Tensor), "tensor"); var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); var buildTensor = Expression.Block( From 597265f721b5a1b0ffbafd30eb91a63331dc1192 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:18:21 +0000 Subject: [PATCH 061/131] Update set with process overloads to handle passing in tensor index --- src/Bonsai.ML.Torch/Set.cs | 85 +++++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 29 deletions(-) diff --git a/src/Bonsai.ML.Torch/Set.cs b/src/Bonsai.ML.Torch/Set.cs index 0e2965c6..18dcc02a 100644 --- a/src/Bonsai.ML.Torch/Set.cs +++ b/src/Bonsai.ML.Torch/Set.cs @@ -6,40 +6,67 @@ using System.Xml.Serialization; using static TorchSharp.torch; -namespace Bonsai.ML.Torch +namespace Bonsai.ML.Torch; + +/// +/// Sets the value of the input tensor at the specified index. +/// +[Combinator] +[Description("Sets the value of the input tensor at the specified index.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Set { /// - /// Sets the value of the input tensor at the specified index. + /// The index at which to set the value. + /// + [Description("The index at which to set the value.")] + public string Index { get; set; } = string.Empty; + + /// + /// The value to set at the specified index. /// - [Combinator] - [Description("Sets the value of the input tensor at the specified index.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Set + [XmlIgnore] + [Description("The value to set at the specified index.")] + public Tensor Value { get; set; } = null; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) { - /// - /// The index at which to set the value. - /// - [Description("The index at which to set the value.")] - public string Index { get; set; } = string.Empty; + return source.Select(tensor => { + var indexes = Torch.Index.IndexHelper.Parse(Index); + return tensor.index_put_(Value, indexes); + }); + } - /// - /// The value to set at the specified index. - /// - [XmlIgnore] - [Description("The value to set at the specified index.")] - public Tensor Value { get; set; } = null; + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => { + var tensor = input.Item1; + var index = input.Item2; + return tensor.index_put_(Value, index); + }); + } - /// - /// Returns an observable sequence that sets the value of the input tensor at the specified index. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - var indexes = IndexHelper.Parse(Index); - return tensor.index_put_(Value, indexes); - }); - } + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => { + var tensor = input.Item1; + var indexes = input.Item2; + return tensor.index_put_(Value, indexes); + }); } } \ No newline at end of file From 5d2cb1e66c38c32b2d3ce4c8c20fd54c52d89d4e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:18:57 +0000 Subject: [PATCH 062/131] Fixed incorrectly generating ones instead of zeros --- src/Bonsai.ML.Torch/Zeros.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs index e4fb3c7a..e99bdce6 100644 --- a/src/Bonsai.ML.Torch/Zeros.cs +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -25,7 +25,7 @@ public class Zeros /// public IObservable Process() { - return Observable.Defer(() => Observable.Return(ones(Size))); + return Observable.Defer(() => Observable.Return(zeros(Size))); } /// @@ -36,7 +36,7 @@ public IObservable Process() public IObservable Process(IObservable source) { return source.Select(value => { - return ones(Size); + return zeros(Size); }); } } From 04a1f5685592ea7893eee83ef4c6c155b897ca9a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:21:37 +0000 Subject: [PATCH 063/131] Updated to correctly parse colons in string --- src/Bonsai.ML.Torch/Index/IndexHelper.cs | 38 +++++++++--------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/Bonsai.ML.Torch/Index/IndexHelper.cs b/src/Bonsai.ML.Torch/Index/IndexHelper.cs index a6f60fb3..b62c1c2c 100644 --- a/src/Bonsai.ML.Torch/Index/IndexHelper.cs +++ b/src/Bonsai.ML.Torch/Index/IndexHelper.cs @@ -1,6 +1,5 @@ using System; -using System.Linq; -using System.Reactive.Linq; +using System.Collections.Generic; using TorchSharp; namespace Bonsai.ML.Torch.Index; @@ -50,32 +49,23 @@ public static torch.TensorIndex[] Parse(string input) } else if (indexString.Contains(":")) { - var rangeParts = indexString.Split(':'); - rangeParts = [.. rangeParts.Where(p => { - p = p.Trim(); - return !string.IsNullOrEmpty(p); - })]; - - if (rangeParts.Length == 0) - { - indices[i] = torch.TensorIndex.Slice(); - } - else if (rangeParts.Length == 1) - { - indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0])); - } - else if (rangeParts.Length == 2) - { - indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); - } - else if (rangeParts.Length == 3) + string[] rangeParts = [.. indexString.Split(':')]; + var argsList = new List([null, null, null]); + try { - indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); + for (int j = 0; j < rangeParts.Length; j++) + { + if (!string.IsNullOrEmpty(rangeParts[j])) + { + argsList[j] = long.Parse(rangeParts[j]); + } + } } - else + catch (Exception) { throw new Exception($"Invalid index format: {indexString}"); } + indices[i] = torch.TensorIndex.Slice(argsList[0], argsList[1], argsList[2]); } else { @@ -84,7 +74,7 @@ public static torch.TensorIndex[] Parse(string input) } return indices; } - + /// /// Serializes the input array of tensor indexes into a string representation. /// From 2dd1b919974e5f2dc3d0454ab1ba3706541d2932 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:30:36 +0000 Subject: [PATCH 064/131] Updated to use collection expressions --- src/Bonsai.ML.Torch/CreateTensor.cs | 12 ++++++------ src/Bonsai.ML.Torch/OpenCVHelper.cs | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 4436db0b..46e1f611 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -24,7 +24,7 @@ namespace Bonsai.ML.Torch [WorkflowElementCategory(ElementCategory.Source)] public class CreateTensor : ExpressionBuilder { - readonly Range argumentRange = new Range(0, 1); + readonly Range argumentRange = new(0, 1); /// public override Range ArgumentRange => argumentRange; @@ -171,14 +171,14 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp ] ); - tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + tensorCreationMethodArguments = [.. tensorCreationMethodArguments.Prepend( Expression.Constant(scalarType, typeof(ScalarType?)) - ).ToArray(); + )]; } - tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( - tensorDataInitializationBlock - ).ToArray(); + tensorCreationMethodArguments = [.. tensorCreationMethodArguments.Prepend( + tensorDataInitializationBlock + )]; var tensorAssignment = Expression.Call( tensorCreationMethodInfo, diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs index a45e2228..9c7fc0e8 100644 --- a/src/Bonsai.ML.Torch/OpenCVHelper.cs +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -141,7 +141,7 @@ public unsafe static IplImage ToImage(Tensor tensor) var tensorType = tensor.dtype; var iplDepth = bitDepthLookup[tensorType].IplDepth; - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + var new_tensor = zeros([height, width, channels], tensorType).copy_(tensor); var res = THSTensor_data(new_tensor.Handle); var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); @@ -163,7 +163,7 @@ public unsafe static Mat ToMat(Tensor tensor) var tensorType = tensor.dtype; var depth = bitDepthLookup[tensorType].Depth; - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + var new_tensor = zeros([height, width, channels], tensorType).copy_(tensor); var res = THSTensor_data(new_tensor.Handle); var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); From f725c586424d083e56324119197d11424624d98d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:31:04 +0000 Subject: [PATCH 065/131] Added documentation --- src/Bonsai.ML.Torch/Sum.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Bonsai.ML.Torch/Sum.cs b/src/Bonsai.ML.Torch/Sum.cs index d01efb95..1e4c1a2c 100644 --- a/src/Bonsai.ML.Torch/Sum.cs +++ b/src/Bonsai.ML.Torch/Sum.cs @@ -5,6 +5,9 @@ namespace Bonsai.ML.Torch { + /// + /// Computes the sum of the input tensor elements along the specified dimensions. + /// [Combinator] [Description("Computes the sum of the input tensor elements along the specified dimensions.")] [WorkflowElementCategory(ElementCategory.Transform)] From 4f0776ee4b247ab5cd7eb08c127b3e6c0373b4d8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 11:50:59 +0000 Subject: [PATCH 066/131] Updated to use collection expressions --- src/Bonsai.ML.Torch/CreateTensor.cs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 46e1f611..52d8de1a 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -74,15 +74,13 @@ public Device Device private Expression BuildTensorFromArray(Array arrayValues, Type returnType) { var rank = arrayValues.Rank; - var lengths = Enumerable.Range(0, rank) - .Select(arrayValues.GetLength) - .ToArray(); + int[] lengths = [.. Enumerable.Range(0, rank).Select(arrayValues.GetLength)]; - var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); + var arrayCreationExpression = Expression.NewArrayBounds(returnType, [.. lengths.Select(len => Expression.Constant(len))]); var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); - var assignments = new List(); + List assignments = []; for (int i = 0; i < values.Length; i++) { var indices = new Expression[rank]; @@ -201,13 +199,11 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp public override Expression Build(IEnumerable arguments) { var returnType = ScalarTypeLookup.GetTypeFromScalarType(scalarType); - var argTypes = arguments.Select(arg => arg.Type).ToArray(); + Type[] argTypes = [.. arguments.Select(arg => arg.Type)]; Type[] methodInfoArgumentTypes = [typeof(Tensor)]; - var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) - .Where(m => m.Name == "Process") - .ToArray(); + MethodInfo[] methods = [.. typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance).Where(m => m.Name == "Process")]; var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) .MakeGenericMethod( From 5b04e35cdf8ad01c2dae6c63e600750de2bb9b8b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 11:51:24 +0000 Subject: [PATCH 067/131] Add process overload to initialize device on input --- src/Bonsai.ML.Torch/InitializeTorchDevice.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs index b8ce574c..a598b794 100644 --- a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -32,5 +32,17 @@ public IObservable Process() return Observable.Return(new Device(DeviceType)); }); } + + /// + /// Initializes the Torch device when the input sequence produces an element. + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => { + InitializeDeviceType(DeviceType); + return new Device(DeviceType); + }); + } } } \ No newline at end of file From 2a6190d44278d6e8c445736155a73a8a306ae864 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 11:52:22 +0000 Subject: [PATCH 068/131] Added documentation --- src/Bonsai.ML.Torch/NeuralNets/Backward.cs | 21 ++++++++- src/Bonsai.ML.Torch/NeuralNets/Forward.cs | 17 ++++++-- .../NeuralNets/ITorchModule.cs | 14 +++++- .../NeuralNets/LoadModuleFromArchitecture.cs | 26 ++++++++++- .../NeuralNets/LoadScriptModule.cs | 17 +++++++- src/Bonsai.ML.Torch/NeuralNets/Loss.cs | 12 +++--- .../NeuralNets/Models/AlexNet.cs | 28 ++++++------ .../NeuralNets/Models/MNIST.cs | 43 +++++++++++-------- .../NeuralNets/Models/MobileNet.cs | 38 +++++++++------- .../NeuralNets/Models/ModelArchitecture.cs | 14 ++++++ src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs | 12 +++--- src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs | 22 ++++++++-- .../NeuralNets/TorchModuleAdapter.cs | 25 ++++++++--- src/Bonsai.ML.Torch/Vision/Normalize.cs | 20 ++++++++- 14 files changed, 232 insertions(+), 77 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs index 5fa4902e..328c35ba 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs @@ -8,18 +8,35 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Trains a model using backpropagation. + /// [Combinator] - [Description("")] + [Description("Trains a model using backpropagation.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Backward { + /// + /// The optimizer to use for training. + /// public Optimizer Optimizer { get; set; } + /// + /// The model to train. + /// [XmlIgnore] public ITorchModule Model { get; set; } + /// + /// The loss function to use for training. + /// public Loss Loss { get; set; } + /// + /// Trains the model using backpropagation. + /// + /// + /// public IObservable Process(IObservable> source) { optim.Optimizer optimizer = null; @@ -47,7 +64,7 @@ public IObservable Process(IObservable> source) { optimizer.zero_grad(); - var prediction = Model.forward(data); + var prediction = Model.Forward(data); var output = loss.forward(prediction, target); output.backward(); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs index 3aae4012..175ed3c0 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -3,23 +3,32 @@ using System.Reactive.Linq; using static TorchSharp.torch; using System.Xml.Serialization; -using TorchSharp.Modules; -using TorchSharp; namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Runs forward inference on the input tensor using the specified model. + /// [Combinator] - [Description("")] + [Description("Runs forward inference on the input tensor using the specified model.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Forward { + /// + /// The model to use for inference. + /// [XmlIgnore] public ITorchModule Model { get; set; } + /// + /// Runs forward inference on the input tensor using the specified model. + /// + /// + /// public IObservable Process(IObservable source) { Model.Module.eval(); - return source.Select(Model.forward); + return source.Select(Model.Forward); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs index e7ebf994..5cde6f73 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs @@ -2,9 +2,21 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Represents an interface for a Torch module. + /// public interface ITorchModule { + /// + /// The module. + /// public nn.Module Module { get; } - public Tensor forward(Tensor tensor); + + /// + /// Runs forward inference on the input tensor using the specified model. + /// + /// + /// + public Tensor Forward(Tensor tensor); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs index 43d74544..8276156f 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs @@ -6,20 +6,39 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Loads a neural network module from a specified architecture. + /// [Combinator] - [Description("")] + [Description("Loads a neural network module from a specified architecture.")] [WorkflowElementCategory(ElementCategory.Source)] public class LoadModuleFromArchitecture { + /// + /// The model architecture to load. + /// + [Description("The model architecture to load.")] public Models.ModelArchitecture ModelArchitecture { get; set; } + /// + /// The device on which to load the model. + /// + [Description("The device on which to load the model.")] [XmlIgnore] public Device Device { get; set; } + /// + /// The optional path to the model weights file. + /// + [Description("The optional path to the model weights file.")] [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string ModelWeightsPath { get; set; } private int numClasses = 10; + /// + /// The number of classes in the dataset. + /// + [Description("The number of classes in the dataset.")] public int NumClasses { get => numClasses; @@ -36,6 +55,11 @@ public int NumClasses } } + /// + /// Loads the neural network module from the specified architecture. + /// + /// + /// public IObservable Process() { var modelArchitecture = ModelArchitecture.ToString().ToLower(); diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs index 7e6c73fe..fb3b2b78 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -6,18 +6,33 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Loads a TorchScript module from the specified file path. + /// [Combinator] - [Description("")] + [Description("Loads a TorchScript module from the specified file path.")] [WorkflowElementCategory(ElementCategory.Source)] public class LoadScriptModule { + /// + /// The device on which to load the model. + /// + [Description("The device on which to load the model.")] [XmlIgnore] public Device Device { get; set; } = CPU; + /// + /// The path to the TorchScript model file. + /// + [Description("The path to the TorchScript model file.")] [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string ModelPath { get; set; } + /// + /// Loads the TorchScript module from the specified file path. + /// + /// public IObservable Process() { var scriptModule = jit.load(ModelPath, Device); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Loss.cs b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs index ff003019..376139c1 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Loss.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs @@ -1,13 +1,13 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using static TorchSharp.torch.optim; - namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Represents a loss function that computes the loss value for a given input and target tensor. + /// public enum Loss { + /// + /// Computes the negative log likelihood loss. + /// NLLLoss, } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs index 4ca9f79c..2ded685d 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -1,13 +1,6 @@ -using System; -using System.IO; -using System.Linq; -using System.Collections.Generic; -using System.Diagnostics; - using TorchSharp; using static TorchSharp.torch; using static TorchSharp.torch.nn; -using static TorchSharp.torch.nn.functional; namespace Bonsai.ML.Torch.NeuralNets.Models { @@ -20,24 +13,30 @@ public class AlexNet : Module private readonly Module avgPool; private readonly Module classifier; + /// + /// Constructs a new AlexNet model. + /// + /// + /// + /// public AlexNet(string name, int numClasses, Device device = null) : base(name) { features = Sequential( ("c1", Conv2d(3, 64, kernelSize: 3, stride: 2, padding: 1)), ("r1", ReLU(inplace: true)), - ("mp1", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("mp1", MaxPool2d(kernelSize: [ 2, 2 ])), ("c2", Conv2d(64, 192, kernelSize: 3, padding: 1)), ("r2", ReLU(inplace: true)), - ("mp2", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("mp2", MaxPool2d(kernelSize: [ 2, 2 ])), ("c3", Conv2d(192, 384, kernelSize: 3, padding: 1)), ("r3", ReLU(inplace: true)), ("c4", Conv2d(384, 256, kernelSize: 3, padding: 1)), ("r4", ReLU(inplace: true)), ("c5", Conv2d(256, 256, kernelSize: 3, padding: 1)), ("r5", ReLU(inplace: true)), - ("mp3", MaxPool2d(kernelSize: new long[] { 2, 2 }))); + ("mp3", MaxPool2d(kernelSize: [ 2, 2 ]))); - avgPool = AdaptiveAvgPool2d(new long[] { 2, 2 }); + avgPool = AdaptiveAvgPool2d([ 2, 2 ]); classifier = Sequential( ("d1", Dropout()), @@ -56,12 +55,17 @@ public AlexNet(string name, int numClasses, Device device = null) : base(name) this.to(device); } + /// + /// Forward pass of the AlexNet model. + /// + /// + /// public override Tensor forward(Tensor input) { var f = features.forward(input); var avg = avgPool.forward(f); - var x = avg.view(new long[] { avg.shape[0], 256 * 2 * 2 }); + var x = avg.view([ avg.shape[0], 256 * 2 * 2 ]); return classifier.forward(x); } diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index 994aca73..32d4bf8a 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -1,34 +1,36 @@ -using System; -using System.IO; -using System.Collections.Generic; -using System.Diagnostics; using TorchSharp; using static TorchSharp.torch; - using static TorchSharp.torch.nn; -using static TorchSharp.torch.nn.functional; namespace Bonsai.ML.Torch.NeuralNets.Models { + /// + /// Represents a simple convolutional neural network for the MNIST dataset. + /// public class MNIST : Module { - private Module conv1 = Conv2d(1, 32, 3); - private Module conv2 = Conv2d(32, 64, 3); - private Module fc1 = Linear(9216, 128); - private Module fc2 = Linear(128, 128); + private readonly Module conv1 = Conv2d(1, 32, 3); + private readonly Module conv2 = Conv2d(32, 64, 3); + private readonly Module fc1 = Linear(9216, 128); + private readonly Module fc2 = Linear(128, 128); - private Module pool1 = MaxPool2d(kernelSize: [2, 2]); + private readonly Module pool1 = MaxPool2d(kernelSize: [2, 2]); - private Module relu1 = ReLU(); - private Module relu2 = ReLU(); - private Module relu3 = ReLU(); + private readonly Module relu1 = ReLU(); + private readonly Module relu2 = ReLU(); + private readonly Module relu3 = ReLU(); - private Module dropout1 = Dropout(0.25); - private Module dropout2 = Dropout(0.5); + private readonly Module dropout1 = Dropout(0.25); + private readonly Module dropout2 = Dropout(0.5); - private Module flatten = Flatten(); - private Module logsm = LogSoftmax(1); + private readonly Module flatten = Flatten(); + private readonly Module logsm = LogSoftmax(1); + /// + /// Constructs a new MNIST model. + /// + /// + /// public MNIST(string name, Device device = null) : base(name) { RegisterComponents(); @@ -37,6 +39,11 @@ public MNIST(string name, Device device = null) : base(name) this.to(device); } + /// + /// Forward pass of the MNIST model. + /// + /// + /// public override Tensor forward(Tensor input) { var l11 = conv1.forward(input); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs index f82a33f9..6ede9818 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -7,33 +7,34 @@ namespace Bonsai.ML.Torch.NeuralNets.Models { /// - /// Modified version of MobileNet to classify CIFAR10 32x32 images. + /// MobileNet model. /// - /// - /// With an unaugmented CIFAR-10 data set, the author of this saw training converge - /// at roughly 75% accuracy on the test set, over the course of 1500 epochs. - /// public class MobileNet : Module { - // The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenet.py - // Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE - - private readonly long[] planes = new long[] { 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 }; - private readonly long[] strides = new long[] { 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 }; + private readonly long[] planes = [ 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 ]; + private readonly long[] strides = [ 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 ]; private readonly Module layers; + /// + /// Constructs a new MobileNet model. + /// + /// + /// + /// + /// public MobileNet(string name, int numClasses, Device device = null) : base(name) { if (planes.Length != strides.Length) throw new ArgumentException("'planes' and 'strides' must have the same length."); - var modules = new List<(string, Module)>(); - - modules.Add(($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false))); - modules.Add(($"bnrm2d-first", BatchNorm2d(32))); - modules.Add(($"relu-first", ReLU())); + var modules = new List<(string, Module)> + { + ($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false)), + ($"bnrm2d-first", BatchNorm2d(32)), + ($"relu-first", ReLU()) + }; MakeLayers(modules, 32); - modules.Add(("avgpool", AvgPool2d(new long[] { 2, 2 }))); + modules.Add(("avgpool", AvgPool2d([2, 2]))); modules.Add(("flatten", Flatten())); modules.Add(($"linear", Linear(planes[planes.Length-1], numClasses))); @@ -63,6 +64,11 @@ private void MakeLayers(List<(string, Module)> modules, long in_ } } + /// + /// Forward pass of the MobileNet model. + /// + /// + /// public override Tensor forward(Tensor input) { return layers.forward(input); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs index 0a221b5c..98a30216 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs @@ -1,9 +1,23 @@ namespace Bonsai.ML.Torch.NeuralNets.Models { + /// + /// Represents the architecture of a neural network model. + /// public enum ModelArchitecture { + /// + /// The AlexNet model architecture. + /// AlexNet, + + /// + /// The MobileNet model architecture. + /// MobileNet, + + /// + /// The MNIST model architecture. + /// MNIST } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs index 3c0d4bb7..4ab09dbd 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs @@ -1,13 +1,13 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using static TorchSharp.torch.optim; - namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Represents an optimizer that updates the parameters of a neural network. + /// public enum Optimizer { + /// + /// Adam optimizer. + /// Adam, } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs index 314e9f3f..3d5ffd97 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs @@ -1,23 +1,37 @@ using System; using System.ComponentModel; using System.Reactive.Linq; -using static TorchSharp.torch; using System.Xml.Serialization; -using TorchSharp.Modules; -using TorchSharp; namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Saves the model to a file. + /// [Combinator] - [Description("")] + [Description("Saves the model to a file.")] [WorkflowElementCategory(ElementCategory.Sink)] public class SaveModel { + /// + /// The model to save. + /// + [Description("The model to save.")] [XmlIgnore] public ITorchModule Model { get; set; } + /// + /// The path to save the model. + /// + [Description("The path to save the model.")] public string ModelPath { get; set; } + /// + /// Saves the model to the specified file path. + /// + /// + /// + /// public IObservable Process(IObservable source) { return source.Do(input => { diff --git a/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs index a1c44d96..3ec35071 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs @@ -4,33 +4,48 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Represents a torch module adapter that wraps a torch module or script module. + /// public class TorchModuleAdapter : ITorchModule { private readonly nn.Module _module = null; private readonly jit.ScriptModule _scriptModule = null; - private Func forwardFunc; + private readonly Func _forwardFunc; + /// + /// The module. + /// public nn.Module Module { get; } + /// + /// Initializes a new instance of the class. + /// + /// public TorchModuleAdapter(nn.Module module) { _module = module; - forwardFunc = _module.forward; + _forwardFunc = _module.forward; Module = _module; } + /// + /// Initializes a new instance of the class. + /// + /// public TorchModuleAdapter(jit.ScriptModule scriptModule) { _scriptModule = scriptModule; - forwardFunc = _scriptModule.forward; + _forwardFunc = _scriptModule.forward; Module = _scriptModule; } - public Tensor forward(Tensor input) + /// + public Tensor Forward(Tensor input) { - return forwardFunc(input); + return _forwardFunc(input); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Vision/Normalize.cs b/src/Bonsai.ML.Torch/Vision/Normalize.cs index f4fc80a3..60b87c44 100644 --- a/src/Bonsai.ML.Torch/Vision/Normalize.cs +++ b/src/Bonsai.ML.Torch/Vision/Normalize.cs @@ -8,21 +8,39 @@ namespace Bonsai.ML.Torch.Vision { + /// + /// Normalizes the input tensor with the mean and standard deviation. + /// [Combinator] [Description("Normalizes the input tensor with the mean and standard deviation.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Normalize { + /// + /// The mean values for each channel. + /// + [Description("The mean values for each channel.")] public double[] Means { get; set; } = [ 0.1307 ]; + + /// + /// The standard deviation values for each channel. + /// + [Description("The standard deviation values for each channel.")] public double[] StdDevs { get; set; } = [ 0.3081 ]; + private ITransform transform = null; + /// + /// Normalizes the input tensor with the mean and standard deviation. + /// + /// + /// public IObservable Process(IObservable source) { return source.Select(tensor => { transform ??= transforms.Normalize(Means, StdDevs, tensor.dtype, tensor.device); return transform.call(tensor); - }); + }).Finally(() => transform = null); } } } \ No newline at end of file From 96014604cb64a1abef24e0e05cb7986f4154d0b1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 12:57:40 +0000 Subject: [PATCH 069/131] Added file name editor attribute to model path --- src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs index 3d5ffd97..c426aedf 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs @@ -24,6 +24,7 @@ public class SaveModel /// The path to save the model. /// [Description("The path to save the model.")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string ModelPath { get; set; } /// From 506d8a142d695165073817e1c08a3582503a982f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 13:52:45 +0000 Subject: [PATCH 070/131] Updated to latest torchsharp version --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 4 ++-- src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs | 16 ++++++++-------- src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs | 2 +- .../NeuralNets/Models/MobileNet.cs | 6 +++--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj index 97bfe18c..3a2f0298 100644 --- a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -8,8 +8,8 @@ - - + + diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs index 2ded685d..c80a3d50 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -22,19 +22,19 @@ public class AlexNet : Module public AlexNet(string name, int numClasses, Device device = null) : base(name) { features = Sequential( - ("c1", Conv2d(3, 64, kernelSize: 3, stride: 2, padding: 1)), + ("c1", Conv2d(3, 64, kernel_size: 3, stride: 2, padding: 1)), ("r1", ReLU(inplace: true)), - ("mp1", MaxPool2d(kernelSize: [ 2, 2 ])), - ("c2", Conv2d(64, 192, kernelSize: 3, padding: 1)), + ("mp1", MaxPool2d(kernel_size: [ 2, 2 ])), + ("c2", Conv2d(64, 192, kernel_size: 3, padding: 1)), ("r2", ReLU(inplace: true)), - ("mp2", MaxPool2d(kernelSize: [ 2, 2 ])), - ("c3", Conv2d(192, 384, kernelSize: 3, padding: 1)), + ("mp2", MaxPool2d(kernel_size: [ 2, 2 ])), + ("c3", Conv2d(192, 384, kernel_size: 3, padding: 1)), ("r3", ReLU(inplace: true)), - ("c4", Conv2d(384, 256, kernelSize: 3, padding: 1)), + ("c4", Conv2d(384, 256, kernel_size: 3, padding: 1)), ("r4", ReLU(inplace: true)), - ("c5", Conv2d(256, 256, kernelSize: 3, padding: 1)), + ("c5", Conv2d(256, 256, kernel_size: 3, padding: 1)), ("r5", ReLU(inplace: true)), - ("mp3", MaxPool2d(kernelSize: [ 2, 2 ]))); + ("mp3", MaxPool2d(kernel_size: [ 2, 2 ]))); avgPool = AdaptiveAvgPool2d([ 2, 2 ]); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index 32d4bf8a..e5895d41 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -14,7 +14,7 @@ public class MNIST : Module private readonly Module fc1 = Linear(9216, 128); private readonly Module fc2 = Linear(128, 128); - private readonly Module pool1 = MaxPool2d(kernelSize: [2, 2]); + private readonly Module pool1 = MaxPool2d(kernel_size: [2, 2]); private readonly Module relu1 = ReLU(); private readonly Module relu2 = ReLU(); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs index 6ede9818..0faa1062 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -29,7 +29,7 @@ public MobileNet(string name, int numClasses, Device device = null) : base(name) var modules = new List<(string, Module)> { - ($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false)), + ($"conv2d-first", Conv2d(3, 32, kernel_size: 3, stride: 1, padding: 1, bias: false)), ($"bnrm2d-first", BatchNorm2d(32)), ($"relu-first", ReLU()) }; @@ -53,10 +53,10 @@ private void MakeLayers(List<(string, Module)> modules, long in_ var out_planes = planes[i]; var stride = strides[i]; - modules.Add(($"conv2d-{i}a", Conv2d(in_planes, in_planes, kernelSize: 3, stride: stride, padding: 1, groups: in_planes, bias: false))); + modules.Add(($"conv2d-{i}a", Conv2d(in_planes, in_planes, kernel_size: 3, stride: stride, padding: 1, groups: in_planes, bias: false))); modules.Add(($"bnrm2d-{i}a", BatchNorm2d(in_planes))); modules.Add(($"relu-{i}a", ReLU())); - modules.Add(($"conv2d-{i}b", Conv2d(in_planes, out_planes, kernelSize: 1L, stride: 1L, padding: 0L, bias: false))); + modules.Add(($"conv2d-{i}b", Conv2d(in_planes, out_planes, kernel_size: 1L, stride: 1L, padding: 0L, bias: false))); modules.Add(($"bnrm2d-{i}b", BatchNorm2d(out_planes))); modules.Add(($"relu-{i}b", ReLU())); From ccfe40050b1f92a61b14af6a67c6ddd1d5cd2b98 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 14:09:12 +0000 Subject: [PATCH 071/131] Update class name to reflect file name --- src/Bonsai.ML.Torch/Index/SingleIndex.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/Index/SingleIndex.cs b/src/Bonsai.ML.Torch/Index/SingleIndex.cs index 9b3ec641..e2f5decd 100644 --- a/src/Bonsai.ML.Torch/Index/SingleIndex.cs +++ b/src/Bonsai.ML.Torch/Index/SingleIndex.cs @@ -12,7 +12,7 @@ namespace Bonsai.ML.Torch.Index; [Combinator] [Description("Represents an index that selects a single value of a tensor.")] [WorkflowElementCategory(ElementCategory.Source)] -public class CreateTensorIndexSingle +public class SingleIndex { /// /// Gets or sets the index value used to select a single element from a tensor. From ff8b463c7c30c374c65c04842595fbc74a4e59cc Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 23 Jan 2025 16:49:39 +0000 Subject: [PATCH 072/131] Added documentation to package --- README.md | 3 +++ docs/articles/Torch/torch-getting-started.md | 13 ++++++++++ docs/articles/Torch/torch-overview.md | 27 ++++++++++++++++++++ docs/articles/toc.yml | 7 ++++- 4 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 docs/articles/Torch/torch-getting-started.md create mode 100644 docs/articles/Torch/torch-overview.md diff --git a/README.md b/README.md index 1b27738a..508c9a05 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,9 @@ Facilitates inference using Hidden Markov Models (HMMs). It interfaces with the ### Bonsai.ML.HiddenMarkovModels.Design Visualizers and editor features for the HiddenMarkovModels package. +### Bonsai.ML.Torch +Interfaces with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# wrapper around the torch library. Provides tooling for manipulating tensors, performing linear algebra, training and inference with deep neural networks, and more. + > [!NOTE] > Bonsai.ML packages can be installed through Bonsai's integrated package manager and are generally ready for immediate use. However, some packages may require additional installation steps. Refer to the specific package section for detailed installation guides and documentation. diff --git a/docs/articles/Torch/torch-getting-started.md b/docs/articles/Torch/torch-getting-started.md new file mode 100644 index 00000000..6a673b9b --- /dev/null +++ b/docs/articles/Torch/torch-getting-started.md @@ -0,0 +1,13 @@ +# Getting Started + +The aim of the `Bonsai.ML.Torch` package is to integrate the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# wrapper around the powerful libtorch library, into Bonsai. In the current version, the package primarily provides tooling and functionality for users to interact with and manipulate `Tensor`s, the core data type of libtorch which underlies many of the advanced torch operations. Additionally, the package provides some capabilities for defining neural network architectures, running forward inference, and learning via back propogation. + +## Tensor Operations +The package provides several ways to work with tensors. Users can initialize tensors, (`Ones`, `Zeros`, etc.), create tensors from .NET data types, (`ToTensor`), and define custom tensors using Python-like syntax (`CreateTensor`). Tensors can be converted back to .NET types using the `ToArray` node (for flattening tensors into a single array) or the `ToNDArray` node (for preserving multidimensional array shapes). Furthermore, the `Tensor` data types contains many extension methods which can be used via scripting, such as using `ExpressionTransform` (for example, it.sum() to sum a tensor, or it.T to transpose), and works with overloaded operators, for example, `Zip` -> `Multiply`. Thus, `ExpressionTransform` can also be used to access individual elements of a tensor, using the syntax `it.data.ReadCpuT(0)` where `T` is a primitive .NET data type. + + +## Running on the GPU +Users must be explicit about running tensors on the GPU. First, the `InitializeDeviceType` node must run with a CUDA-compatible GPU. Afterwards, tensors are moved to the GPU using the `ToDevice` node. Converting tensors back to .NET data types requires moving the tensor back to the CPU before converting. + +## Neural Networks +The package provides initial support for working with torch `Module`s, the conventional object for deep neural networks. The `LoadModuleFromArchitecture` node allows users to select from a list of common architectures, and can optionally load in pretrained weights from disk. Additionally, the package supports loading `TorchScript` modules with the `LoadScriptModule` node, which enables users to use torch modules saved in the `.pt` file format. Users can then use the `Forward` node to run inference and the `Backward` node to run back propogation. \ No newline at end of file diff --git a/docs/articles/Torch/torch-overview.md b/docs/articles/Torch/torch-overview.md new file mode 100644 index 00000000..4ca6cd08 --- /dev/null +++ b/docs/articles/Torch/torch-overview.md @@ -0,0 +1,27 @@ +# Bonsai.ML.Torch Overview + +The Torch package provides a Bonsai interface to interact with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# implementation of the torch library. + +## General Guide + +The Bonsai.ML.Torch package can be installed through the Bonsai Package Manager and depends on the TorchSharp library. Additionally, running the package requires installing the specific torch DLLs needed for your desired application. The steps for installing are outlined below. + +### Running on the CPU +For running the package using the CPU, the `TorchSharp-cpu` library can be installed though the `nuget` package source. + +### Running on the GPU +To run torch on the GPU, you first need to ensure that you have a CUDA compatible device installed on your system. + +Next, you must follow the [CUDA installation guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) or the [guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html). Make sure to install the correct `CUDA v12.1` version [found here](https://developer.nvidia.com/cuda-12-1-0-download-archive). Ensure that you have the correct CUDA version (v12.1) installed, as `TorchSharp` currently only supports this version. + +Next, you need to install the `cuDNN v9` library following the [guide for Windows](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/windows.html) or the [guide for Linux](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/linux.html). Again, you need to ensure you have the correct version installed (v9). You should consult [nvidia's support matrix](https://docs.nvidia.com/deeplearning/cudnn/latest/reference/support-matrix.html) to ensure the versions of CUDA and cuDNN you installed are compatible with your specific OS, graphics driver, and hardware. + +Once complete, you need to install the cuda-compatible torch libraries and place them into the correct location. You can download the libraries from [the pytorch website](https://pytorch.org/get-started/locally/) with the following options selected: + +- PyTorch Build: Stable (2.5.1) +- OS: [Your OS] +- Package: LibTorch +- Language: C++/Java +- Compute Platform: CUDA 12.1 + +Finally, extract the zip folder and copy all of the DLLs into the `Extensions` folder of your bonsai installation directory. \ No newline at end of file diff --git a/docs/articles/toc.yml b/docs/articles/toc.yml index e22b0b80..625cfcc3 100644 --- a/docs/articles/toc.yml +++ b/docs/articles/toc.yml @@ -13,4 +13,9 @@ - name: Overview href: HiddenMarkovModels/hmm-overview.md - name: Getting Started - href: HiddenMarkovModels/hmm-getting-started.md \ No newline at end of file + href: HiddenMarkovModels/hmm-getting-started.md +- name: Torch +- name: Overview + href: Torch/torch-overview.md +- name: Getting Started + href: Torch/torch-getting-started.md \ No newline at end of file From 2f4d6e31e5078ed322b452f42dabe6c6705b1ff9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 23 Jan 2025 18:38:50 +0000 Subject: [PATCH 073/131] Fixed issue with MNIST model not accepting num classes --- .../NeuralNets/LoadModuleFromArchitecture.cs | 2 +- .../NeuralNets/Models/MNIST.cs | 44 +++++++++++++------ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs index 8276156f..ac791d8d 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs @@ -69,7 +69,7 @@ public IObservable Process() { "alexnet" => new Models.AlexNet(modelArchitecture, numClasses, device), "mobilenet" => new Models.MobileNet(modelArchitecture, numClasses, device), - "mnist" => new Models.MNIST(modelArchitecture, device), + "mnist" => new Models.MNIST(modelArchitecture, numClasses, device), _ => throw new ArgumentException($"Model {modelArchitecture} not supported.") }; diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index e5895d41..8a5f84db 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -9,30 +9,48 @@ namespace Bonsai.ML.Torch.NeuralNets.Models /// public class MNIST : Module { - private readonly Module conv1 = Conv2d(1, 32, 3); - private readonly Module conv2 = Conv2d(32, 64, 3); - private readonly Module fc1 = Linear(9216, 128); - private readonly Module fc2 = Linear(128, 128); + private readonly Module conv1; + private readonly Module conv2; + private readonly Module fc1; + private readonly Module fc2; - private readonly Module pool1 = MaxPool2d(kernel_size: [2, 2]); + private readonly Module pool1; - private readonly Module relu1 = ReLU(); - private readonly Module relu2 = ReLU(); - private readonly Module relu3 = ReLU(); + private readonly Module relu1; + private readonly Module relu2; + private readonly Module relu3; - private readonly Module dropout1 = Dropout(0.25); - private readonly Module dropout2 = Dropout(0.5); + private readonly Module dropout1; + private readonly Module dropout2; - private readonly Module flatten = Flatten(); - private readonly Module logsm = LogSoftmax(1); + private readonly Module flatten; + private readonly Module logsm; /// /// Constructs a new MNIST model. /// /// + /// /// - public MNIST(string name, Device device = null) : base(name) + public MNIST(string name, int numClasses, Device device = null) : base(name) { + conv1 = Conv2d(1, 32, 3); + conv2 = Conv2d(32, 64, 3); + fc1 = Linear(9216, 128); + fc2 = Linear(128, numClasses); + + pool1 = MaxPool2d(kernel_size: [2, 2]); + + relu1 = ReLU(); + relu2 = ReLU(); + relu3 = ReLU(); + + dropout1 = Dropout(0.25); + dropout2 = Dropout(0.5); + + flatten = Flatten(); + logsm = LogSoftmax(1); + RegisterComponents(); if (device != null && device.type != DeviceType.CPU) From d17584f24dcd1b19b2a6ad1f36abd5b8da113218 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 24 Jan 2025 09:02:13 +0000 Subject: [PATCH 074/131] Made slight correction to GPU documentation --- docs/articles/Torch/torch-overview.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/articles/Torch/torch-overview.md b/docs/articles/Torch/torch-overview.md index 4ca6cd08..0884ea01 100644 --- a/docs/articles/Torch/torch-overview.md +++ b/docs/articles/Torch/torch-overview.md @@ -24,4 +24,4 @@ Once complete, you need to install the cuda-compatible torch libraries and place - Language: C++/Java - Compute Platform: CUDA 12.1 -Finally, extract the zip folder and copy all of the DLLs into the `Extensions` folder of your bonsai installation directory. \ No newline at end of file +Finally, extract the zip folder and copy the contents of the `lib` folder into the `Extensions` folder of your bonsai installation directory. \ No newline at end of file From 49eef3843c8d4cc4dd90e8156fa519d3d2e90d6f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Jan 2025 19:07:00 +0000 Subject: [PATCH 075/131] Modified torch module classes to be internel --- src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs | 2 +- src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs | 2 +- src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs index c80a3d50..c3d19d55 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch.NeuralNets.Models /// /// Modified version of original AlexNet to fix CIFAR10 32x32 images. /// - public class AlexNet : Module + internal class AlexNet : Module { private readonly Module features; private readonly Module avgPool; diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index 8a5f84db..8bd3e0a4 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch.NeuralNets.Models /// /// Represents a simple convolutional neural network for the MNIST dataset. /// - public class MNIST : Module + internal class MNIST : Module { private readonly Module conv1; private readonly Module conv2; diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs index 0faa1062..a5f7701a 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -9,7 +9,7 @@ namespace Bonsai.ML.Torch.NeuralNets.Models /// /// MobileNet model. /// - public class MobileNet : Module + internal class MobileNet : Module { private readonly long[] planes = [ 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 ]; private readonly long[] strides = [ 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 ]; From 1f063a92e50a475cb9b4aa32e7d91baf2c6e310a Mon Sep 17 00:00:00 2001 From: David Maas Date: Tue, 28 Jan 2025 19:37:12 -0600 Subject: [PATCH 076/131] Reworked OpenCV <-> TorchSharp conversions to respect the garbage collector --- src/Bonsai.ML.Torch/OpenCVHelper.cs | 121 ++++++++----------------- src/Bonsai.ML.Torch/TorchSharpEx.cs | 133 ++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 85 deletions(-) create mode 100644 src/Bonsai.ML.Torch/TorchSharpEx.cs diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs index 9c7fc0e8..f489e4d8 100644 --- a/src/Bonsai.ML.Torch/OpenCVHelper.cs +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -1,6 +1,4 @@ -using System; -using System.Runtime.InteropServices; -using System.Collections.Concurrent; +using System; using System.Collections.Generic; using System.Linq; using OpenCV.Net; @@ -23,54 +21,6 @@ public static class OpenCVHelper { ScalarType.Int8, (IplDepth.S8, Depth.S8) } }; - private static ConcurrentDictionary deleters = new ConcurrentDictionary(); - - internal delegate void GCHandleDeleter(IntPtr memory); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_data(IntPtr handle); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); - - /// - /// Creates a tensor from a pointer to the data and the dimensions of the tensor. - /// - /// - /// - /// - /// - public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) - { - var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); - var gchp = GCHandle.ToIntPtr(dataHandle); - GCHandleDeleter deleter = null; - - deleter = new GCHandleDeleter((IntPtr ptrHandler) => - { - GCHandle.FromIntPtr(gchp).Free(); - deleters.TryRemove(deleter, out deleter); - }); - deleters.TryAdd(deleter, deleter); - - fixed (long* dimensionsPtr = dimensions) - { - IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - if (tensorHandle == IntPtr.Zero) - { - GC.Collect(); - GC.WaitForPendingFinalizers(); - tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - } - if (tensorHandle == IntPtr.Zero) - { - CheckForErrors(); - } - var output = Tensor.UnsafeCreateTensor(tensorHandle); - return output; - } - } - /// /// Converts an OpenCV image to a Torch tensor. /// @@ -79,10 +29,8 @@ public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dim public static Tensor ToTensor(IplImage image) { if (image == null) - { return empty([ 0, 0, 0 ]); - } - // int width = image.Width; + int height = image.Height; int channels = image.Channels; var width = image.WidthStep / channels; @@ -90,13 +38,12 @@ public static Tensor ToTensor(IplImage image) var iplDepth = image.Depth; var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; - IntPtr tensorDataPtr = image.ImageData; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + IntPtr data = image.ImageData; + ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; + if (data == IntPtr.Zero) + return zeros(dimensions); + + return TorchSharpEx.CreateTensorFromUnmanagedMemoryWithManagedAnchor(data, image, dimensions, tensorType); } /// @@ -107,9 +54,7 @@ public static Tensor ToTensor(IplImage image) public static Tensor ToTensor(Mat mat) { if (mat == null) - { return empty([0, 0, 0 ]); - } int width = mat.Size.Width; int height = mat.Size.Height; @@ -118,13 +63,21 @@ public static Tensor ToTensor(Mat mat) var depth = mat.Depth; var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; - IntPtr tensorDataPtr = mat.Data; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + IntPtr data = mat.Data; + ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; + if (data == IntPtr.Zero) + return zeros(dimensions); + + return TorchSharpEx.CreateTensorFromUnmanagedMemoryWithManagedAnchor(data, mat, dimensions, tensorType); + } + + private static (int height, int width, int channels) GetImageDimensions(this Tensor tensor) + { + if (tensor.dim() != 3) + throw new ArgumentException("The tensor does not have exactly 3 dimensions."); + + checked + { return ((int)tensor.size(0), (int)tensor.size(1), (int)tensor.size(2)); } } /// @@ -134,17 +87,16 @@ public static Tensor ToTensor(Mat mat) /// public unsafe static IplImage ToImage(Tensor tensor) { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; + var (height, width, channels) = tensor.GetImageDimensions(); var tensorType = tensor.dtype; var iplDepth = bitDepthLookup[tensorType].IplDepth; + var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels); - var new_tensor = zeros([height, width, channels], tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); + // Create a temporary tensor backed by the image's memory and copy the source tensor into it + ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; + using var imageTensor = TorchSharpEx.CreateStackTensor(image.ImageData, image, dimensions, tensorType); + imageTensor.Tensor.copy_(tensor); return image; } @@ -156,19 +108,18 @@ public unsafe static IplImage ToImage(Tensor tensor) /// public unsafe static Mat ToMat(Tensor tensor) { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; + var (height, width, channels) = tensor.GetImageDimensions(); var tensorType = tensor.dtype; var depth = bitDepthLookup[tensorType].Depth; + var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels); - var new_tensor = zeros([height, width, channels], tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); + // Create a temporary tensor backed by the matrix's memory and copy the source tensor into it + ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; + using var matTensor = TorchSharpEx.CreateStackTensor(mat.Data, mat, dimensions, tensorType); + matTensor.Tensor.copy_(tensor); return mat; } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/TorchSharpEx.cs b/src/Bonsai.ML.Torch/TorchSharpEx.cs new file mode 100644 index 00000000..6e3bfd23 --- /dev/null +++ b/src/Bonsai.ML.Torch/TorchSharpEx.cs @@ -0,0 +1,133 @@ +#nullable enable +#pragma warning disable CS1573 // Parameter has no matching param tag in the XML comment +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.InteropServices; +using TorchSharp; + +namespace Bonsai.ML.Torch; + +internal unsafe static class TorchSharpEx +{ + [DllImport("LibTorchSharp")] + private static extern IntPtr THSTensor_new(IntPtr rawArray, DeleterCallback deleter, long* dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, byte requires_grad); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private delegate void DeleterCallback(IntPtr context); + + // Torch does not expect the deleter callback to be able to be null since it's a C++ reference and LibTorchSharp does not expose the functions used to create a tensor without a deleter callback, so we must use a no-op callback + private static DeleterCallback NullDeleterCallback = _ => { }; + + // Acts as GC root for unmanaged callbacks, value is unused + private static ConcurrentDictionary ActiveDeleterCallbacks = new(); + + /// Creates a from unmanaged memory that is owned by a managed object + /// The unmanaged memory that will back the tensor, must remain valid and fixed for the lifetime of the tensor + /// The managed .NET object which owns + public static torch.Tensor CreateTensorFromUnmanagedMemoryWithManagedAnchor(IntPtr data, object managedAnchor, ReadOnlySpan dimensions, torch.ScalarType dataType) + { + //PERF: Ideally this would receive the GCHandle as the context rather than the pointer to the unmanaged memory since that's what we actually want to free + // Torch itself has the ability to set the context to something else via `TensorMaker::context(void* value, ContextDeleter deleter)`, but unfortunately this method isn't exposed in LibTorchSharp + // This is the inefficient method TorchSharp uses, which has quite a lot of unecessary overhead (particularly the unmanaged delegate allocation.) + // Some overhead could be removed by looking up the GCHandle from the native pointer, but doing this without breaking the ability to create redundant tensors over the same data is overly complicated. + GCHandle handle = default; + DeleterCallback? deleter = null; + deleter = (data) => + { + if (handle.IsAllocated) + handle.Free(); + + if (!ActiveDeleterCallbacks.TryRemove(deleter!, out _)) + Debug.Fail($"The same tensor data handle deleter was called more than once!"); + }; + + if (!ActiveDeleterCallbacks.TryAdd(deleter, default)) + Debug.Fail("Unreachable"); + + handle = GCHandle.Alloc(managedAnchor); + + bool isInitialized = false; + try + { + fixed (long* dimensionsPtr = &dimensions[0]) + { + IntPtr tensorHandle = THSTensor_new(data, deleter, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + if (tensorHandle == IntPtr.Zero) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(data, deleter, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + } + + if (tensorHandle == IntPtr.Zero) + torch.CheckForErrors(); + + isInitialized = true; + return torch.Tensor.UnsafeCreateTensor(tensorHandle); + } + } + finally + { + if (!isInitialized) + deleter(data); + } + } + + internal readonly ref struct StackTensor + { + public readonly torch.Tensor Tensor; + private readonly object? Anchor; + + internal StackTensor(torch.Tensor tensor, object? anchor) + { + Tensor = tensor; + Anchor = anchor; + } + + public void Dispose() + { + Tensor.Dispose(); + GC.KeepAlive(Anchor); + } + } + + /// Creates a tensor which is associated with a stack scope. + /// The unmanaged memory that will back the tensor, must remain valid and fixed for the lifetime of the tensor + /// An optional managed .NET object which owns + /// + /// The returned stack tensor must be disposed. The tensor it refers to will not be valid outside of the scope where it was allocated. + /// + internal static StackTensor CreateStackTensor(IntPtr data, object? managedAnchor, ReadOnlySpan dimensions, torch.ScalarType dataType) + { + fixed (long* dimensionsPtr = &dimensions[0]) + { + IntPtr tensorHandle = THSTensor_new(data, NullDeleterCallback, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + if (tensorHandle == IntPtr.Zero) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(data, NullDeleterCallback, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + } + + if (tensorHandle == IntPtr.Zero) + torch.CheckForErrors(); + + torch.Tensor result = torch.Tensor.UnsafeCreateTensor(tensorHandle); + return new StackTensor(result, data); + } + } + + /// Gets a pointer to the tensor's backing memory + /// The data backing a tensor is not necessarily contiguous or even present on the CPU, consider other strategies before using this method. + public static IntPtr DangerousGetDataPointer(this torch.Tensor tensor) + { + [DllImport("LibTorchSharp")] + static extern IntPtr THSTensor_data(IntPtr handle); + + IntPtr data = THSTensor_data(tensor.Handle); + if (data == IntPtr.Zero) + torch.CheckForErrors(); + return data; + } +} From 6cdd764d95f382a28489cee9bea417624140d79a Mon Sep 17 00:00:00 2001 From: David Maas Date: Tue, 4 Feb 2025 12:58:31 -0600 Subject: [PATCH 077/131] Removed unnecessary GCHandle usage (the lambda itself makes the anchor reachable instead), turned `IplImage` without backing data into an exception. (Also made some comments not a thousand columns wide.) --- src/Bonsai.ML.Torch/OpenCVHelper.cs | 10 +++-- src/Bonsai.ML.Torch/TorchSharpEx.cs | 57 ++++++++++++----------------- 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs index f489e4d8..31a27ac4 100644 --- a/src/Bonsai.ML.Torch/OpenCVHelper.cs +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -40,8 +40,9 @@ public static Tensor ToTensor(IplImage image) IntPtr data = image.ImageData; ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; - if (data == IntPtr.Zero) - return zeros(dimensions); + + if (data == IntPtr.Zero) + throw new InvalidOperationException($"Got {nameof(IplImage)} without backing data, this isn't expected to be possible."); return TorchSharpEx.CreateTensorFromUnmanagedMemoryWithManagedAnchor(data, image, dimensions, tensorType); } @@ -65,8 +66,9 @@ public static Tensor ToTensor(Mat mat) IntPtr data = mat.Data; ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; - if (data == IntPtr.Zero) - return zeros(dimensions); + + if (data == IntPtr.Zero) + throw new InvalidOperationException($"Got {nameof(Mat)} without backing data, this isn't expected to be possible."); return TorchSharpEx.CreateTensorFromUnmanagedMemoryWithManagedAnchor(data, mat, dimensions, tensorType); } diff --git a/src/Bonsai.ML.Torch/TorchSharpEx.cs b/src/Bonsai.ML.Torch/TorchSharpEx.cs index 6e3bfd23..47077ae4 100644 --- a/src/Bonsai.ML.Torch/TorchSharpEx.cs +++ b/src/Bonsai.ML.Torch/TorchSharpEx.cs @@ -16,27 +16,30 @@ internal unsafe static class TorchSharpEx [UnmanagedFunctionPointer(CallingConvention.Cdecl)] private delegate void DeleterCallback(IntPtr context); - // Torch does not expect the deleter callback to be able to be null since it's a C++ reference and LibTorchSharp does not expose the functions used to create a tensor without a deleter callback, so we must use a no-op callback - private static DeleterCallback NullDeleterCallback = _ => { }; + // Torch does not expect the deleter callback to be able to be null since it's a C++ reference and LibTorchSharp + // does not expose the functions used to create a tensor without a deleter callback, so we must use a no-op callback + private static readonly DeleterCallback NullDeleterCallback = _ => { }; // Acts as GC root for unmanaged callbacks, value is unused - private static ConcurrentDictionary ActiveDeleterCallbacks = new(); + private static readonly ConcurrentDictionary ActiveDeleterCallbacks = new(); /// Creates a from unmanaged memory that is owned by a managed object /// The unmanaged memory that will back the tensor, must remain valid and fixed for the lifetime of the tensor /// The managed .NET object which owns public static torch.Tensor CreateTensorFromUnmanagedMemoryWithManagedAnchor(IntPtr data, object managedAnchor, ReadOnlySpan dimensions, torch.ScalarType dataType) { - //PERF: Ideally this would receive the GCHandle as the context rather than the pointer to the unmanaged memory since that's what we actually want to free - // Torch itself has the ability to set the context to something else via `TensorMaker::context(void* value, ContextDeleter deleter)`, but unfortunately this method isn't exposed in LibTorchSharp - // This is the inefficient method TorchSharp uses, which has quite a lot of unecessary overhead (particularly the unmanaged delegate allocation.) - // Some overhead could be removed by looking up the GCHandle from the native pointer, but doing this without breaking the ability to create redundant tensors over the same data is overly complicated. - GCHandle handle = default; + //PERF: Ideally the deleter would receive the GCHandle as the context rather than the pointer to the unmanaged memory since that's + // would allow us to use a GCHandle to root the anchor and free it directly rather than capturing it in the lambda. + // Torch itself has the ability to set the context to something else via `TensorMaker::context(void* value, ContextDeleter deleter)`, + // but unfortunately this method isn't exposed in LibTorchSharp. + // This is similar to the inefficient method TorchSharp uses, which has quite a lot of unecessary overhead (particularly the unmanaged + // delegate allocation), but we do skip some aspects like the GC handle allocation. + // It may be tempting to use a GCHandle and a static delegate, looking up the GC handle from the native memory pointer, but doing this + // without breaking the ability to create redundant tensors over the same data is overly complicated. DeleterCallback? deleter = null; deleter = (data) => { - if (handle.IsAllocated) - handle.Free(); + GC.KeepAlive(managedAnchor); if (!ActiveDeleterCallbacks.TryRemove(deleter!, out _)) Debug.Fail($"The same tensor data handle deleter was called more than once!"); @@ -45,32 +48,20 @@ public static torch.Tensor CreateTensorFromUnmanagedMemoryWithManagedAnchor(IntP if (!ActiveDeleterCallbacks.TryAdd(deleter, default)) Debug.Fail("Unreachable"); - handle = GCHandle.Alloc(managedAnchor); - - bool isInitialized = false; - try + fixed (long* dimensionsPtr = &dimensions[0]) { - fixed (long* dimensionsPtr = &dimensions[0]) + IntPtr tensorHandle = THSTensor_new(data, deleter, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + if (tensorHandle == IntPtr.Zero) { - IntPtr tensorHandle = THSTensor_new(data, deleter, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); - if (tensorHandle == IntPtr.Zero) - { - GC.Collect(); - GC.WaitForPendingFinalizers(); - tensorHandle = THSTensor_new(data, deleter, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); - } - - if (tensorHandle == IntPtr.Zero) - torch.CheckForErrors(); - - isInitialized = true; - return torch.Tensor.UnsafeCreateTensor(tensorHandle); + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(data, deleter, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); } - } - finally - { - if (!isInitialized) - deleter(data); + + if (tensorHandle == IntPtr.Zero) + torch.CheckForErrors(); + + return torch.Tensor.UnsafeCreateTensor(tensorHandle); } } From 3a68d6b6cbaabd521b0d145cea4a7296376edeee Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 10 Jan 2025 11:58:37 +0000 Subject: [PATCH 078/131] Added point process model package --- Bonsai.ML.sln | 9 +- .../Bonsai.ML.PointProcessDecoder.csproj | 15 + .../CreatePointProcessModel.cs | 429 ++++++++++++++++++ src/Bonsai.ML.PointProcessDecoder/Decode.cs | 35 ++ src/Bonsai.ML.PointProcessDecoder/Encode.cs | 37 ++ .../PointProcessModelDisposable.cs | 26 ++ .../PointProcessModelManager.cs | 78 ++++ .../PointProcessModelNameConverter.cs | 43 ++ 8 files changed, 671 insertions(+), 1 deletion(-) create mode 100644 src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj create mode 100644 src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs create mode 100644 src/Bonsai.ML.PointProcessDecoder/Decode.cs create mode 100644 src/Bonsai.ML.PointProcessDecoder/Encode.cs create mode 100644 src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs create mode 100644 src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs create mode 100644 src/Bonsai.ML.PointProcessDecoder/PointProcessModelNameConverter.cs diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index 30c6b6f1..daaf9e74 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -32,6 +32,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.HiddenMarkovModel EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch", "src\Bonsai.ML.Torch\Bonsai.ML.Torch.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PointProcessDecoder", "src\Bonsai.ML.PointProcessDecoder\Bonsai.ML.PointProcessDecoder.csproj", "{AD32C680-1E8C-4340-81B1-DA19C9104516}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -74,10 +76,14 @@ Global {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.Build.0 = Debug|Any CPU {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.ActiveCfg = Release|Any CPU {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.Build.0 = Release|Any CPU - {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.Build.0 = Debug|Any CPU {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.ActiveCfg = Release|Any CPU {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.Build.0 = Release|Any CPU + {AD32C680-1E8C-4340-81B1-DA19C9104516}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AD32C680-1E8C-4340-81B1-DA19C9104516}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AD32C680-1E8C-4340-81B1-DA19C9104516}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AD32C680-1E8C-4340-81B1-DA19C9104516}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -93,6 +99,7 @@ Global {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290} {17DF50BE-F481-4904-A4C8-5DF9725B2CA1} = {12312384-8828-4786-AE19-EFCEDF968290} {06FCC9AF-CE38-44BB-92B3-0D451BE88537} = {12312384-8828-4786-AE19-EFCEDF968290} + {AD32C680-1E8C-4340-81B1-DA19C9104516} = {12312384-8828-4786-AE19-EFCEDF968290} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {B6468F13-97CD-45E0-9E1E-C122D7F1E09F} diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj new file mode 100644 index 00000000..577017c4 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -0,0 +1,15 @@ + + + + + + + + + Bonsai.ML.PointProcessDecoder + A Bonsai package for running a neural decoder based on point processes using Bayesian state-space models. + Bonsai Rx ML Point Process Neural Decoder + netstandard2.0 + enable + + \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs new file mode 100644 index 00000000..ef9d8aae --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -0,0 +1,429 @@ +using System; +using System.ComponentModel; +using System.Xml.Serialization; +using System.Linq; +using System.Reactive.Linq; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using PointProcessDecoder.Core.Estimation; +using PointProcessDecoder.Core.Transitions; +using PointProcessDecoder.Core.Encoder; +using PointProcessDecoder.Core.Decoder; +using PointProcessDecoder.Core.StateSpace; +using PointProcessDecoder.Core.Likelihood; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Creates a new neural decoding model based on point processes using Bayesian state space models. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Source)] +[Description("Creates a new neural decoding model based on point processes using Bayesian state space models.")] +public class CreatePointProcessModel +{ + private string name = "PointProcessModel"; + + /// + /// Gets or sets the name of the neural decoding model. + /// + [Category("1. Model Parameters")] + [Description("The name of the neural decoding model.")] + public string Name + { + get + { + return name; + } + set + { + name = value; + } + } + + private EstimationMethod estimationMethod = EstimationMethod.KernelDensity; + + /// + /// Gets or sets the estimation method used during the encoding process. + /// + [Category("1. Model Parameters")] + [Description("The estimation method used during the encoding process.")] + public EstimationMethod EstimationMethod + { + get + { + return estimationMethod; + } + set + { + estimationMethod = value; + } + } + + private TransitionsType transitionsType = TransitionsType.RandomWalk; + /// + /// Gets or sets the type of transition model used during the decoding process. + /// + [Category("1. Model Parameters")] + [Description("The type of transition model used during the decoding process.")] + public TransitionsType TransitionsType + { + get + { + return transitionsType; + } + set + { + transitionsType = value; + } + } + + private EncoderType encoderType = EncoderType.SortedSpikeEncoder; + /// + /// Gets or sets the type of encoder used. + /// + [Category("1. Model Parameters")] + [Description("The type of encoder used.")] + public EncoderType EncoderType + { + get + { + return encoderType; + } + set + { + encoderType = value; + } + } + + private DecoderType decoderType = DecoderType.StateSpaceDecoder; + /// + /// Gets or sets the type of decoder used. + /// + [Category("1. Model Parameters")] + [Description("The type of decoder used.")] + public DecoderType DecoderType + { + get + { + return decoderType; + } + set + { + decoderType = value; + } + } + + private StateSpaceType stateSpaceType = StateSpaceType.DiscreteUniformStateSpace; + /// + /// Gets or sets the type of state space used. + /// + [Category("1. Model Parameters")] + [Description("The type of state space used.")] + public StateSpaceType StateSpaceType + { + get + { + return stateSpaceType; + } + set + { + stateSpaceType = value; + } + } + + private LikelihoodType likelihoodType = LikelihoodType.Poisson; + /// + /// Gets or sets the type of likelihood function used. + /// + [Category("1. Model Parameters")] + [Description("The type of likelihood function used.")] + public LikelihoodType LikelihoodType + { + get + { + return likelihoodType; + } + set + { + likelihoodType = value; + } + } + + Device? device = null; + /// + /// Gets or sets the device used to run the neural decoding model. + /// + [XmlIgnore] + [Category("1. Model Parameters")] + [Description("The device used to run the neural decoding model.")] + public Device? Device + { + get + { + return device; + } + set + { + device = value; + } + } + + ScalarType? scalarType = null; + /// + /// Gets or sets the scalar type used to run the neural decoding model. + /// + [Category("1. Model Parameters")] + [Description("The scalar type used to run the neural decoding model.")] + public ScalarType? ScalarType + { + get + { + return scalarType; + } + set + { + scalarType = value; + } + } + + + private int stateSpaceDimensions = 1; + /// + /// Gets or sets the number of dimensions in the state space. + /// + [Category("2. State Space Parameters")] + [Description("The number of dimensions in the state space.")] + public int StateSpaceDimensions + { + get + { + return stateSpaceDimensions; + } + set + { + stateSpaceDimensions = value; + } + } + + private double[] minStateSpace = [0]; + /// + /// Gets or sets the minimum values of the state space. Must be the same length as the number of state space dimensions. + /// + [Category("2. State Space Parameters")] + [Description("The minimum values of the state space. Must be the same length as the number of state space dimensions.")] + public double[] MinStateSpace + { + get + { + return minStateSpace; + } + set + { + minStateSpace = value; + } + } + + private double[] maxStateSpace = [100]; + /// + /// Gets or sets the maximum values of the state space. Must be the same length as the number of state space dimensions. + /// + [Category("2. State Space Parameters")] + [Description("The maximum values of the state space. Must be the same length as the number of state space dimensions.")] + public double[] MaxStateSpace + { + get + { + return maxStateSpace; + } + set + { + maxStateSpace = value; + } + } + + private long[] stepsStateSpace = [50]; + /// + /// Gets or sets the number of steps evaluated in the state space. Must be the same length as the number of state space dimensions. + /// + [Category("2. State Space Parameters")] + [Description("The number of steps evaluated in the state space. Must be the same length as the number of state space dimensions.")] + public long[] StepsStateSpace + { + get + { + return stepsStateSpace; + } + set + { + stepsStateSpace = value; + } + } + + private double[] observationBandwidth = [1]; + /// + /// Gets or sets the bandwidth of the observation estimation method. Must be the same length as the number of state space dimensions. + /// + [Category("2. State Space Parameters")] + [Description("The bandwidth of the observation estimation method. Must be the same length as the number of state space dimensions.")] + public double[] ObservationBandwidth + { + get + { + return observationBandwidth; + } + set + { + observationBandwidth = value; + } + } + + private int? nUnits = null; + /// + /// Gets or sets the number of sorted spiking units. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("The number of sorted spiking units. Only used when the encoder type is set to SortedSpikeEncoder.")] + public int? NUnits + { + get + { + return nUnits; + } + set + { + nUnits = value; + } + } + + private int? markDimensions = null; + /// + /// Gets or sets the number of dimensions or features associated with each mark. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("The number of dimensions or features associated with each mark. Only used when the encoder type is set to ClusterlessMarkEncoder.")] + public int? MarkDimensions + { + get + { + return markDimensions; + } + set + { + markDimensions = value; + } + } + + private int? markChannels = null; + /// + /// Gets or sets the number of mark recording channels. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("The number of mark recording channels. Only used when the encoder type is set to ClusterlessMarkEncoder.")] + public int? MarkChannels + { + get + { + return markChannels; + } + set + { + markChannels = value; + } + } + + private double[]? markBandwidth = null; + /// + /// Gets or sets the bandwidth of the mark estimation method. + /// Must be the same length as the number of mark dimensions. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("The bandwidth of the mark estimation method. Must be the same length as the number of mark dimensions. Only used when the encoder type is set to ClusterlessMarkEncoder.")] + public double[]? MarkBandwidth + { + get + { + return markBandwidth; + } + set + { + markBandwidth = value; + } + } + + private double? distanceThreshold = null; + /// + /// Gets or sets the distance threshold used to determine the threshold to merge unique clusters into a single compressed cluster. + /// Only used when the estimation method is set to . + /// + [Category("4. Estimation Parameters")] + [Description("The distance threshold used to determine the threshold to merge unique clusters into a single compressed cluster. Only used when the estimation method is set to KernelCompression.")] + public double? DistanceThreshold + { + get + { + return distanceThreshold; + } + set + { + distanceThreshold = value; + } + } + + private double? sigmaRandomWalk = null; + /// + /// Gets or sets the standard deviation of the random walk transitions model. + /// Only used when the transitions type is set to . + /// + [Category("5. Transition Parameters")] + [Description("The standard deviation of the random walk transitions model. Only used when the transitions type is set to RandomWalk.")] + public double? SigmaRandomWalk + { + get + { + return sigmaRandomWalk; + } + set + { + sigmaRandomWalk = value; + } + } + + /// + /// Creates a new neural decoding model based on point processes using Bayesian state space models. + /// + /// + public IObservable Process() + { + return Observable.Using( + () => PointProcessModelManager.Reserve( + name: name, + estimationMethod: estimationMethod, + transitionsType: transitionsType, + encoderType: encoderType, + decoderType: decoderType, + stateSpaceType: stateSpaceType, + likelihoodType: likelihoodType, + minStateSpace: minStateSpace, + maxStateSpace: maxStateSpace, + stepsStateSpace: stepsStateSpace, + observationBandwidth: observationBandwidth, + stateSpaceDimensions: stateSpaceDimensions, + markDimensions: markDimensions, + markChannels: markChannels, + markBandwidth: markBandwidth, + nUnits: nUnits, + distanceThreshold: distanceThreshold, + sigmaRandomWalk: sigmaRandomWalk, + device: device, + scalarType: scalarType + ), resource => Observable.Return(resource.Model) + .Concat(Observable.Never(resource.Model))); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/Decode.cs b/src/Bonsai.ML.PointProcessDecoder/Decode.cs new file mode 100644 index 00000000..3e9c7369 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/Decode.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; + +using static TorchSharp.torch; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Decodes the input neural data into a posterior state estimate using a point process model. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Transform)] +[Description("Decodes the input neural data into a posterior state estimate using a point process model.")] +public class Decode +{ + /// + /// The name of the point process model to use. + /// + [TypeConverter(typeof(PointProcessModelNameConverter))] + public string Model { get; set; } = string.Empty; + + /// + /// Decodes the input neural data into a posterior state estimate using a point process model. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => { + var model = PointProcessModelManager.GetModel(Model); + return model.Decode(input); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/Encode.cs b/src/Bonsai.ML.PointProcessDecoder/Encode.cs new file mode 100644 index 00000000..c79fe81e --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/Encode.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; + +using static TorchSharp.torch; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Encodes the combined state observation data and neural data into a point process model. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Sink)] +[Description("Encodes the combined state observation data and neural data into a point process model.")] +public class Encode +{ + /// + /// The name of the point process model to use. + /// + [TypeConverter(typeof(PointProcessModelNameConverter))] + public string Model { get; set; } = string.Empty; + + /// + /// Encodes the combined state observation data and neural data into a point process model. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do(input => + { + var model = PointProcessModelManager.GetModel(Model); + var (neuralData, stateObservations) = input; + model.Encode(neuralData, stateObservations); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs new file mode 100644 index 00000000..d9d938bd --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs @@ -0,0 +1,26 @@ +using System; +using System.Threading; +using PointProcessDecoder.Core; + +namespace Bonsai.ML.PointProcessDecoder; + +sealed class PointProcessModelDisposable : IDisposable +{ + private IDisposable? resource; + public bool IsDispose => resource == null; + + private readonly PointProcessModel model; + public PointProcessModel Model => model; + + public PointProcessModelDisposable(PointProcessModel model, IDisposable disposable) + { + this.model = model ?? throw new ArgumentNullException(nameof(model)); + resource = disposable ?? throw new ArgumentNullException(nameof(disposable)); + } + + public void Dispose() + { + var disposable = Interlocked.Exchange(ref resource, null); + disposable?.Dispose(); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs new file mode 100644 index 00000000..36651626 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -0,0 +1,78 @@ +using System; +using System.Collections.Generic; +using System.Reactive.Disposables; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using PointProcessDecoder.Core.Estimation; +using PointProcessDecoder.Core.Transitions; +using PointProcessDecoder.Core.StateSpace; +using PointProcessDecoder.Core.Encoder; +using PointProcessDecoder.Core.Decoder; +using PointProcessDecoder.Core.Likelihood; + +namespace Bonsai.ML.PointProcessDecoder; + +internal static class PointProcessModelManager +{ + private static readonly Dictionary models = []; + + internal static PointProcessModel GetModel(string name) + { + return models.TryGetValue(name, out var model) ? model : throw new InvalidOperationException($"Model with name {name} not found."); + } + + internal static PointProcessModelDisposable Reserve( + string name, + EstimationMethod estimationMethod, + TransitionsType transitionsType, + EncoderType encoderType, + DecoderType decoderType, + StateSpaceType stateSpaceType, + LikelihoodType likelihoodType, + double[] minStateSpace, + double[] maxStateSpace, + long[] stepsStateSpace, + double[] observationBandwidth, + int stateSpaceDimensions, + int? markDimensions = null, + int? markChannels = null, + double[]? markBandwidth = null, + int? nUnits = null, + double? distanceThreshold = null, + double? sigmaRandomWalk = null, + Device? device = null, + ScalarType? scalarType = null + ) + { + var model = new PointProcessModel( + estimationMethod: estimationMethod, + transitionsType: transitionsType, + encoderType: encoderType, + decoderType: decoderType, + stateSpaceType: stateSpaceType, + likelihoodType: likelihoodType, + minStateSpace: minStateSpace, + maxStateSpace: maxStateSpace, + stepsStateSpace: stepsStateSpace, + observationBandwidth: observationBandwidth, + stateSpaceDimensions: stateSpaceDimensions, + markDimensions: markDimensions, + markChannels: markChannels, + markBandwidth: markBandwidth, + nUnits: nUnits, + distanceThreshold: distanceThreshold, + sigmaRandomWalk: sigmaRandomWalk, + device: device, + scalarType: scalarType + ); + + models.Add(name, model); + + return new PointProcessModelDisposable( + model, + Disposable.Create(() => models.Remove(name)) + ); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelNameConverter.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelNameConverter.cs new file mode 100644 index 00000000..5bef2701 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelNameConverter.cs @@ -0,0 +1,43 @@ +using Bonsai; +using Bonsai.Expressions; +using System.Linq; +using System.ComponentModel; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Provides a type converter to display a list of available point process models. +/// +public class PointProcessModelNameConverter : StringConverter +{ + /// + public override bool GetStandardValuesSupported(ITypeDescriptorContext context) + { + return true; + } + + /// + public override StandardValuesCollection GetStandardValues(ITypeDescriptorContext context) + { + if (context != null) + { + var workflowBuilder = (WorkflowBuilder)context.GetService(typeof(WorkflowBuilder)); + if (workflowBuilder != null) + { + var models = (from builder in workflowBuilder.Workflow.Descendants() + where builder.GetType() != typeof(DisableBuilder) + let createPointProcessModel = ExpressionBuilder.GetWorkflowElement(builder) as CreatePointProcessModel + where createPointProcessModel != null && !string.IsNullOrEmpty(createPointProcessModel.Name) + select createPointProcessModel.Name) + .Distinct() + .ToList(); + if (models.Count > 0) + { + return new StandardValuesCollection(models); + } + } + } + + return new StandardValuesCollection(new string[] { }); + } +} \ No newline at end of file From fd4ea05264967ad93a9055e3c4ae7ee4db3f300b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 13 Jan 2025 12:10:44 +0000 Subject: [PATCH 079/131] Added get model function --- src/Bonsai.ML.PointProcessDecoder/GetModel.cs | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder/GetModel.cs diff --git a/src/Bonsai.ML.PointProcessDecoder/GetModel.cs b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs new file mode 100644 index 00000000..7d5fbe41 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using PointProcessDecoder.Core; +using static TorchSharp.torch; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Returns the point process model. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Transform)] +[Description("Returns the point process model.")] +public class GetModel +{ + /// + /// The name of the point process model to return. + /// + [TypeConverter(typeof(PointProcessModelNameConverter))] + public string Model { get; set; } = string.Empty; + + /// + /// Decodes the input neural data into a posterior state estimate using a point process model. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => { + return PointProcessModelManager.GetModel(Model); + }); + } +} \ No newline at end of file From a911045f13b96a46114fc8468944003467e4f8b7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 13 Jan 2025 12:11:21 +0000 Subject: [PATCH 080/131] Changed "IsDispose" to correct "IsDisposed" use of past tense --- .../PointProcessModelDisposable.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs index d9d938bd..912b603e 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.PointProcessDecoder; sealed class PointProcessModelDisposable : IDisposable { private IDisposable? resource; - public bool IsDispose => resource == null; + public bool IsDisposed => resource == null; private readonly PointProcessModel model; public PointProcessModel Model => model; From 4533762e263b9c20fead5ac4a343eff37b7f5385 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 13 Jan 2025 12:11:42 +0000 Subject: [PATCH 081/131] Added dispose method to model manager --- .../PointProcessModelManager.cs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs index 36651626..17e0ac5c 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -72,7 +72,11 @@ internal static PointProcessModelDisposable Reserve( return new PointProcessModelDisposable( model, - Disposable.Create(() => models.Remove(name)) + Disposable.Create(() => { + model.Dispose(); + model = null; + models.Remove(name); + }) ); } } \ No newline at end of file From 2f2689aa1d8e1fdfeafe379493b9e71709a9ba2a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 15 Jan 2025 18:43:11 +0000 Subject: [PATCH 082/131] Updated comments/docs --- src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs index ef9d8aae..12235ac3 100644 --- a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -27,10 +27,10 @@ public class CreatePointProcessModel private string name = "PointProcessModel"; /// - /// Gets or sets the name of the neural decoding model. + /// Gets or sets the name of the point process model. /// [Category("1. Model Parameters")] - [Description("The name of the neural decoding model.")] + [Description("The name of the point process model.")] public string Name { get From 1f7fe53c6a348561b89f262aef665976b7563b33 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 18:40:49 +0000 Subject: [PATCH 083/131] Added Bonsai.ML.Torch dependency --- .../Bonsai.ML.PointProcessDecoder.csproj | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index 577017c4..de14007d 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -1,6 +1,7 @@ + From 3979a66b3c7f1ac9e2f1535dde0126ec37485a49 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 18:47:51 +0000 Subject: [PATCH 084/131] Moved property group to top for consistency with other packages. Also updated the target frameworks to include net framework --- .../Bonsai.ML.PointProcessDecoder.csproj | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index de14007d..6ff57bc0 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -1,4 +1,11 @@ + + Bonsai.ML.PointProcessDecoder + A Bonsai package for running a neural decoder based on point processes using Bayesian state-space models. + Bonsai Rx ML Point Process Neural Decoder + net472;netstandard2.0 + enable + @@ -6,11 +13,4 @@ - - Bonsai.ML.PointProcessDecoder - A Bonsai package for running a neural decoder based on point processes using Bayesian state-space models. - Bonsai Rx ML Point Process Neural Decoder - netstandard2.0 - enable - \ No newline at end of file From 86a21744528f33710af70dd78ba6c6e6ab86c625 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 24 Jan 2025 08:58:32 +0000 Subject: [PATCH 085/131] Updated dependency version --- .../Bonsai.ML.PointProcessDecoder.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index 6ff57bc0..95e0cabe 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -11,6 +11,6 @@ - + \ No newline at end of file From ff7140dfa89b2fb8dbcf3c1f89f8225d65c74397 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 24 Jan 2025 12:27:53 +0000 Subject: [PATCH 086/131] Added documentation for point process decoder package --- README.md | 3 + .../ppd-getting-started.md | 27 +++++++ docs/articles/toc.yml | 3 + docs/workflows/PointProcessDecoder.bonsai | 73 +++++++++++++++++++ 4 files changed, 106 insertions(+) create mode 100644 docs/articles/PointProcessDecoder/ppd-getting-started.md create mode 100644 docs/workflows/PointProcessDecoder.bonsai diff --git a/README.md b/README.md index 508c9a05..60864b6d 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,9 @@ Facilitates inference using Hidden Markov Models (HMMs). It interfaces with the ### Bonsai.ML.HiddenMarkovModels.Design Visualizers and editor features for the HiddenMarkovModels package. +### Bonsai.ML.PointProcessDecoder +Interfaces with the [PointProcessDecoder](https://github.com/ncguilbeault/PointProcessDecoder) package which can be used for decoding neural activity and other point processes. + ### Bonsai.ML.Torch Interfaces with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# wrapper around the torch library. Provides tooling for manipulating tensors, performing linear algebra, training and inference with deep neural networks, and more. diff --git a/docs/articles/PointProcessDecoder/ppd-getting-started.md b/docs/articles/PointProcessDecoder/ppd-getting-started.md new file mode 100644 index 00000000..7006f682 --- /dev/null +++ b/docs/articles/PointProcessDecoder/ppd-getting-started.md @@ -0,0 +1,27 @@ +# Getting Started + +The `Bonsai.ML.PointProcessDecoder` package provides a Bonsai interface to the [PointProcessDecoder](https://github.com/ncguilbeault/PointProcessDecoder) package used for decoding neural activity (point processes), and relies on the `Bonsai.ML.Torch` package for tensor operations. + +## Installation + +The package can be installed by going to the bonsai package manager and installing the `Bonsai.ML.PointProcessDecoder`. Additional installation steps are required for installing the CPU or GPU version of the `Bonsai.ML.Torch` package. See the [Torch installation guide](../Torch/torch-overview.md) for more information. + +## Package Overview + +The `PointProcessDecoder` package is a C# implementation of a Bayesian state space point process decoder inspired by the [replay_trajectory_classification repository](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification) from the Eden-Kramer Lab. It can decode latent state observations from spike-train data or clusterless mark data based on point processes using Bayesian state space models. + +For more detailed information and documentation about the model, please see the [PointProcessDecoder repo](https://github.com/ncguilbeault/PointProcessDecoder). + +## Bonsai Implementation + +The following workflow showcases the core functionality of the `Bonsai.ML.PointProcessDecoder` package. + +:::workflow +![Point Process Decoder Implementation](~/workflows/PointProcessDecoder.bonsai) +::: + +The `CreatePointProcessModel` node is used to define a model and configure it's parameters. For details on model configuration, please see the [PointProcessDecoder documentation](https://github.com/ncguilbeault/PointProcessDecoder). Crucially, the user must specify the `Name` property in the `Model Parameters` section, as this is what allows you to reference the specific model in the `Encode` and `Decode` nodes, the two main methods that the model will use. + +During encoding, the user passes in a tuple of `Observation`s and `SpikeCounts`. Observations are variables that the user measures (for example, the animal's position), represented as a (M, N) tensor, where M is the number of samples in a batch and N is the dimensionality of the observations. Spike counts are the data you will use for decoding. For instance, spike counts might be sorted spike data or clusterless marks. If the data are sorted spiking units, then the `SpikeCounts` tensor will be an (M, U) tensor, where U is the number of sorted units. If the data are clusterles marks, then the `SpikeCounts` tensor will be an (M, F, C) tensor, where F is the number of features computed for each mark (for instance, the maximum spike amplitude across electrodes), and C is the number of independant recording channels (for example, individual tetrodes). Internally, the model fits the data as a point process, which will be used during decoding. + +Decoding is the process of taking just the `SpikeCounts` and inferring what is the latent `Observation`. To do this, the model uses a bayesian state space model to predict a posterior distribution over the latent observation space using the information contained in the spike counts data. The output of the `Decode` node will provide you with an (M x D*) tensor, with D* representing a discrete latent state space determined by the parameters defined in the model. \ No newline at end of file diff --git a/docs/articles/toc.yml b/docs/articles/toc.yml index 625cfcc3..ddb774cc 100644 --- a/docs/articles/toc.yml +++ b/docs/articles/toc.yml @@ -14,6 +14,9 @@ href: HiddenMarkovModels/hmm-overview.md - name: Getting Started href: HiddenMarkovModels/hmm-getting-started.md +- name: PointProcessDecoder +- name: Getting Started + href: PointProcessDecoder/ppd-getting-started.md - name: Torch - name: Overview href: Torch/torch-overview.md diff --git a/docs/workflows/PointProcessDecoder.bonsai b/docs/workflows/PointProcessDecoder.bonsai new file mode 100644 index 00000000..b13680cd --- /dev/null +++ b/docs/workflows/PointProcessDecoder.bonsai @@ -0,0 +1,73 @@ + + + + + + + PointProcessModel + KernelDensity + RandomWalk + SortedSpikeEncoder + StateSpaceDecoder + DiscreteUniformStateSpace + Poisson + + 2 + + 0 + 0 + + + 120 + 120 + + + 50 + 50 + + + 5 + 5 + + 104 + + + 1.5 + + + + + Observation + + + SpikeCounts + + + + + + + PointProcessModel + + + + SortedSpikeCounts + + + + PointProcessModel + + + + + + + + + + + \ No newline at end of file From 63ce3bed169acef8ca9d5b4c8f00e9d020e7a70e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 27 Jan 2025 16:28:09 +0000 Subject: [PATCH 087/131] Updated point process model to include ignore no spikes parameter --- .../CreatePointProcessModel.cs | 232 ++++++++++-------- .../PointProcessModelManager.cs | 2 + 2 files changed, 127 insertions(+), 107 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs index 12235ac3..c4b5e36b 100644 --- a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -43,76 +43,40 @@ public string Name } } - private EstimationMethod estimationMethod = EstimationMethod.KernelDensity; - - /// - /// Gets or sets the estimation method used during the encoding process. - /// - [Category("1. Model Parameters")] - [Description("The estimation method used during the encoding process.")] - public EstimationMethod EstimationMethod - { - get - { - return estimationMethod; - } - set - { - estimationMethod = value; - } - } - - private TransitionsType transitionsType = TransitionsType.RandomWalk; - /// - /// Gets or sets the type of transition model used during the decoding process. - /// - [Category("1. Model Parameters")] - [Description("The type of transition model used during the decoding process.")] - public TransitionsType TransitionsType - { - get - { - return transitionsType; - } - set - { - transitionsType = value; - } - } - - private EncoderType encoderType = EncoderType.SortedSpikeEncoder; + Device? device = null; /// - /// Gets or sets the type of encoder used. + /// Gets or sets the device used to run the neural decoding model. /// + [XmlIgnore] [Category("1. Model Parameters")] - [Description("The type of encoder used.")] - public EncoderType EncoderType + [Description("The device used to run the neural decoding model.")] + public Device? Device { get { - return encoderType; + return device; } set { - encoderType = value; + device = value; } } - private DecoderType decoderType = DecoderType.StateSpaceDecoder; + ScalarType? scalarType = null; /// - /// Gets or sets the type of decoder used. + /// Gets or sets the scalar type used to run the neural decoding model. /// [Category("1. Model Parameters")] - [Description("The type of decoder used.")] - public DecoderType DecoderType + [Description("The scalar type used to run the neural decoding model.")] + public ScalarType? ScalarType { get { - return decoderType; + return scalarType; } set { - decoderType = value; + scalarType = value; } } @@ -120,7 +84,7 @@ public DecoderType DecoderType /// /// Gets or sets the type of state space used. /// - [Category("1. Model Parameters")] + [Category("2. State Space Parameters")] [Description("The type of state space used.")] public StateSpaceType StateSpaceType { @@ -134,62 +98,6 @@ public StateSpaceType StateSpaceType } } - private LikelihoodType likelihoodType = LikelihoodType.Poisson; - /// - /// Gets or sets the type of likelihood function used. - /// - [Category("1. Model Parameters")] - [Description("The type of likelihood function used.")] - public LikelihoodType LikelihoodType - { - get - { - return likelihoodType; - } - set - { - likelihoodType = value; - } - } - - Device? device = null; - /// - /// Gets or sets the device used to run the neural decoding model. - /// - [XmlIgnore] - [Category("1. Model Parameters")] - [Description("The device used to run the neural decoding model.")] - public Device? Device - { - get - { - return device; - } - set - { - device = value; - } - } - - ScalarType? scalarType = null; - /// - /// Gets or sets the scalar type used to run the neural decoding model. - /// - [Category("1. Model Parameters")] - [Description("The scalar type used to run the neural decoding model.")] - public ScalarType? ScalarType - { - get - { - return scalarType; - } - set - { - scalarType = value; - } - } - - private int stateSpaceDimensions = 1; /// /// Gets or sets the number of dimensions in the state space. @@ -280,6 +188,24 @@ public double[] ObservationBandwidth } } + private EncoderType encoderType = EncoderType.SortedSpikeEncoder; + /// + /// Gets or sets the type of encoder used. + /// + [Category("3. Encoder Parameters")] + [Description("The type of encoder used.")] + public EncoderType EncoderType + { + get + { + return encoderType; + } + set + { + encoderType = value; + } + } + private int? nUnits = null; /// /// Gets or sets the number of sorted spiking units. @@ -357,6 +283,43 @@ public double[]? MarkBandwidth } } + private bool ignoreNoSpikes = false; + /// + /// Gets or sets a value indicating whether to ignore contributions from joint probability distributions with no spikes. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("Indicates whether to ignore contributions from joint probability distributions with no spikes. Only used when the encoder type is set to ClusterlessMarkEncoder.")] + public bool IgnoreNoSpikes + { + get + { + return ignoreNoSpikes; + } + set + { + ignoreNoSpikes = value; + } + } + + private EstimationMethod estimationMethod = EstimationMethod.KernelDensity; + /// + /// Gets or sets the estimation method used during the encoding process. + /// + [Category("4. Estimation Parameters")] + [Description("The estimation method used during the encoding process.")] + public EstimationMethod EstimationMethod + { + get + { + return estimationMethod; + } + set + { + estimationMethod = value; + } + } + private double? distanceThreshold = null; /// /// Gets or sets the distance threshold used to determine the threshold to merge unique clusters into a single compressed cluster. @@ -376,12 +339,48 @@ public double? DistanceThreshold } } + private LikelihoodType likelihoodType = LikelihoodType.Poisson; + /// + /// Gets or sets the type of likelihood function used. + /// + [Category("5. Likelihood Parameters")] + [Description("The type of likelihood function used.")] + public LikelihoodType LikelihoodType + { + get + { + return likelihoodType; + } + set + { + likelihoodType = value; + } + } + + private TransitionsType transitionsType = TransitionsType.RandomWalk; + /// + /// Gets or sets the type of transition model used during the decoding process. + /// + [Category("6. Transition Parameters")] + [Description("The type of transition model used during the decoding process.")] + public TransitionsType TransitionsType + { + get + { + return transitionsType; + } + set + { + transitionsType = value; + } + } + private double? sigmaRandomWalk = null; /// /// Gets or sets the standard deviation of the random walk transitions model. /// Only used when the transitions type is set to . /// - [Category("5. Transition Parameters")] + [Category("6. Transition Parameters")] [Description("The standard deviation of the random walk transitions model. Only used when the transitions type is set to RandomWalk.")] public double? SigmaRandomWalk { @@ -395,6 +394,24 @@ public double? SigmaRandomWalk } } + private DecoderType decoderType = DecoderType.StateSpaceDecoder; + /// + /// Gets or sets the type of decoder used. + /// + [Category("7. Decoder Parameters")] + [Description("The type of decoder used.")] + public DecoderType DecoderType + { + get + { + return decoderType; + } + set + { + decoderType = value; + } + } + /// /// Creates a new neural decoding model based on point processes using Bayesian state space models. /// @@ -418,6 +435,7 @@ public IObservable Process() markDimensions: markDimensions, markChannels: markChannels, markBandwidth: markBandwidth, + ignoreNoSpikes: ignoreNoSpikes, nUnits: nUnits, distanceThreshold: distanceThreshold, sigmaRandomWalk: sigmaRandomWalk, diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs index 17e0ac5c..2403b458 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -39,6 +39,7 @@ internal static PointProcessModelDisposable Reserve( int? markDimensions = null, int? markChannels = null, double[]? markBandwidth = null, + bool ignoreNoSpikes = false, int? nUnits = null, double? distanceThreshold = null, double? sigmaRandomWalk = null, @@ -61,6 +62,7 @@ internal static PointProcessModelDisposable Reserve( markDimensions: markDimensions, markChannels: markChannels, markBandwidth: markBandwidth, + ignoreNoSpikes: ignoreNoSpikes, nUnits: nUnits, distanceThreshold: distanceThreshold, sigmaRandomWalk: sigmaRandomWalk, From 094964dafc8ebbf4f4004f8dfdda6868e4479930 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Jan 2025 11:55:36 +0000 Subject: [PATCH 088/131] Updated point process decoder core package version --- .../Bonsai.ML.PointProcessDecoder.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index 95e0cabe..ceefa4db 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -11,6 +11,6 @@ - + \ No newline at end of file From 094173b6f43386b0bf58a0eed32f26b9de72a2a1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Jan 2025 16:02:41 +0000 Subject: [PATCH 089/131] Update decode method to allow modifying the ignore no spikes property in the clusterless likelihood calcultion --- src/Bonsai.ML.PointProcessDecoder/Decode.cs | 24 ++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Decode.cs b/src/Bonsai.ML.PointProcessDecoder/Decode.cs index 3e9c7369..6a6340a8 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Decode.cs +++ b/src/Bonsai.ML.PointProcessDecoder/Decode.cs @@ -1,7 +1,8 @@ using System; using System.ComponentModel; using System.Reactive.Linq; - +using PointProcessDecoder.Core; +using PointProcessDecoder.Core.Likelihood; using static TorchSharp.torch; namespace Bonsai.ML.PointProcessDecoder; @@ -18,8 +19,25 @@ public class Decode /// The name of the point process model to use. /// [TypeConverter(typeof(PointProcessModelNameConverter))] + [Description("The name of the point process model to use.")] public string Model { get; set; } = string.Empty; + private bool _ignoreNoSpikes = false; + private bool _updateIgnoreNoSpikes = false; + /// + /// Gets or sets a value indicating whether to ignore contributions from no spike events. + /// + [Description("Indicates whether to ignore contributions from no spike events.")] + public bool IgnoreNoSpikes + { + get => _ignoreNoSpikes; + set + { + _ignoreNoSpikes = value; + _updateIgnoreNoSpikes = true; + } + } + /// /// Decodes the input neural data into a posterior state estimate using a point process model. /// @@ -29,6 +47,10 @@ public IObservable Process(IObservable source) { return source.Select(input => { var model = PointProcessModelManager.GetModel(Model); + if (_updateIgnoreNoSpikes && model.Likelihood is ClusterlessLikelihood likelihood) { + likelihood.IgnoreNoSpikes = _ignoreNoSpikes; + _updateIgnoreNoSpikes = false; + } return model.Decode(input); }); } From b0fd4db3f0fb23bc6a948be3ef49790d6655e4d8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Jan 2025 16:03:45 +0000 Subject: [PATCH 090/131] Updated create model node to dispose model using the observable.finally method --- src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs index c4b5e36b..6c53fb6d 100644 --- a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -442,6 +442,7 @@ public IObservable Process() device: device, scalarType: scalarType ), resource => Observable.Return(resource.Model) - .Concat(Observable.Never(resource.Model))); + .Concat(Observable.Never(resource.Model)) + .Finally(resource.Dispose)); } } \ No newline at end of file From 1254dfecb2aef0826a825b74fb71455da1226fc3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Jan 2025 16:04:24 +0000 Subject: [PATCH 091/131] Changed get model to source and add process function for no input --- src/Bonsai.ML.PointProcessDecoder/GetModel.cs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/GetModel.cs b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs index 7d5fbe41..2590d154 100644 --- a/src/Bonsai.ML.PointProcessDecoder/GetModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs @@ -10,7 +10,7 @@ namespace Bonsai.ML.PointProcessDecoder; /// Returns the point process model. /// [Combinator] -[WorkflowElementCategory(ElementCategory.Transform)] +[WorkflowElementCategory(ElementCategory.Source)] [Description("Returns the point process model.")] public class GetModel { @@ -21,11 +21,23 @@ public class GetModel public string Model { get; set; } = string.Empty; /// - /// Decodes the input neural data into a posterior state estimate using a point process model. + /// Returns the point process model. /// + /// + public IObservable Process() + { + return Observable.Defer(() => + Observable.Return(PointProcessModelManager.GetModel(Model)) + ); + } + + /// + /// Returns the point process model. + /// + /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(input => { return PointProcessModelManager.GetModel(Model); From b1e0ca0dda54fc278aec451363d87f14dd30a361 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Jan 2025 16:07:22 +0000 Subject: [PATCH 092/131] Moved ignore no spikes property to likelihood parameters category --- .../CreatePointProcessModel.cs | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs index 6c53fb6d..cef032c1 100644 --- a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -283,25 +283,6 @@ public double[]? MarkBandwidth } } - private bool ignoreNoSpikes = false; - /// - /// Gets or sets a value indicating whether to ignore contributions from joint probability distributions with no spikes. - /// Only used when the encoder type is set to . - /// - [Category("3. Encoder Parameters")] - [Description("Indicates whether to ignore contributions from joint probability distributions with no spikes. Only used when the encoder type is set to ClusterlessMarkEncoder.")] - public bool IgnoreNoSpikes - { - get - { - return ignoreNoSpikes; - } - set - { - ignoreNoSpikes = value; - } - } - private EstimationMethod estimationMethod = EstimationMethod.KernelDensity; /// /// Gets or sets the estimation method used during the encoding process. @@ -357,6 +338,25 @@ public LikelihoodType LikelihoodType } } + private bool ignoreNoSpikes = false; + /// + /// Gets or sets a value indicating whether to ignore contributions from channels with no spikes. + /// Only used when the likelihood type is set to . + /// + [Category("5. Likelihood Parameters")] + [Description("Indicates whether to ignore contributions from channels with no spikes. Only used when the likelihood type is set to Clusterless.")] + public bool IgnoreNoSpikes + { + get + { + return ignoreNoSpikes; + } + set + { + ignoreNoSpikes = value; + } + } + private TransitionsType transitionsType = TransitionsType.RandomWalk; /// /// Gets or sets the type of transition model used during the decoding process. From 29058e2f76103ba1f7b92dc3655e97fdc633e2aa Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Jan 2025 16:13:48 +0000 Subject: [PATCH 093/131] Added property descriptions --- src/Bonsai.ML.PointProcessDecoder/Encode.cs | 1 + src/Bonsai.ML.PointProcessDecoder/GetModel.cs | 1 + .../PointProcessModelDisposable.cs | 18 +++++++++--------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Encode.cs b/src/Bonsai.ML.PointProcessDecoder/Encode.cs index c79fe81e..9a3bf6c8 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Encode.cs +++ b/src/Bonsai.ML.PointProcessDecoder/Encode.cs @@ -18,6 +18,7 @@ public class Encode /// The name of the point process model to use. /// [TypeConverter(typeof(PointProcessModelNameConverter))] + [Description("The name of the point process model to use.")] public string Model { get; set; } = string.Empty; /// diff --git a/src/Bonsai.ML.PointProcessDecoder/GetModel.cs b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs index 2590d154..a9312f4d 100644 --- a/src/Bonsai.ML.PointProcessDecoder/GetModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs @@ -18,6 +18,7 @@ public class GetModel /// The name of the point process model to return. /// [TypeConverter(typeof(PointProcessModelNameConverter))] + [Description("The name of the point process model to return.")] public string Model { get; set; } = string.Empty; /// diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs index 912b603e..af9461d0 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs @@ -4,20 +4,20 @@ namespace Bonsai.ML.PointProcessDecoder; -sealed class PointProcessModelDisposable : IDisposable +internal sealed class PointProcessModelDisposable(PointProcessModel model, IDisposable disposable) : IDisposable { - private IDisposable? resource; + private IDisposable? resource = disposable ?? throw new ArgumentNullException(nameof(disposable)); + /// + /// Gets a value indicating whether the object has been disposed. + /// public bool IsDisposed => resource == null; - private readonly PointProcessModel model; + private readonly PointProcessModel model = model ?? throw new ArgumentNullException(nameof(model)); + /// + /// Gets the point process model. + /// public PointProcessModel Model => model; - public PointProcessModelDisposable(PointProcessModel model, IDisposable disposable) - { - this.model = model ?? throw new ArgumentNullException(nameof(model)); - resource = disposable ?? throw new ArgumentNullException(nameof(disposable)); - } - public void Dispose() { var disposable = Interlocked.Exchange(ref resource, null); From edeed968fe8b0a5cd9695186ffaf9fa9547f91b8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Jan 2025 13:28:24 +0000 Subject: [PATCH 094/131] Update package version --- .../Bonsai.ML.PointProcessDecoder.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index ceefa4db..d1675f4f 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -11,6 +11,6 @@ - + \ No newline at end of file From 93c24ad5730a4d64ba96bb333d00fd2f44ddc4e0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 4 Feb 2025 16:35:27 +0000 Subject: [PATCH 095/131] Update package dependency version --- .../Bonsai.ML.PointProcessDecoder.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index d1675f4f..69301982 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -11,6 +11,6 @@ - + \ No newline at end of file From 1b3d3028f78c7d969a2f716f6337dd70dd0b6997 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 4 Feb 2025 16:35:47 +0000 Subject: [PATCH 096/131] Add kernel limit parameter to model --- .../CreatePointProcessModel.cs | 19 +++++++++++ .../PointProcessModelManager.cs | 33 ++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs index cef032c1..ca07fcb8 100644 --- a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -206,6 +206,24 @@ public EncoderType EncoderType } } + private int? kernelLimit = null; + /// + /// Gets or sets the kernel limit. + /// + [Category("3. Encoder Parameters")] + [Description("The kernel limit.")] + public int? KernelLimit + { + get + { + return kernelLimit; + } + set + { + kernelLimit = value; + } + } + private int? nUnits = null; /// /// Gets or sets the number of sorted spiking units. @@ -439,6 +457,7 @@ public IObservable Process() nUnits: nUnits, distanceThreshold: distanceThreshold, sigmaRandomWalk: sigmaRandomWalk, + kernelLimit: kernelLimit, device: device, scalarType: scalarType ), resource => Observable.Return(resource.Model) diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs index 2403b458..7d6cd444 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -43,11 +43,17 @@ internal static PointProcessModelDisposable Reserve( int? nUnits = null, double? distanceThreshold = null, double? sigmaRandomWalk = null, + int? kernelLimit = null, Device? device = null, ScalarType? scalarType = null ) { - var model = new PointProcessModel( + if (models.TryGetValue(name, out var model)) + { + throw new ArgumentException($"Model with name {nameof(name)} already exists."); + } + + model = new PointProcessModel( estimationMethod: estimationMethod, transitionsType: transitionsType, encoderType: encoderType, @@ -66,6 +72,7 @@ internal static PointProcessModelDisposable Reserve( nUnits: nUnits, distanceThreshold: distanceThreshold, sigmaRandomWalk: sigmaRandomWalk, + kernelLimit: kernelLimit, device: device, scalarType: scalarType ); @@ -81,4 +88,28 @@ internal static PointProcessModelDisposable Reserve( }) ); } + + internal static PointProcessModelDisposable Load( + string name, + string path, + Device? device = null + ) + { + if (models.TryGetValue(name, out var model)) + { + throw new ArgumentException($"Model with name {nameof(name)} already exists."); + } + + model = PointProcessModel.Load(path, device) as PointProcessModel ?? throw new InvalidOperationException("The model could not be loaded."); + models.Add(name, model); + + return new PointProcessModelDisposable( + model, + Disposable.Create(() => { + model.Dispose(); + model = null; + models.Remove(name); + }) + ); + } } \ No newline at end of file From 2141c27cf0b1cd185b3cbca5d8a92da3bbb3b7bf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 4 Feb 2025 16:37:02 +0000 Subject: [PATCH 097/131] Added load model class --- .../LoadPointProcessModel.cs | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder/LoadPointProcessModel.cs diff --git a/src/Bonsai.ML.PointProcessDecoder/LoadPointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/LoadPointProcessModel.cs new file mode 100644 index 00000000..803bbff9 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/LoadPointProcessModel.cs @@ -0,0 +1,100 @@ +using System; +using System.ComponentModel; +using System.Xml.Serialization; +using System.Linq; +using System.Reactive.Linq; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using PointProcessDecoder.Core.Estimation; +using PointProcessDecoder.Core.Transitions; +using PointProcessDecoder.Core.Encoder; +using PointProcessDecoder.Core.Decoder; +using PointProcessDecoder.Core.StateSpace; +using PointProcessDecoder.Core.Likelihood; +using System.IO; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Loads a point process model from a saved state. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Source)] +[Description("Loads a point process model from a saved state.")] +public class LoadPointProcessModel +{ + private string name = "PointProcessModel"; + + /// + /// Gets or sets the name of the point process model. + /// + [Description("The name of the point process model.")] + public string Name + { + get + { + return name; + } + set + { + name = value; + } + } + + Device? device = null; + /// + /// Gets or sets the device used to run the neural decoding model. + /// + [XmlIgnore] + [Description("The device used to run the neural decoding model.")] + public Device? Device + { + get + { + return device; + } + set + { + device = value; + } + } + + /// + /// The path to the folder where the state of the point process model was saved. + /// + [Editor("Bonsai.Design.FolderNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the folder where the state of the point process model was saved.")] + public string Path { get; set; } = string.Empty; + + + /// + /// Creates a new neural decoding model based on point processes using Bayesian state space models. + /// + /// + public IObservable Process() + { + return Observable.Using( + () => { + + if (string.IsNullOrEmpty(Path)) + { + throw new InvalidOperationException("The save path is not specified."); + } + + if (!Directory.Exists(Path)) + { + throw new InvalidOperationException("The save path does not exist."); + } + + return PointProcessModelManager.Load( + name: name, + path: Path, + device: device + ); + }, resource => Observable.Return(resource.Model) + .Concat(Observable.Never(resource.Model)) + .Finally(resource.Dispose)); + } +} \ No newline at end of file From 8d680e6d70f3a9de1852f6d17d1ea55b63a9a65d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 4 Feb 2025 16:40:08 +0000 Subject: [PATCH 098/131] Added save model --- .../SavePointProcessModel.cs | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs diff --git a/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs new file mode 100644 index 00000000..60285c05 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs @@ -0,0 +1,104 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using PointProcessDecoder.Core; +using static TorchSharp.torch; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Saves the state of the point process model. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Sink)] +[Description("Saves the state of the point process model.")] +public class SavePointProcessModel +{ + /// + /// The path to the folder where the state of the point process model will be saved. + /// + [Editor("Bonsai.Design.FolderNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the folder where the state of the point process model will be saved.")] + public string Path { get; set; } = string.Empty; + + /// + /// If true, the contents of the folder will be overwritten if it already exists. + /// + [Description("If true, the contents of the folder will be overwritten if it already exists.")] + public bool Overwrite { get; set; } = false; + + /// + /// Specifies the type of suffix to add to the save path. + /// + [Description("Specifies the type of suffix to add to the save path.")] + public SuffixType AddSuffix { get; set; } = SuffixType.None; + + /// + /// The name of the point process model to save. + /// + [TypeConverter(typeof(PointProcessModelNameConverter))] + [Description("The name of the point process model to save.")] + public string Model { get; set; } = string.Empty; + + /// + /// Saves the state of the point process model. + /// + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Do(_ => { + + var path = AddSuffix switch + { + SuffixType.DateTime => System.IO.Path.Combine(Path, $"{DateTime.Now}"), + SuffixType.Guid => System.IO.Path.Combine(Path, Guid.NewGuid().ToString()), + _ => Path + }; + + if (string.IsNullOrEmpty(path)) + { + throw new InvalidOperationException("The save path is not specified."); + } + + if (System.IO.Directory.Exists(path)) + { + if (Overwrite) + { + System.IO.Directory.Delete(path, true); + } + else + { + throw new InvalidOperationException("The save path already exists and overwrite is set to False."); + } + } + + var model = PointProcessModelManager.GetModel(Model); + + model.Save(path); + }); + } + + /// + /// Specifies the type of suffix to add to the save path. + /// + public enum SuffixType + { + /// + /// No suffix is added to the save path. + /// + None, + + /// + /// A suffix with the current date and time is added to the save path. + /// + DateTime, + + /// + /// A suffix with a unique identifier is added to the save path. + /// + Guid + } +} \ No newline at end of file From 8307f73d5eacdb0293a6b174abb46616e1bbcab9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 6 Feb 2025 16:45:12 +0000 Subject: [PATCH 099/131] Removed try get call in create model --- .../PointProcessModelManager.cs | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs index 7d6cd444..c6e832a1 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -48,12 +48,7 @@ internal static PointProcessModelDisposable Reserve( ScalarType? scalarType = null ) { - if (models.TryGetValue(name, out var model)) - { - throw new ArgumentException($"Model with name {nameof(name)} already exists."); - } - - model = new PointProcessModel( + var model = new PointProcessModel( estimationMethod: estimationMethod, transitionsType: transitionsType, encoderType: encoderType, @@ -95,12 +90,7 @@ internal static PointProcessModelDisposable Load( Device? device = null ) { - if (models.TryGetValue(name, out var model)) - { - throw new ArgumentException($"Model with name {nameof(name)} already exists."); - } - - model = PointProcessModel.Load(path, device) as PointProcessModel ?? throw new InvalidOperationException("The model could not be loaded."); + var model = PointProcessModel.Load(path, device) as PointProcessModel ?? throw new InvalidOperationException("The model could not be loaded."); models.Add(name, model); return new PointProcessModelDisposable( From 6a64264eed2dd080262737f4e64e864ee133024b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 7 Feb 2025 10:34:50 +0000 Subject: [PATCH 100/131] Update suffix string format for datetime suffix type --- src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs index 60285c05..d82c5457 100644 --- a/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs @@ -53,7 +53,7 @@ public IObservable Process(IObservable source) var path = AddSuffix switch { - SuffixType.DateTime => System.IO.Path.Combine(Path, $"{DateTime.Now}"), + SuffixType.DateTime => System.IO.Path.Combine(Path, $"{DateTime.Now:yyyyMMddHHmmss}"), SuffixType.Guid => System.IO.Path.Combine(Path, Guid.NewGuid().ToString()), _ => Path }; From 0f5a1b4a2c7283674b33134e7c3b1269c31be201 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 6 Feb 2025 16:44:16 +0000 Subject: [PATCH 101/131] Added initial visualizers for conditional intensities and kernel estimates --- Bonsai.ML.sln | 45 ++++---- ...onsai.ML.PointProcessDecoder.Design.csproj | 13 +++ .../ConditionalIntensitiesVisualizer.cs | 106 ++++++++++++++++++ .../KernelEstimatesVisualizer.cs | 106 ++++++++++++++++++ 4 files changed, 251 insertions(+), 19 deletions(-) create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/KernelEstimatesVisualizer.cs diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index daaf9e74..9090db9f 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -1,4 +1,4 @@ - + Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio Version 17 VisualStudioVersion = 17.0.31903.59 @@ -34,6 +34,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch", "src\Bons EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PointProcessDecoder", "src\Bonsai.ML.PointProcessDecoder\Bonsai.ML.PointProcessDecoder.csproj", "{AD32C680-1E8C-4340-81B1-DA19C9104516}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PointProcessDecoder.Design", "src\Bonsai.ML.PointProcessDecoder.Design\Bonsai.ML.PointProcessDecoder.Design.csproj", "{91C3E252-9457-43AB-A21A-6064E2404BAA}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -64,26 +66,30 @@ Global {39A4414F-52B1-42D7-82FA-E65DAD885264}.Debug|Any CPU.Build.0 = Debug|Any CPU {39A4414F-52B1-42D7-82FA-E65DAD885264}.Release|Any CPU.ActiveCfg = Release|Any CPU {39A4414F-52B1-42D7-82FA-E65DAD885264}.Release|Any CPU.Build.0 = Release|Any CPU - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.Build.0 = Debug|Any CPU - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.ActiveCfg = Release|Any CPU - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.Build.0 = Release|Any CPU - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Debug|Any CPU.Build.0 = Debug|Any CPU - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Release|Any CPU.ActiveCfg = Release|Any CPU - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Release|Any CPU.Build.0 = Release|Any CPU - {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.Build.0 = Debug|Any CPU - {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.ActiveCfg = Release|Any CPU - {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.Build.0 = Release|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.Build.0 = Release|Any CPU + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Release|Any CPU.Build.0 = Release|Any CPU + {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.Build.0 = Release|Any CPU {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.Build.0 = Debug|Any CPU - {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.ActiveCfg = Release|Any CPU - {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.Build.0 = Release|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.Build.0 = Debug|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.ActiveCfg = Release|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.Build.0 = Release|Any CPU {AD32C680-1E8C-4340-81B1-DA19C9104516}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AD32C680-1E8C-4340-81B1-DA19C9104516}.Debug|Any CPU.Build.0 = Debug|Any CPU {AD32C680-1E8C-4340-81B1-DA19C9104516}.Release|Any CPU.ActiveCfg = Release|Any CPU {AD32C680-1E8C-4340-81B1-DA19C9104516}.Release|Any CPU.Build.0 = Release|Any CPU + {91C3E252-9457-43AB-A21A-6064E2404BAA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {91C3E252-9457-43AB-A21A-6064E2404BAA}.Debug|Any CPU.Build.0 = Debug|Any CPU + {91C3E252-9457-43AB-A21A-6064E2404BAA}.Release|Any CPU.ActiveCfg = Release|Any CPU + {91C3E252-9457-43AB-A21A-6064E2404BAA}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -96,10 +102,11 @@ Global {81DB65B3-EA65-4947-8CF1-0E777324C082} = {461FE3E2-21C4-47F9-8405-DF72326AAB2B} {BAD0A733-8EFB-4EAF-9648-9851656AF7FF} = {12312384-8828-4786-AE19-EFCEDF968290} {39A4414F-52B1-42D7-82FA-E65DAD885264} = {12312384-8828-4786-AE19-EFCEDF968290} - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290} - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1} = {12312384-8828-4786-AE19-EFCEDF968290} - {06FCC9AF-CE38-44BB-92B3-0D451BE88537} = {12312384-8828-4786-AE19-EFCEDF968290} + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290} + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1} = {12312384-8828-4786-AE19-EFCEDF968290} + {06FCC9AF-CE38-44BB-92B3-0D451BE88537} = {12312384-8828-4786-AE19-EFCEDF968290} {AD32C680-1E8C-4340-81B1-DA19C9104516} = {12312384-8828-4786-AE19-EFCEDF968290} + {91C3E252-9457-43AB-A21A-6064E2404BAA} = {12312384-8828-4786-AE19-EFCEDF968290} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {B6468F13-97CD-45E0-9E1E-C122D7F1E09F} diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj b/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj new file mode 100644 index 00000000..1a052001 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj @@ -0,0 +1,13 @@ + + + Bonsai.ML.PointProcessDecoder.Design + A package for visualizing data from the Bonsai.ML.PointProcessDecoder package. + Bonsai Rx ML Machine Learning Point Process Decoder Design + net472 + true + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs new file mode 100644 index 00000000..8a4ba9dd --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs @@ -0,0 +1,106 @@ +using System; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; +using Bonsai; +using Bonsai.Dag; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; +using PointProcessDecoder.Core; +using OxyPlot; +using OxyPlot.Series; +using OxyPlot.Axes; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.ConditionalIntensitiesVisualizer), + Target = typeof(PointProcessModel))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + public class ConditionalIntensitiesVisualizer : MultidimensionalArrayVisualizer + { + private PointProcessModel _model = null; + + /// + public override void Load(IServiceProvider provider) + { + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + _model = ExpressionBuilder.GetWorkflowElement( + expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) + .FirstOrDefault().Value) as DensityCluster; + } + + if (_densityCluster == null) + { + throw new InvalidOperationException("Unable to access the density cluster workflow element."); + } + + if (_densityCluster.Dimensions != 2) + { + throw new InvalidOperationException("The density visualizer can only be used with 2 dimensional data."); + } + + base.Load(provider); + + var showDensityClusterInfoLabel = new ToolStripLabel() + { + Text = "Density Cluster Info: ", + AutoSize = true + }; + + var showDensityClusterInfoCombobox = new ToolStripComboBox() + { + Name = "densityClusterInfoComboBox", + }; + + showDensityClusterInfoCombobox.Items.AddRange([ + "Density Values", + "Cluster Ids", + "Density Labels" + ]); + + showDensityClusterInfoCombobox.SelectedIndexChanged += (sender, e) => + { + var combobox = (ToolStripComboBox)sender; + var selectedIndex = combobox.SelectedIndex; + _getDensityInfo = selectedIndex switch + { + 0 => _densityCluster.GetCellGridDensities, + 1 => _densityCluster.GetCellGridClusterIds, + 2 => _densityCluster.GetCellGridDensityLabels, + _ => throw new InvalidOperationException("Invalid density cluster info selection.") + }; + }; + + var toolStripItems = new ToolStripItem[] { + showDensityClusterInfoLabel, + showDensityClusterInfoCombobox + }; + + _getDensityInfo = _densityCluster.GetCellGridDensities; + + Plot.StatusStrip.Items.AddRange(toolStripItems); + } + + /// + public override void Show(object value) + { + var densityInfo = (double[,])_getDensityInfo(); + if (densityInfo == null || densityInfo.Length == 0) + { + return; + } + base.Show(densityInfo); + } + + /// + public override void Unload() + { + _densityCluster = null; + base.Unload(); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/KernelEstimatesVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/KernelEstimatesVisualizer.cs new file mode 100644 index 00000000..9e1216a2 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/KernelEstimatesVisualizer.cs @@ -0,0 +1,106 @@ +using System; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; +using Bonsai; +using Bonsai.Dag; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; +using PointProcessDecoder.Core; +using OxyPlot; +using OxyPlot.Series; +using OxyPlot.Axes; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.KernelEstimatesVisualizer), + Target = typeof(PointProcessModel))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + public class KernelEstimatesVisualizer : DialogTypeVisualizer + { + private PointProcessModel _model = null; + + /// + public override void Load(IServiceProvider provider) + { + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + _model = ExpressionBuilder.GetWorkflowElement( + expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) + .FirstOrDefault().Value) as DensityCluster; + } + + if (_densityCluster == null) + { + throw new InvalidOperationException("Unable to access the density cluster workflow element."); + } + + if (_densityCluster.Dimensions != 2) + { + throw new InvalidOperationException("The density visualizer can only be used with 2 dimensional data."); + } + + base.Load(provider); + + var showDensityClusterInfoLabel = new ToolStripLabel() + { + Text = "Density Cluster Info: ", + AutoSize = true + }; + + var showDensityClusterInfoCombobox = new ToolStripComboBox() + { + Name = "densityClusterInfoComboBox", + }; + + showDensityClusterInfoCombobox.Items.AddRange([ + "Density Values", + "Cluster Ids", + "Density Labels" + ]); + + showDensityClusterInfoCombobox.SelectedIndexChanged += (sender, e) => + { + var combobox = (ToolStripComboBox)sender; + var selectedIndex = combobox.SelectedIndex; + _getDensityInfo = selectedIndex switch + { + 0 => _densityCluster.GetCellGridDensities, + 1 => _densityCluster.GetCellGridClusterIds, + 2 => _densityCluster.GetCellGridDensityLabels, + _ => throw new InvalidOperationException("Invalid density cluster info selection.") + }; + }; + + var toolStripItems = new ToolStripItem[] { + showDensityClusterInfoLabel, + showDensityClusterInfoCombobox + }; + + _getDensityInfo = _densityCluster.GetCellGridDensities; + + Plot.StatusStrip.Items.AddRange(toolStripItems); + } + + /// + public override void Show(object value) + { + var densityInfo = (double[,])_getDensityInfo(); + if (densityInfo == null || densityInfo.Length == 0) + { + return; + } + base.Show(densityInfo); + } + + /// + public override void Unload() + { + _densityCluster = null; + base.Unload(); + } + } +} \ No newline at end of file From aa611627c246958d52de06c1ac46ebca9373f8e9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 11 Feb 2025 12:30:04 +0000 Subject: [PATCH 102/131] Make GetModel function public for visualizer --- .../PointProcessModelManager.cs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs index c6e832a1..d7fd34f4 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -14,11 +14,14 @@ namespace Bonsai.ML.PointProcessDecoder; -internal static class PointProcessModelManager +/// +/// Manages the point process models. +/// +public static class PointProcessModelManager { private static readonly Dictionary models = []; - internal static PointProcessModel GetModel(string name) + public static PointProcessModel GetModel(string name) { return models.TryGetValue(name, out var model) ? model : throw new InvalidOperationException($"Model with name {name} not found."); } From 48e359f3750ee63d9062a2caf74425ff84a5803a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 11 Feb 2025 12:31:09 +0000 Subject: [PATCH 103/131] Removed kernel estimates visualizer --- .../KernelEstimatesVisualizer.cs | 106 ------------------ 1 file changed, 106 deletions(-) delete mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/KernelEstimatesVisualizer.cs diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/KernelEstimatesVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/KernelEstimatesVisualizer.cs deleted file mode 100644 index 9e1216a2..00000000 --- a/src/Bonsai.ML.PointProcessDecoder.Design/KernelEstimatesVisualizer.cs +++ /dev/null @@ -1,106 +0,0 @@ -using System; -using System.Windows.Forms; -using System.Collections.Generic; -using System.Linq; -using Bonsai; -using Bonsai.Dag; -using Bonsai.Expressions; -using Bonsai.Design; -using Bonsai.ML.Design; -using PointProcessDecoder.Core; -using OxyPlot; -using OxyPlot.Series; -using OxyPlot.Axes; - -[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.KernelEstimatesVisualizer), - Target = typeof(PointProcessModel))] - -namespace Bonsai.ML.PointProcessDecoder.Design -{ - public class KernelEstimatesVisualizer : DialogTypeVisualizer - { - private PointProcessModel _model = null; - - /// - public override void Load(IServiceProvider provider) - { - var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); - var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); - if (expressionBuilderGraph != null && typeVisualizerContext != null) - { - _model = ExpressionBuilder.GetWorkflowElement( - expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) - .FirstOrDefault().Value) as DensityCluster; - } - - if (_densityCluster == null) - { - throw new InvalidOperationException("Unable to access the density cluster workflow element."); - } - - if (_densityCluster.Dimensions != 2) - { - throw new InvalidOperationException("The density visualizer can only be used with 2 dimensional data."); - } - - base.Load(provider); - - var showDensityClusterInfoLabel = new ToolStripLabel() - { - Text = "Density Cluster Info: ", - AutoSize = true - }; - - var showDensityClusterInfoCombobox = new ToolStripComboBox() - { - Name = "densityClusterInfoComboBox", - }; - - showDensityClusterInfoCombobox.Items.AddRange([ - "Density Values", - "Cluster Ids", - "Density Labels" - ]); - - showDensityClusterInfoCombobox.SelectedIndexChanged += (sender, e) => - { - var combobox = (ToolStripComboBox)sender; - var selectedIndex = combobox.SelectedIndex; - _getDensityInfo = selectedIndex switch - { - 0 => _densityCluster.GetCellGridDensities, - 1 => _densityCluster.GetCellGridClusterIds, - 2 => _densityCluster.GetCellGridDensityLabels, - _ => throw new InvalidOperationException("Invalid density cluster info selection.") - }; - }; - - var toolStripItems = new ToolStripItem[] { - showDensityClusterInfoLabel, - showDensityClusterInfoCombobox - }; - - _getDensityInfo = _densityCluster.GetCellGridDensities; - - Plot.StatusStrip.Items.AddRange(toolStripItems); - } - - /// - public override void Show(object value) - { - var densityInfo = (double[,])_getDensityInfo(); - if (densityInfo == null || densityInfo.Length == 0) - { - return; - } - base.Show(densityInfo); - } - - /// - public override void Unload() - { - _densityCluster = null; - base.Unload(); - } - } -} \ No newline at end of file From 454bd513041969a2cedbc10152d89288cc0f90d5 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 11 Feb 2025 12:31:23 +0000 Subject: [PATCH 104/131] Updated conditional intensities visualizer --- .../ConditionalIntensitiesVisualizer.cs | 519 ++++++++++++++++-- 1 file changed, 479 insertions(+), 40 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs index 8a4ba9dd..9b41ffe5 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs +++ b/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs @@ -11,96 +11,535 @@ using OxyPlot; using OxyPlot.Series; using OxyPlot.Axes; +using static TorchSharp.torch; +using TorchSharp; [assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.ConditionalIntensitiesVisualizer), - Target = typeof(PointProcessModel))] + Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] namespace Bonsai.ML.PointProcessDecoder.Design { - public class ConditionalIntensitiesVisualizer : MultidimensionalArrayVisualizer + public class ConditionalIntensitiesVisualizer : DialogTypeVisualizer { + private int _rowCount = 1; + public int RowCount + { + get => _rowCount; + set + { + if (value < 1) + { + throw new InvalidOperationException("The number of rows must be greater than 0."); + } + _rowCount = value; + if (_container != null) + { + _container.RowCount = _rowCount; + } + } + } + + private int _columnCount = 1; + public int ColumnCount + { + get => _columnCount; + set + { + if (value < 1) + { + throw new InvalidOperationException("The number of columns must be greater than 0."); + } + _columnCount = value; + if (_container != null) + { + _container.ColumnCount = _columnCount; + } + } + } + + private int _pageCount = 1; + public int PageCount => _pageCount; + + private int _selectedPageIndex = 0; + public int SelectedPageIndex + { + get => _selectedPageIndex; + set + { + if (value < 0 || value >= _pageCount) + { + throw new InvalidOperationException("The selected page index is out of range."); + } + _selectedPageIndex = value; + } + } + private PointProcessModel _model = null; + private HeatMapSeriesOxyPlotBase[] _heatmapPlots = null; + private long _conditionalIntensitiesCount = 0; + private TableLayoutPanel _container = null; + // create a dictionary to map the index of the heatmap plot to the corresponding conditional intensity + private List _conditionalIntensitiesCumulativeIndex = []; + + private StatusStrip _statusStrip = null; + public StatusStrip StatusStrip => _statusStrip; /// public override void Load(IServiceProvider provider) { + + Decode decodeNode = null; var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); if (expressionBuilderGraph != null && typeVisualizerContext != null) { - _model = ExpressionBuilder.GetWorkflowElement( + decodeNode = ExpressionBuilder.GetWorkflowElement( expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) - .FirstOrDefault().Value) as DensityCluster; + .FirstOrDefault().Value) as Decode; + } + + if (decodeNode == null) + { + Console.WriteLine("The decode node is invalid."); + throw new InvalidOperationException("The decode node is invalid."); + } + + var modelName = decodeNode.Model; + if (string.IsNullOrEmpty(modelName)) + { + Console.WriteLine("The point process model name is not set."); + throw new InvalidOperationException("The point process model name is not set."); } - if (_densityCluster == null) + _model = PointProcessModelManager.GetModel(modelName); + + if (_model == null) { - throw new InvalidOperationException("Unable to access the density cluster workflow element."); + Console.WriteLine($"The point process model with name {modelName} is not found."); + throw new InvalidOperationException($"The point process model with name {modelName} is not found."); } - if (_densityCluster.Dimensions != 2) + if (_model.StateSpace.Dimensions != 2) { - throw new InvalidOperationException("The density visualizer can only be used with 2 dimensional data."); + Console.WriteLine("For the conditional intensities visualizer to work, the state space dimensions must be 2."); + throw new InvalidOperationException("For the conditional intensities visualizer to work, the state space dimensions must be 2."); } + + _container = new TableLayoutPanel() + { + Dock = DockStyle.Fill, + ColumnCount = _columnCount, + RowCount = _rowCount, + }; + + InitializeHeatmaps(); + + _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); + + var pageIndexLabel = new ToolStripLabel($"Page: {SelectedPageIndex}"); + var pageIndexControl = new ToolStripNumericUpDown() + { + Minimum = 0, + Maximum = _pageCount - 1, + DecimalPlaces = 0, + Value = SelectedPageIndex, + }; - base.Load(provider); + // pageIndexControl.ValueChanged += (sender, e) => + // { + // var value = Convert.ToInt32(pageIndexControl.Value); + // try { + // SelectedPageIndex = value; + // UpdatePage(); + // Show(null); + // pageIndexLabel.Text = $"Page: {SelectedPageIndex}"; + // } catch (InvalidOperationException) { + // // pageIndexControl.Value = SelectedPageIndex; + // } + // }; - var showDensityClusterInfoLabel = new ToolStripLabel() + var rowLabel = new ToolStripLabel($"Rows: {RowCount}"); + var rowControl = new ToolStripNumericUpDown() { - Text = "Density Cluster Info: ", - AutoSize = true + Minimum = 1, + DecimalPlaces = 0, + Value = RowCount, }; - var showDensityClusterInfoCombobox = new ToolStripComboBox() + rowControl.ValueChanged += (sender, e) => { - Name = "densityClusterInfoComboBox", + var value = Convert.ToInt32(rowControl.Value); + try + { + RowCount = value; + _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); + UpdateTableLayout(); + Show(null); + rowLabel.Text = $"Rows: {RowCount}"; + } + catch (InvalidOperationException) + { + // rowControl.Value = RowCount; + } + }; + + var columnLabel = new ToolStripLabel($"Columns: {ColumnCount}"); + var columnControl = new ToolStripNumericUpDown() + { + Minimum = 1, + DecimalPlaces = 0, + Value = ColumnCount, }; - showDensityClusterInfoCombobox.Items.AddRange([ - "Density Values", - "Cluster Ids", - "Density Labels" + // columnControl.ValueChanged += (sender, e) => + // { + // var value = Convert.ToInt32(columnControl.Value); + // try + // { + // ColumnCount = value; + // _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); + // UpdatePage(); + // Show(null); + // columnLabel.Text = $"Columns: {ColumnCount}"; + // } + // catch (InvalidOperationException) + // { + // // columnControl.Value = ColumnCount; + // } + // }; + + _statusStrip = new StatusStrip() + { + Visible = true, + }; + + _statusStrip.Items.AddRange([ + pageIndexLabel, + pageIndexControl, + rowLabel, + rowControl, + columnLabel, + columnControl ]); - showDensityClusterInfoCombobox.SelectedIndexChanged += (sender, e) => + UpdateHeatmaps(); + UpdateTableLayout(); + + // _container.Controls.Add(statusStrip); + // _container.MouseClick += (sender, e) => { + // if (e.Button == MouseButtons.Right) { + // statusStrip.Visible = !statusStrip.Visible; + // } + // }; + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + visualizerService?.AddControl(_container); + visualizerService?.AddControl(_statusStrip); + } + + private bool InitializeHeatmaps() + { + if (_model.Encoder.ConditionalIntensities.Length == 0 || (_model.Encoder.ConditionalIntensities.Length == 1 && _model.Encoder.ConditionalIntensities[0].numel() == 0)) + { + return false; + } + + _conditionalIntensitiesCount = 0; + for (int i = 0; i < _model.Encoder.ConditionalIntensities.Length; i++) { + var ci = _model.Encoder.ConditionalIntensities[i].clone(); + if (ci.IsInvalid) continue; + if (ci.numel() > 0) { + _conditionalIntensitiesCount += ci.size(0); + } + _conditionalIntensitiesCumulativeIndex.Add(ci.shape[0] + _conditionalIntensitiesCumulativeIndex.LastOrDefault()); + } + + _heatmapPlots = new HeatMapSeriesOxyPlotBase[_conditionalIntensitiesCount]; + for (int i = 0; i < _conditionalIntensitiesCount; i++) { - var combobox = (ToolStripComboBox)sender; - var selectedIndex = combobox.SelectedIndex; - _getDensityInfo = selectedIndex switch + _heatmapPlots[i] = new HeatMapSeriesOxyPlotBase(0, 0) { - 0 => _densityCluster.GetCellGridDensities, - 1 => _densityCluster.GetCellGridClusterIds, - 2 => _densityCluster.GetCellGridDensityLabels, - _ => throw new InvalidOperationException("Invalid density cluster info selection.") + Dock = DockStyle.Fill, }; - }; - - var toolStripItems = new ToolStripItem[] { - showDensityClusterInfoLabel, - showDensityClusterInfoCombobox - }; + } + + return true; + } + + private void UpdateHeatmaps() + { + _container.Controls.Clear(); + // update heatmaps in container + for (int i = 0; i < _rowCount; i++) + { + for (int j = 0; j < _columnCount; j++) + { + var index = SelectedPageIndex * _rowCount * _columnCount + i * _columnCount + j; + if (index >= _conditionalIntensitiesCount) + { + break; + } + + _container.Controls.Add(_heatmapPlots[index], j, i); + } + } + } + + private void UpdateTableLayout() + { + var oldRowCount = _container.RowCount; + var oldColumnCount = _container.ColumnCount; + + // update rows in container + if (_rowCount > oldRowCount) + { + foreach (RowStyle rowStyle in _container.RowStyles) + { + rowStyle.SizeType = SizeType.Percent; + rowStyle.Height = 100f / _rowCount; + } + + _container.RowCount = _rowCount; + for (int i = 0; i < _rowCount - oldRowCount; i++) + { + _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / _rowCount)); + for (int j = 0; j < _columnCount; j++) + { + var index = SelectedPageIndex * _rowCount * _columnCount + (oldRowCount * oldColumnCount) + (i * _columnCount) + j; + if (index >= _conditionalIntensitiesCount) + { + break; + } + + _container.Controls.Add(_heatmapPlots[index], j, i); + } + } + } + else if (_rowCount < oldRowCount) + { + _container.RowCount = _rowCount; + for (int i = _rowCount; i < oldRowCount; i++) + { + _container.RowStyles.RemoveAt(_rowCount); + for (int j = 0; j < _columnCount; j++) + { + var index = SelectedPageIndex * _rowCount * _columnCount + (i * _columnCount) + j; + _container.Controls.Remove(_heatmapPlots[index]); + } + } + } - _getDensityInfo = _densityCluster.GetCellGridDensities; + // update columns in container + if (_columnCount > oldColumnCount) + { + foreach (ColumnStyle columnStyle in _container.ColumnStyles) + { + columnStyle.SizeType = SizeType.Percent; + columnStyle.Width = 100f / _columnCount; + } + _container.ColumnCount = _columnCount; + for (int i = 0; i < _columnCount - oldColumnCount; i++) + { + _container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / _columnCount)); + } - Plot.StatusStrip.Items.AddRange(toolStripItems); + // remove heatmaps in old rows + for (int i = 1; i < _rowCount; i++) + { + for (int j = 0; j < oldColumnCount; j++) + { + var index = SelectedPageIndex * _rowCount * _columnCount + (i * oldColumnCount) + j; + if (index >= _conditionalIntensitiesCount) + { + break; + } + + _container.Controls.Remove(_heatmapPlots[index]); + } + } + + // move heatmaps to the new columns + for (int i = 0; i < _rowCount; i++) + { + _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / _rowCount)); + for (int j = 0; j < _columnCount; j++) + { + if (i == 0 && j < oldColumnCount) + { + continue; + } + + var index = SelectedPageIndex * _rowCount * _columnCount + (i * _columnCount) + j; + if (index >= _conditionalIntensitiesCount) + { + break; + } + + _container.Controls.Add(_heatmapPlots[index], j, i); + } + } + } + else if (_columnCount < oldColumnCount) + { + var oldColumnCount = _container.ColumnCount; + _container.ColumnCount = _columnCount; + for (int i = _columnCount; i < oldColumnCount; i++) + { + _container.ColumnStyles.RemoveAt(_columnCount); + } + } + } + + private (int ConditionalIntensitiesIndex, int ConditionalIntensitiesTensorIndex) GetConditionalIntensitiesIndex(int index) + { + + var conditionalIntensitiesIndex = 0; + for (int i = 0; i < _conditionalIntensitiesCumulativeIndex.Count; i++) + { + if (index < _conditionalIntensitiesCumulativeIndex[i]) + { + conditionalIntensitiesIndex = i; + break; + } + } + var conditionalIntensitiesTensorIndex = conditionalIntensitiesIndex == 0 ? index : index - _conditionalIntensitiesCumulativeIndex[conditionalIntensitiesIndex - 1]; + return (conditionalIntensitiesIndex, (int)conditionalIntensitiesTensorIndex); } /// public override void Show(object value) { - var densityInfo = (double[,])_getDensityInfo(); - if (densityInfo == null || densityInfo.Length == 0) + if (_heatmapPlots == null) + { + var success = InitializeHeatmaps(); + if (!success) + { + throw new InvalidOperationException("The conditional intensities are empty."); + } + } + + var startIndex = SelectedPageIndex * _rowCount * _columnCount; + var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _conditionalIntensitiesCount); + + for (int i = startIndex; i < endIndex; i++) { - return; + var (conditionalIntensitiesIndex, conditionalIntensitiesTensorIndex) = GetConditionalIntensitiesIndex(i); + + Array heatmapValues; + + var conditionalIntensity = _model.Encoder.ConditionalIntensities[conditionalIntensitiesIndex][conditionalIntensitiesTensorIndex] + .clone() + .to(CPU); + + if (conditionalIntensity.IsInvalid) + { + continue; + } + + if (conditionalIntensity.Dimensions == 3) { + conditionalIntensity = conditionalIntensity.sum(dim: 1); + Console.WriteLine($"ConditionalIntensity after sum: {conditionalIntensity}"); + } + Console.WriteLine($"ConditionalIntensity: {conditionalIntensity}"); + var conditionalIntensityValues = conditionalIntensity + .to_type(ScalarType.Float64) + .data(); + + heatmapValues = conditionalIntensityValues + .ToNDArray(); + + Console.WriteLine(heatmapValues.Length); + var heatmap = new double[_model.StateSpace.Shape[0], _model.StateSpace.Shape[1]]; + Buffer.BlockCopy(heatmapValues, 0, heatmap, 0, heatmapValues.Length * sizeof(double)); + + _heatmapPlots[i].UpdateHeatMapSeries( + 0, + _model.StateSpace.Shape[0], + 0, + _model.StateSpace.Shape[1], + heatmap + ); + + _heatmapPlots[i].UpdatePlot(); } - base.Show(densityInfo); } /// public override void Unload() { - _densityCluster = null; - base.Unload(); + if (_container != null) + { + if (!_container.IsDisposed) + { + _container.Dispose(); + } + _container = null; + } + + for (int i = 0; i < _heatmapPlots.Length; i++) + { + if (!_heatmapPlots[i].IsDisposed) + { + _heatmapPlots[i].Dispose(); + } + } + _heatmapPlots = null; + + if (_model != null) + { + _model.Dispose(); + _model = null; + } + + _conditionalIntensitiesCount = 0; + _rowCount = 1; + _columnCount = 1; + _pageCount = 1; + _selectedPageIndex = 0; + _conditionalIntensitiesCumulativeIndex.Clear(); + } + } + + class ToolStripNumericUpDown : ToolStripControlHost + { + public ToolStripNumericUpDown() + : base(new NumericUpDown()) + { + } + + public NumericUpDown NumericUpDown + { + get { return Control as NumericUpDown; } + } + + public int DecimalPlaces + { + get { return NumericUpDown.DecimalPlaces; } + set { NumericUpDown.DecimalPlaces = value; } + } + + public decimal Value + { + get { return NumericUpDown.Value; } + set { NumericUpDown.Value = value; } + } + + public decimal Minimum + { + get { return NumericUpDown.Minimum; } + set { NumericUpDown.Minimum = value; } + } + + public decimal Maximum + { + get { return NumericUpDown.Maximum; } + set { NumericUpDown.Maximum = value; } + } + + public event EventHandler ValueChanged + { + add { NumericUpDown.ValueChanged += value; } + remove { NumericUpDown.ValueChanged -= value; } } } } \ No newline at end of file From 65a0bb789dad7db90ba465f2d5c2fa35736b9624 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 11 Feb 2025 16:36:31 +0000 Subject: [PATCH 105/131] Adding all changes from altering dispose methods --- .../ConditionalIntensitiesVisualizer.cs | 404 +++++++----------- .../Bonsai.ML.PointProcessDecoder.csproj | 2 +- src/Bonsai.ML.PointProcessDecoder/Decode.cs | 3 +- src/Bonsai.ML.PointProcessDecoder/Encode.cs | 3 +- src/Bonsai.ML.PointProcessDecoder/GetModel.cs | 3 +- .../PointProcessModelManager.cs | 4 - 6 files changed, 170 insertions(+), 249 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs index 9b41ffe5..98108f34 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs +++ b/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs @@ -13,13 +13,14 @@ using OxyPlot.Axes; using static TorchSharp.torch; using TorchSharp; +using System.Reactive; [assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.ConditionalIntensitiesVisualizer), Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] namespace Bonsai.ML.PointProcessDecoder.Design { - public class ConditionalIntensitiesVisualizer : DialogTypeVisualizer + public class ConditionalIntensitiesVisualizer : BufferedVisualizer { private int _rowCount = 1; public int RowCount @@ -27,15 +28,15 @@ public int RowCount get => _rowCount; set { - if (value < 1) - { - throw new InvalidOperationException("The number of rows must be greater than 0."); - } + // if (value < 1) + // { + // throw new InvalidOperationException("The number of rows must be greater than 0."); + // } _rowCount = value; - if (_container != null) - { - _container.RowCount = _rowCount; - } + // if (_container != null) + // { + // _container.RowCount = _rowCount; + // } } } @@ -45,49 +46,49 @@ public int ColumnCount get => _columnCount; set { - if (value < 1) - { - throw new InvalidOperationException("The number of columns must be greater than 0."); - } + // if (value < 1) + // { + // throw new InvalidOperationException("The number of columns must be greater than 0."); + // } _columnCount = value; - if (_container != null) - { - _container.ColumnCount = _columnCount; - } + // if (_container != null) + // { + // _container.ColumnCount = _columnCount; + // } } } - private int _pageCount = 1; - public int PageCount => _pageCount; - private int _selectedPageIndex = 0; public int SelectedPageIndex { get => _selectedPageIndex; set { - if (value < 0 || value >= _pageCount) - { - throw new InvalidOperationException("The selected page index is out of range."); - } + // if (value < 0 || value >= _pageCount) + // { + // throw new InvalidOperationException("The selected page index is out of range."); + // } _selectedPageIndex = value; } } + private int _pageCount = 1; + private string _modelName = string.Empty; private PointProcessModel _model = null; private HeatMapSeriesOxyPlotBase[] _heatmapPlots = null; private long _conditionalIntensitiesCount = 0; private TableLayoutPanel _container = null; // create a dictionary to map the index of the heatmap plot to the corresponding conditional intensity private List _conditionalIntensitiesCumulativeIndex = []; - private StatusStrip _statusStrip = null; public StatusStrip StatusStrip => _statusStrip; + private ToolStripNumericUpDown _pageIndexControl = null; + private ToolStripNumericUpDown _rowControl = null; + private ToolStripNumericUpDown _columnControl = null; /// public override void Load(IServiceProvider provider) { - Decode decodeNode = null; var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); @@ -97,47 +98,29 @@ public override void Load(IServiceProvider provider) expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) .FirstOrDefault().Value) as Decode; } - + if (decodeNode == null) { Console.WriteLine("The decode node is invalid."); throw new InvalidOperationException("The decode node is invalid."); } - var modelName = decodeNode.Model; - if (string.IsNullOrEmpty(modelName)) + _modelName = decodeNode.Model; + if (string.IsNullOrEmpty(_modelName)) { Console.WriteLine("The point process model name is not set."); throw new InvalidOperationException("The point process model name is not set."); } - - _model = PointProcessModelManager.GetModel(modelName); - - if (_model == null) - { - Console.WriteLine($"The point process model with name {modelName} is not found."); - throw new InvalidOperationException($"The point process model with name {modelName} is not found."); - } - - if (_model.StateSpace.Dimensions != 2) - { - Console.WriteLine("For the conditional intensities visualizer to work, the state space dimensions must be 2."); - throw new InvalidOperationException("For the conditional intensities visualizer to work, the state space dimensions must be 2."); - } _container = new TableLayoutPanel() { Dock = DockStyle.Fill, - ColumnCount = _columnCount, - RowCount = _rowCount, + ColumnCount = ColumnCount, + RowCount = RowCount, }; - InitializeHeatmaps(); - - _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); - var pageIndexLabel = new ToolStripLabel($"Page: {SelectedPageIndex}"); - var pageIndexControl = new ToolStripNumericUpDown() + _pageIndexControl = new ToolStripNumericUpDown() { Minimum = 0, Maximum = _pageCount - 1, @@ -145,68 +128,76 @@ public override void Load(IServiceProvider provider) Value = SelectedPageIndex, }; - // pageIndexControl.ValueChanged += (sender, e) => - // { - // var value = Convert.ToInt32(pageIndexControl.Value); - // try { - // SelectedPageIndex = value; - // UpdatePage(); - // Show(null); - // pageIndexLabel.Text = $"Page: {SelectedPageIndex}"; - // } catch (InvalidOperationException) { - // // pageIndexControl.Value = SelectedPageIndex; - // } - // }; + _pageIndexControl.ValueChanged += (sender, e) => + { + var value = Convert.ToInt32(_pageIndexControl.Value); + try { + SelectedPageIndex = value; + UpdateTableLayout(); + Show(null); + pageIndexLabel.Text = $"Page: {SelectedPageIndex}"; + } catch (InvalidOperationException) { } + }; var rowLabel = new ToolStripLabel($"Rows: {RowCount}"); - var rowControl = new ToolStripNumericUpDown() + _rowControl = new ToolStripNumericUpDown() { Minimum = 1, DecimalPlaces = 0, Value = RowCount, }; - rowControl.ValueChanged += (sender, e) => + _rowControl.ValueChanged += (sender, e) => { - var value = Convert.ToInt32(rowControl.Value); + var value = Convert.ToInt32(_rowControl.Value); try { RowCount = value; _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); - UpdateTableLayout(); - Show(null); + if (SelectedPageIndex >= _pageCount) + { + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Maximum = _pageCount - 1; + _pageIndexControl.Value = SelectedPageIndex; + } + else + { + UpdateTableLayout(); + Show(null); + } rowLabel.Text = $"Rows: {RowCount}"; - } - catch (InvalidOperationException) - { - // rowControl.Value = RowCount; - } + } catch (InvalidOperationException) { } }; var columnLabel = new ToolStripLabel($"Columns: {ColumnCount}"); - var columnControl = new ToolStripNumericUpDown() + _columnControl = new ToolStripNumericUpDown() { Minimum = 1, DecimalPlaces = 0, Value = ColumnCount, }; - // columnControl.ValueChanged += (sender, e) => - // { - // var value = Convert.ToInt32(columnControl.Value); - // try - // { - // ColumnCount = value; - // _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); - // UpdatePage(); - // Show(null); - // columnLabel.Text = $"Columns: {ColumnCount}"; - // } - // catch (InvalidOperationException) - // { - // // columnControl.Value = ColumnCount; - // } - // }; + _columnControl.ValueChanged += (sender, e) => + { + var value = Convert.ToInt32(_columnControl.Value); + try + { + ColumnCount = value; + _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); + if (SelectedPageIndex >= _pageCount) + { + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Maximum = SelectedPageIndex; + _pageIndexControl.Value = SelectedPageIndex; + } + else + { + UpdateTableLayout(); + Show(null); + } + columnLabel.Text = $"Columns: {ColumnCount}"; + } catch (InvalidOperationException) { } + }; _statusStrip = new StatusStrip() { @@ -215,43 +206,63 @@ public override void Load(IServiceProvider provider) _statusStrip.Items.AddRange([ pageIndexLabel, - pageIndexControl, + _pageIndexControl, rowLabel, - rowControl, + _rowControl, columnLabel, - columnControl + _columnControl ]); - UpdateHeatmaps(); UpdateTableLayout(); - // _container.Controls.Add(statusStrip); - // _container.MouseClick += (sender, e) => { - // if (e.Button == MouseButtons.Right) { - // statusStrip.Visible = !statusStrip.Visible; - // } - // }; - var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); visualizerService?.AddControl(_container); visualizerService?.AddControl(_statusStrip); } - private bool InitializeHeatmaps() + private bool InitializeModel() { + try { + _model = PointProcessModelManager.GetModel(_modelName); + } catch { + return false; + } + + if (_model == null) + { + return false; + } + + if (_model.StateSpace.Dimensions != 2) + { + throw new InvalidOperationException("For the conditional intensities visualizer to work, the state space dimensions must be 2."); + } + if (_model.Encoder.ConditionalIntensities.Length == 0 || (_model.Encoder.ConditionalIntensities.Length == 1 && _model.Encoder.ConditionalIntensities[0].numel() == 0)) { + _model = null; return false; } + InitializeHeatmaps(); + + return true; + } + + private bool InitializeHeatmaps() + { _conditionalIntensitiesCount = 0; - for (int i = 0; i < _model.Encoder.ConditionalIntensities.Length; i++) { - var ci = _model.Encoder.ConditionalIntensities[i].clone(); - if (ci.IsInvalid) continue; - if (ci.numel() > 0) { - _conditionalIntensitiesCount += ci.size(0); + try { + for (int i = 0; i < _model.Encoder.ConditionalIntensities.Length; i++) { + var ci = _model.Encoder.ConditionalIntensities[i]; + if (_model.Encoder.ConditionalIntensities[i].numel() > 0) { + var size = _model.Encoder.ConditionalIntensities[i].size(0); + _conditionalIntensitiesCount += size; + _conditionalIntensitiesCumulativeIndex.Add(size + _conditionalIntensitiesCumulativeIndex.LastOrDefault()); + } } - _conditionalIntensitiesCumulativeIndex.Add(ci.shape[0] + _conditionalIntensitiesCumulativeIndex.LastOrDefault()); + } catch { + return false; } _heatmapPlots = new HeatMapSeriesOxyPlotBase[_conditionalIntensitiesCount]; @@ -263,129 +274,42 @@ private bool InitializeHeatmaps() }; } + _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); + _pageIndexControl.Maximum = _pageCount - 1; + return true; } - private void UpdateHeatmaps() + private void UpdateTableLayout() { _container.Controls.Clear(); - // update heatmaps in container - for (int i = 0; i < _rowCount; i++) - { - for (int j = 0; j < _columnCount; j++) - { - var index = SelectedPageIndex * _rowCount * _columnCount + i * _columnCount + j; - if (index >= _conditionalIntensitiesCount) - { - break; - } + _container.RowStyles.Clear(); + _container.ColumnStyles.Clear(); - _container.Controls.Add(_heatmapPlots[index], j, i); - } - } - } + _container.RowCount = RowCount; + _container.ColumnCount = ColumnCount; - private void UpdateTableLayout() - { - var oldRowCount = _container.RowCount; - var oldColumnCount = _container.ColumnCount; - - // update rows in container - if (_rowCount > oldRowCount) + for (int i = 0; i < RowCount; i++) { - foreach (RowStyle rowStyle in _container.RowStyles) - { - rowStyle.SizeType = SizeType.Percent; - rowStyle.Height = 100f / _rowCount; - } - - _container.RowCount = _rowCount; - for (int i = 0; i < _rowCount - oldRowCount; i++) - { - _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / _rowCount)); - for (int j = 0; j < _columnCount; j++) - { - var index = SelectedPageIndex * _rowCount * _columnCount + (oldRowCount * oldColumnCount) + (i * _columnCount) + j; - if (index >= _conditionalIntensitiesCount) - { - break; - } - - _container.Controls.Add(_heatmapPlots[index], j, i); - } - } + _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / RowCount)); } - else if (_rowCount < oldRowCount) + + for (int i = 0; i < ColumnCount; i++) { - _container.RowCount = _rowCount; - for (int i = _rowCount; i < oldRowCount; i++) - { - _container.RowStyles.RemoveAt(_rowCount); - for (int j = 0; j < _columnCount; j++) - { - var index = SelectedPageIndex * _rowCount * _columnCount + (i * _columnCount) + j; - _container.Controls.Remove(_heatmapPlots[index]); - } - } + _container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / ColumnCount)); } - // update columns in container - if (_columnCount > oldColumnCount) + for (int i = 0; i < RowCount; i++) { - foreach (ColumnStyle columnStyle in _container.ColumnStyles) - { - columnStyle.SizeType = SizeType.Percent; - columnStyle.Width = 100f / _columnCount; - } - _container.ColumnCount = _columnCount; - for (int i = 0; i < _columnCount - oldColumnCount; i++) + for (int j = 0; j < ColumnCount; j++) { - _container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / _columnCount)); - } - - // remove heatmaps in old rows - for (int i = 1; i < _rowCount; i++) - { - for (int j = 0; j < oldColumnCount; j++) + var index = SelectedPageIndex * RowCount * ColumnCount + i * ColumnCount + j; + if (index >= _conditionalIntensitiesCount) { - var index = SelectedPageIndex * _rowCount * _columnCount + (i * oldColumnCount) + j; - if (index >= _conditionalIntensitiesCount) - { - break; - } - - _container.Controls.Remove(_heatmapPlots[index]); + break; } - } - // move heatmaps to the new columns - for (int i = 0; i < _rowCount; i++) - { - _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / _rowCount)); - for (int j = 0; j < _columnCount; j++) - { - if (i == 0 && j < oldColumnCount) - { - continue; - } - - var index = SelectedPageIndex * _rowCount * _columnCount + (i * _columnCount) + j; - if (index >= _conditionalIntensitiesCount) - { - break; - } - - _container.Controls.Add(_heatmapPlots[index], j, i); - } - } - } - else if (_columnCount < oldColumnCount) - { - var oldColumnCount = _container.ColumnCount; - _container.ColumnCount = _columnCount; - for (int i = _columnCount; i < oldColumnCount; i++) - { - _container.ColumnStyles.RemoveAt(_columnCount); + _container.Controls.Add(_heatmapPlots[index], j, i); } } } @@ -406,16 +330,23 @@ private void UpdateTableLayout() return (conditionalIntensitiesIndex, (int)conditionalIntensitiesTensorIndex); } + protected override void ShowBuffer(IList> values) + { + try { Show(values.LastOrDefault().Value); } + catch { } + } + /// public override void Show(object value) { - if (_heatmapPlots == null) + if (_model is null) { - var success = InitializeHeatmaps(); + var success = InitializeModel(); if (!success) { - throw new InvalidOperationException("The conditional intensities are empty."); + return; } + UpdateTableLayout(); } var startIndex = SelectedPageIndex * _rowCount * _columnCount; @@ -427,28 +358,25 @@ public override void Show(object value) Array heatmapValues; - var conditionalIntensity = _model.Encoder.ConditionalIntensities[conditionalIntensitiesIndex][conditionalIntensitiesTensorIndex] - .clone() - .to(CPU); + try { + var conditionalIntensity = _model.Encoder.ConditionalIntensities[conditionalIntensitiesIndex][conditionalIntensitiesTensorIndex]; - if (conditionalIntensity.IsInvalid) - { - continue; - } + if (conditionalIntensity.Dimensions == 2) { + conditionalIntensity = conditionalIntensity + .sum(dim: 0) + .exp(); + } - if (conditionalIntensity.Dimensions == 3) { - conditionalIntensity = conditionalIntensity.sum(dim: 1); - Console.WriteLine($"ConditionalIntensity after sum: {conditionalIntensity}"); - } - Console.WriteLine($"ConditionalIntensity: {conditionalIntensity}"); - var conditionalIntensityValues = conditionalIntensity - .to_type(ScalarType.Float64) - .data(); + var conditionalIntensityValues = conditionalIntensity + .to_type(ScalarType.Float64) + .data(); - heatmapValues = conditionalIntensityValues - .ToNDArray(); + heatmapValues = conditionalIntensityValues + .ToNDArray(); + } catch { + throw new InvalidOperationException("Error while updating the heatmap."); + } - Console.WriteLine(heatmapValues.Length); var heatmap = new double[_model.StateSpace.Shape[0], _model.StateSpace.Shape[1]]; Buffer.BlockCopy(heatmapValues, 0, heatmap, 0, heatmapValues.Length * sizeof(double)); @@ -484,13 +412,7 @@ public override void Unload() } } _heatmapPlots = null; - - if (_model != null) - { - _model.Dispose(); - _model = null; - } - + _model = null; _conditionalIntensitiesCount = 0; _rowCount = 1; _columnCount = 1; diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index 69301982..c756f16f 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -11,6 +11,6 @@ - + \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/Decode.cs b/src/Bonsai.ML.PointProcessDecoder/Decode.cs index 6a6340a8..fdaf9307 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Decode.cs +++ b/src/Bonsai.ML.PointProcessDecoder/Decode.cs @@ -45,8 +45,9 @@ public bool IgnoreNoSpikes /// public IObservable Process(IObservable source) { + var modelName = Model; return source.Select(input => { - var model = PointProcessModelManager.GetModel(Model); + var model = PointProcessModelManager.GetModel(modelName); if (_updateIgnoreNoSpikes && model.Likelihood is ClusterlessLikelihood likelihood) { likelihood.IgnoreNoSpikes = _ignoreNoSpikes; _updateIgnoreNoSpikes = false; diff --git a/src/Bonsai.ML.PointProcessDecoder/Encode.cs b/src/Bonsai.ML.PointProcessDecoder/Encode.cs index 9a3bf6c8..cd64404c 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Encode.cs +++ b/src/Bonsai.ML.PointProcessDecoder/Encode.cs @@ -28,9 +28,10 @@ public class Encode /// public IObservable> Process(IObservable> source) { + var modelName = Model; return source.Do(input => { - var model = PointProcessModelManager.GetModel(Model); + var model = PointProcessModelManager.GetModel(modelName); var (neuralData, stateObservations) = input; model.Encode(neuralData, stateObservations); }); diff --git a/src/Bonsai.ML.PointProcessDecoder/GetModel.cs b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs index a9312f4d..37ab955c 100644 --- a/src/Bonsai.ML.PointProcessDecoder/GetModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs @@ -40,8 +40,9 @@ public IObservable Process() /// public IObservable Process(IObservable source) { + var modelName = Model; return source.Select(input => { - return PointProcessModelManager.GetModel(Model); + return PointProcessModelManager.GetModel(modelName); }); } } \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs index d7fd34f4..629594ef 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -80,8 +80,6 @@ internal static PointProcessModelDisposable Reserve( return new PointProcessModelDisposable( model, Disposable.Create(() => { - model.Dispose(); - model = null; models.Remove(name); }) ); @@ -99,8 +97,6 @@ internal static PointProcessModelDisposable Load( return new PointProcessModelDisposable( model, Disposable.Create(() => { - model.Dispose(); - model = null; models.Remove(name); }) ); From 3d25d6bc1eb6818bc91d0cd0e312e9e772385af3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Feb 2025 17:29:14 +0000 Subject: [PATCH 106/131] Moved toolstrip numeric up down class to seperate file --- .../ToolStripNumericUpDown.cs | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/ToolStripNumericUpDown.cs diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/ToolStripNumericUpDown.cs b/src/Bonsai.ML.PointProcessDecoder.Design/ToolStripNumericUpDown.cs new file mode 100644 index 00000000..80e38e8b --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/ToolStripNumericUpDown.cs @@ -0,0 +1,47 @@ +using System; +using System.Windows.Forms; + +namespace Bonsai.ML.PointProcessDecoder.Design; + +internal class ToolStripNumericUpDown : ToolStripControlHost +{ + public ToolStripNumericUpDown() + : base(new NumericUpDown()) + { + } + + public NumericUpDown NumericUpDown + { + get { return Control as NumericUpDown; } + } + + public int DecimalPlaces + { + get { return NumericUpDown.DecimalPlaces; } + set { NumericUpDown.DecimalPlaces = value; } + } + + public decimal Value + { + get { return NumericUpDown.Value; } + set { NumericUpDown.Value = value; } + } + + public decimal Minimum + { + get { return NumericUpDown.Minimum; } + set { NumericUpDown.Minimum = value; } + } + + public decimal Maximum + { + get { return NumericUpDown.Maximum; } + set { NumericUpDown.Maximum = value; } + } + + public event EventHandler ValueChanged + { + add { NumericUpDown.ValueChanged += value; } + remove { NumericUpDown.ValueChanged -= value; } + } +} From ed2932727728ded7bf0e0eb3cadfd63554f4e2b6 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Feb 2025 17:29:48 +0000 Subject: [PATCH 107/131] Fixed issue with visualizers not displaying properly --- .../ConditionalIntensitiesVisualizer.cs | 384 ++++++++---------- 1 file changed, 175 insertions(+), 209 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs index 98108f34..47b02540 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs +++ b/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs @@ -1,26 +1,25 @@ using System; +using System.Reactive.Linq; +using System.Reactive; using System.Windows.Forms; using System.Collections.Generic; using System.Linq; + using Bonsai; -using Bonsai.Dag; using Bonsai.Expressions; using Bonsai.Design; using Bonsai.ML.Design; -using PointProcessDecoder.Core; -using OxyPlot; -using OxyPlot.Series; -using OxyPlot.Axes; + using static TorchSharp.torch; -using TorchSharp; -using System.Reactive; + +using PointProcessDecoder.Core; [assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.ConditionalIntensitiesVisualizer), Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] namespace Bonsai.ML.PointProcessDecoder.Design { - public class ConditionalIntensitiesVisualizer : BufferedVisualizer + public class ConditionalIntensitiesVisualizer : DialogTypeVisualizer { private int _rowCount = 1; public int RowCount @@ -28,15 +27,11 @@ public int RowCount get => _rowCount; set { - // if (value < 1) - // { - // throw new InvalidOperationException("The number of rows must be greater than 0."); - // } + if (value < 1) + { + throw new InvalidOperationException("The number of rows must be greater than 0."); + } _rowCount = value; - // if (_container != null) - // { - // _container.RowCount = _rowCount; - // } } } @@ -46,15 +41,11 @@ public int ColumnCount get => _columnCount; set { - // if (value < 1) - // { - // throw new InvalidOperationException("The number of columns must be greater than 0."); - // } + if (value < 1) + { + throw new InvalidOperationException("The number of columns must be greater than 0."); + } _columnCount = value; - // if (_container != null) - // { - // _container.ColumnCount = _columnCount; - // } } } @@ -64,27 +55,25 @@ public int SelectedPageIndex get => _selectedPageIndex; set { - // if (value < 0 || value >= _pageCount) - // { - // throw new InvalidOperationException("The selected page index is out of range."); - // } _selectedPageIndex = value; } } + private readonly int _sampleFrequency = 30; private int _pageCount = 1; private string _modelName = string.Empty; - private PointProcessModel _model = null; - private HeatMapSeriesOxyPlotBase[] _heatmapPlots = null; - private long _conditionalIntensitiesCount = 0; + private List _heatmapPlots = null; + private int _conditionalIntensitiesCount = 0; private TableLayoutPanel _container = null; - // create a dictionary to map the index of the heatmap plot to the corresponding conditional intensity - private List _conditionalIntensitiesCumulativeIndex = []; + private readonly List _conditionalIntensitiesCumulativeIndex = []; private StatusStrip _statusStrip = null; public StatusStrip StatusStrip => _statusStrip; private ToolStripNumericUpDown _pageIndexControl = null; private ToolStripNumericUpDown _rowControl = null; private ToolStripNumericUpDown _columnControl = null; + private Tensor[] _conditionalIntensities = null; + private long _stateSpaceWidth; + private long _stateSpaceHeight; /// public override void Load(IServiceProvider provider) @@ -116,87 +105,74 @@ public override void Load(IServiceProvider provider) { Dock = DockStyle.Fill, ColumnCount = ColumnCount, - RowCount = RowCount, + RowCount = _rowCount, }; - var pageIndexLabel = new ToolStripLabel($"Page: {SelectedPageIndex}"); + var pageIndexLabel = new ToolStripLabel($"Page: {_selectedPageIndex}"); _pageIndexControl = new ToolStripNumericUpDown() { Minimum = 0, - Maximum = _pageCount - 1, DecimalPlaces = 0, - Value = SelectedPageIndex, + Value = _selectedPageIndex, }; _pageIndexControl.ValueChanged += (sender, e) => { var value = Convert.ToInt32(_pageIndexControl.Value); - try { - SelectedPageIndex = value; - UpdateTableLayout(); - Show(null); - pageIndexLabel.Text = $"Page: {SelectedPageIndex}"; - } catch (InvalidOperationException) { } + SelectedPageIndex = value; + UpdateTableLayout(); + Show(null); + pageIndexLabel.Text = $"Page: {_selectedPageIndex}"; }; - var rowLabel = new ToolStripLabel($"Rows: {RowCount}"); + var rowLabel = new ToolStripLabel($"Rows: {_rowCount}"); _rowControl = new ToolStripNumericUpDown() { Minimum = 1, DecimalPlaces = 0, - Value = RowCount, + Value = _rowCount, }; _rowControl.ValueChanged += (sender, e) => { - var value = Convert.ToInt32(_rowControl.Value); - try + RowCount = Convert.ToInt32(_rowControl.Value); + UpdatePages(); + if (_selectedPageIndex >= _pageCount) { - RowCount = value; - _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); - if (SelectedPageIndex >= _pageCount) - { - SelectedPageIndex = _pageCount - 1; - _pageIndexControl.Maximum = _pageCount - 1; - _pageIndexControl.Value = SelectedPageIndex; - } - else - { - UpdateTableLayout(); - Show(null); - } - rowLabel.Text = $"Rows: {RowCount}"; - } catch (InvalidOperationException) { } + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Value = _selectedPageIndex; + } + else + { + UpdateTableLayout(); + Show(null); + } + rowLabel.Text = $"Rows: {_rowCount}"; }; - var columnLabel = new ToolStripLabel($"Columns: {ColumnCount}"); + var columnLabel = new ToolStripLabel($"Columns: {_columnCount}"); _columnControl = new ToolStripNumericUpDown() { Minimum = 1, DecimalPlaces = 0, - Value = ColumnCount, + Value = _columnCount, }; _columnControl.ValueChanged += (sender, e) => { - var value = Convert.ToInt32(_columnControl.Value); - try + ColumnCount = Convert.ToInt32(_columnControl.Value); + UpdatePages(); + if (_selectedPageIndex >= _pageCount) { - ColumnCount = value; - _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); - if (SelectedPageIndex >= _pageCount) - { - SelectedPageIndex = _pageCount - 1; - _pageIndexControl.Maximum = SelectedPageIndex; - _pageIndexControl.Value = SelectedPageIndex; - } - else - { - UpdateTableLayout(); - Show(null); - } - columnLabel.Text = $"Columns: {ColumnCount}"; - } catch (InvalidOperationException) { } + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Value = _selectedPageIndex; + } + else + { + UpdateTableLayout(); + Show(null); + } + columnLabel.Text = $"Columns: {_columnCount}"; }; _statusStrip = new StatusStrip() @@ -213,69 +189,97 @@ public override void Load(IServiceProvider provider) _columnControl ]); - UpdateTableLayout(); - var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); visualizerService?.AddControl(_container); visualizerService?.AddControl(_statusStrip); } - private bool InitializeModel() + private void UpdatePages() { + _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); + _pageIndexControl.Maximum = _pageCount - 1; + } + + private bool UpdateModel() + { + PointProcessModel model; + try { - _model = PointProcessModelManager.GetModel(_modelName); + model = PointProcessModelManager.GetModel(_modelName); } catch { return false; } - if (_model == null) + if (model == null) { return false; } - if (_model.StateSpace.Dimensions != 2) + if (model.StateSpace.Dimensions != 2) { throw new InvalidOperationException("For the conditional intensities visualizer to work, the state space dimensions must be 2."); } - if (_model.Encoder.ConditionalIntensities.Length == 0 || (_model.Encoder.ConditionalIntensities.Length == 1 && _model.Encoder.ConditionalIntensities[0].numel() == 0)) + if (model.Encoder.ConditionalIntensities.Length == 0 || (model.Encoder.ConditionalIntensities.Length == 1 && model.Encoder.ConditionalIntensities[0].numel() == 0)) { - _model = null; return false; } - InitializeHeatmaps(); + _conditionalIntensities = model.Encoder.ConditionalIntensities; + _stateSpaceWidth = model.StateSpace.Shape[0]; + _stateSpaceHeight = model.StateSpace.Shape[1]; return true; } - private bool InitializeHeatmaps() + private static int GetConditionalIntensitiesCount(Tensor[] conditionalIntensities, List conditionalIntensitiesCumulativeIndex) { - _conditionalIntensitiesCount = 0; - try { - for (int i = 0; i < _model.Encoder.ConditionalIntensities.Length; i++) { - var ci = _model.Encoder.ConditionalIntensities[i]; - if (_model.Encoder.ConditionalIntensities[i].numel() > 0) { - var size = _model.Encoder.ConditionalIntensities[i].size(0); - _conditionalIntensitiesCount += size; - _conditionalIntensitiesCumulativeIndex.Add(size + _conditionalIntensitiesCumulativeIndex.LastOrDefault()); - } + long conditionalIntensitiesCount = 0; + conditionalIntensitiesCumulativeIndex.Clear(); + for (int i = 0; i < conditionalIntensities.Length; i++) { + if (conditionalIntensities[i].numel() > 0) { + conditionalIntensitiesCount += conditionalIntensities[i].size(0); + conditionalIntensitiesCumulativeIndex.Add(conditionalIntensitiesCount); } - } catch { - return false; } + return (int)conditionalIntensitiesCount; + } - _heatmapPlots = new HeatMapSeriesOxyPlotBase[_conditionalIntensitiesCount]; - for (int i = 0; i < _conditionalIntensitiesCount; i++) + private bool UpdateHeatmaps() + { + if (_heatmapPlots is null) { - _heatmapPlots[i] = new HeatMapSeriesOxyPlotBase(0, 0) + _heatmapPlots = []; + for (int i = 0; i < _conditionalIntensitiesCount; i++) { - Dock = DockStyle.Fill, - }; + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(0, 0) + { + Dock = DockStyle.Fill, + }); + } + } + else if (_heatmapPlots.Count > _conditionalIntensitiesCount) + { + var count = _heatmapPlots.Count - _conditionalIntensitiesCount; + for (int i = 0; i < count; i++) + { + if (!_heatmapPlots[i + _conditionalIntensitiesCount].IsDisposed) + { + _heatmapPlots[i + _conditionalIntensitiesCount].Dispose(); + } + } + _heatmapPlots.RemoveRange(_conditionalIntensitiesCount, count); + } + else if (_heatmapPlots.Count < _conditionalIntensitiesCount) + { + for (int i = _heatmapPlots.Count; i < _conditionalIntensitiesCount; i++) + { + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(0, 0) + { + Dock = DockStyle.Fill, + }); + } } - - _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); - _pageIndexControl.Maximum = _pageCount - 1; return true; } @@ -286,24 +290,24 @@ private void UpdateTableLayout() _container.RowStyles.Clear(); _container.ColumnStyles.Clear(); - _container.RowCount = RowCount; - _container.ColumnCount = ColumnCount; + _container.RowCount = _rowCount; + _container.ColumnCount = _columnCount; - for (int i = 0; i < RowCount; i++) + for (int i = 0; i < _rowCount; i++) { - _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / RowCount)); + _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / _rowCount)); } - for (int i = 0; i < ColumnCount; i++) + for (int i = 0; i < _columnCount; i++) { - _container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / ColumnCount)); + _container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / _columnCount)); } - for (int i = 0; i < RowCount; i++) + for (int i = 0; i < _rowCount; i++) { - for (int j = 0; j < ColumnCount; j++) + for (int j = 0; j < _columnCount; j++) { - var index = SelectedPageIndex * RowCount * ColumnCount + i * ColumnCount + j; + var index = SelectedPageIndex * _rowCount * _columnCount + i * _columnCount + j; if (index >= _conditionalIntensitiesCount) { break; @@ -330,25 +334,9 @@ private void UpdateTableLayout() return (conditionalIntensitiesIndex, (int)conditionalIntensitiesTensorIndex); } - protected override void ShowBuffer(IList> values) - { - try { Show(values.LastOrDefault().Value); } - catch { } - } - /// public override void Show(object value) - { - if (_model is null) - { - var success = InitializeModel(); - if (!success) - { - return; - } - UpdateTableLayout(); - } - + { var startIndex = SelectedPageIndex * _rowCount * _columnCount; var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _conditionalIntensitiesCount); @@ -356,36 +344,24 @@ public override void Show(object value) { var (conditionalIntensitiesIndex, conditionalIntensitiesTensorIndex) = GetConditionalIntensitiesIndex(i); - Array heatmapValues; + var conditionalIntensity = _conditionalIntensities[conditionalIntensitiesIndex][conditionalIntensitiesTensorIndex]; - try { - var conditionalIntensity = _model.Encoder.ConditionalIntensities[conditionalIntensitiesIndex][conditionalIntensitiesTensorIndex]; - - if (conditionalIntensity.Dimensions == 2) { - conditionalIntensity = conditionalIntensity - .sum(dim: 0) - .exp(); - } + if (conditionalIntensity.Dimensions == 2) { + conditionalIntensity = conditionalIntensity + .sum(dim: 0); + } - var conditionalIntensityValues = conditionalIntensity - .to_type(ScalarType.Float64) - .data(); - heatmapValues = conditionalIntensityValues - .ToNDArray(); - } catch { - throw new InvalidOperationException("Error while updating the heatmap."); - } - var heatmap = new double[_model.StateSpace.Shape[0], _model.StateSpace.Shape[1]]; - Buffer.BlockCopy(heatmapValues, 0, heatmap, 0, heatmapValues.Length * sizeof(double)); + var conditionalIntensityValues = (double[,])conditionalIntensity + .exp() + .to_type(ScalarType.Float64) + .reshape([_stateSpaceWidth, _stateSpaceHeight]) + .data() + .ToNDArray(); _heatmapPlots[i].UpdateHeatMapSeries( - 0, - _model.StateSpace.Shape[0], - 0, - _model.StateSpace.Shape[1], - heatmap + conditionalIntensityValues ); _heatmapPlots[i].UpdatePlot(); @@ -404,64 +380,54 @@ public override void Unload() _container = null; } - for (int i = 0; i < _heatmapPlots.Length; i++) + if (_heatmapPlots != null) { - if (!_heatmapPlots[i].IsDisposed) + for (int i = 0; i < _heatmapPlots.Count; i++) { - _heatmapPlots[i].Dispose(); + if (!_heatmapPlots[i].IsDisposed) + { + _heatmapPlots[i].Dispose(); + } } - } - _heatmapPlots = null; - _model = null; + _heatmapPlots = null; + }; + _conditionalIntensitiesCount = 0; - _rowCount = 1; - _columnCount = 1; - _pageCount = 1; - _selectedPageIndex = 0; _conditionalIntensitiesCumulativeIndex.Clear(); - } - } - - class ToolStripNumericUpDown : ToolStripControlHost - { - public ToolStripNumericUpDown() - : base(new NumericUpDown()) - { - } - - public NumericUpDown NumericUpDown - { - get { return Control as NumericUpDown; } - } - - public int DecimalPlaces - { - get { return NumericUpDown.DecimalPlaces; } - set { NumericUpDown.DecimalPlaces = value; } - } - - public decimal Value - { - get { return NumericUpDown.Value; } - set { NumericUpDown.Value = value; } - } - - public decimal Minimum - { - get { return NumericUpDown.Minimum; } - set { NumericUpDown.Minimum = value; } + _conditionalIntensities = null; } - public decimal Maximum + public override IObservable Visualize(IObservable> source, IServiceProvider provider) { - get { return NumericUpDown.Maximum; } - set { NumericUpDown.Maximum = value; } - } + if (provider.GetService(typeof(IDialogTypeVisualizerService)) is not Control visualizerControl) + { + return source; + } - public event EventHandler ValueChanged - { - add { NumericUpDown.ValueChanged += value; } - remove { NumericUpDown.ValueChanged -= value; } + return source.SelectMany(input => + input.Sample(TimeSpan.FromMilliseconds(_sampleFrequency)) + .ObserveOn(visualizerControl) + .Do(value => + { + var success = UpdateModel(); + if (!success) + { + return; + } + + var newConditionalIntensitiesCount = GetConditionalIntensitiesCount(_conditionalIntensities, _conditionalIntensitiesCumulativeIndex); + if (_conditionalIntensitiesCount != newConditionalIntensitiesCount) + { + _conditionalIntensitiesCount = newConditionalIntensitiesCount; + UpdatePages(); + UpdateHeatmaps(); + UpdateTableLayout(); + } + + Show(value); + } + ) + ); } } } \ No newline at end of file From 5faf9a78cdd6ce10cccb0c7ca30304f998eb7695 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:16:04 +0000 Subject: [PATCH 108/131] Update heatmap series plot to make visualizer drop down public --- src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs index d6944816..74d82202 100644 --- a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs +++ b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs @@ -36,6 +36,12 @@ public class HeatMapSeriesOxyPlotBase : UserControl private ToolStripTextBox minValueTextBox; private ToolStripLabel minValueLabel; + private ToolStripDropDownButton _visualizerPropertiesDropDown; + /// + /// Gets the visualizer properties drop down button. + /// + public ToolStripDropDownButton VisualizerPropertiesDropDown => _visualizerPropertiesDropDown; + private int _numColors = 100; /// @@ -126,14 +132,14 @@ private void Initialize() minValueTextBox }; - ToolStripDropDownButton visualizerPropertiesButton = new ToolStripDropDownButton("Visualizer Properties"); + _visualizerPropertiesDropDown = new ToolStripDropDownButton("Visualizer Properties"); foreach (var item in toolStripItems) { - visualizerPropertiesButton.DropDownItems.Add(item); + _visualizerPropertiesDropDown.DropDownItems.Add(item); } - statusStrip.Items.Add(visualizerPropertiesButton); + statusStrip.Items.Add(_visualizerPropertiesDropDown); Controls.Add(statusStrip); view.MouseClick += new MouseEventHandler(onMouseClick); From e15854bab875072d1d3fbe0c995c10e9f8cdd7cf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:18:53 +0000 Subject: [PATCH 109/131] Update heatmap series to expose plot model and plot view publically. --- .../HeatMapSeriesOxyPlotBase.cs | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs index 74d82202..cdc80150 100644 --- a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs +++ b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs @@ -14,8 +14,18 @@ namespace Bonsai.ML.Design /// public class HeatMapSeriesOxyPlotBase : UserControl { - private PlotView view; - private PlotModel model; + private PlotView _view; + /// + /// Gets the plot view of the control. + /// + public PlotView View => _view; + + private PlotModel _model; + /// + /// Gets the plot model of the control. + /// + public PlotModel Model => _model; + private HeatMapSeries heatMapSeries; private LinearColorAxis colorAxis; @@ -85,12 +95,12 @@ public HeatMapSeriesOxyPlotBase(int paletteSelectedIndex, int renderMethodSelect private void Initialize() { - view = new PlotView + _view = new PlotView { Dock = DockStyle.Fill, }; - model = new PlotModel(); + _model = new PlotModel(); heatMapSeries = new HeatMapSeries { X0 = 0, @@ -106,11 +116,11 @@ private void Initialize() Position = AxisPosition.Right, }; - model.Axes.Add(colorAxis); - model.Series.Add(heatMapSeries); + _model.Axes.Add(colorAxis); + _model.Series.Add(heatMapSeries); - view.Model = model; - Controls.Add(view); + _view.Model = _model; + Controls.Add(_view); InitializeColorPalette(); InitializeRenderMethod(); @@ -332,7 +342,7 @@ public void UpdateHeatMapSeries(double x0, double x1, double y0, double y1, doub /// public void UpdatePlot() { - model.InvalidatePlot(true); + _model.InvalidatePlot(true); } private static readonly Dictionary> paletteLookup = new Dictionary> From 1659700c979c1e9052e83b118db46cd0f15a7bda Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:21:39 +0000 Subject: [PATCH 110/131] Changed formatting. Used a lambda expression for on mouse view changed instead of explicit mouse event handler function. --- .../HeatMapSeriesOxyPlotBase.cs | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs index cdc80150..f009a13a 100644 --- a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs +++ b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs @@ -85,7 +85,11 @@ public class HeatMapSeriesOxyPlotBase : UserControl /// Data source is optional, since pasing it to the constructor will populate the combobox and leave it empty otherwise. /// The selected index is only needed when the data source is provided. /// - public HeatMapSeriesOxyPlotBase(int paletteSelectedIndex, int renderMethodSelectedIndex, int numColors = 100) + public HeatMapSeriesOxyPlotBase( + int paletteSelectedIndex, + int renderMethodSelectedIndex, + int numColors = 100 + ) { _paletteSelectedIndex = paletteSelectedIndex; _renderMethodSelectedIndex = renderMethodSelectedIndex; @@ -152,7 +156,13 @@ private void Initialize() statusStrip.Items.Add(_visualizerPropertiesDropDown); Controls.Add(statusStrip); - view.MouseClick += new MouseEventHandler(onMouseClick); + _view.MouseClick += (sender, e) => { + if (e.Button == MouseButtons.Right) + { + statusStrip.Visible = !statusStrip.Visible; + } + }; + AutoScaleDimensions = new SizeF(6F, 13F); } @@ -303,14 +313,6 @@ private void UpdateRenderMethod() heatMapSeries.RenderMethod = renderMethod; } - private void onMouseClick(object sender, MouseEventArgs e) - { - if (e.Button == MouseButtons.Right) - { - statusStrip.Visible = !statusStrip.Visible; - } - } - /// /// Method to update the heatmap series with new data. /// From 4327dbd862589d002f3172ab0be39a1e3ebeae16 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:22:46 +0000 Subject: [PATCH 111/131] Make heatmap plot field public. --- .../MultidimensionalArrayVisualizer.cs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/Bonsai.ML.Design/MultidimensionalArrayVisualizer.cs b/src/Bonsai.ML.Design/MultidimensionalArrayVisualizer.cs index a2ea15ab..1101c601 100644 --- a/src/Bonsai.ML.Design/MultidimensionalArrayVisualizer.cs +++ b/src/Bonsai.ML.Design/MultidimensionalArrayVisualizer.cs @@ -23,24 +23,25 @@ public class MultidimensionalArrayVisualizer : DialogTypeVisualizer /// public int RenderMethodSelectedIndex { get; set; } - private HeatMapSeriesOxyPlotBase Plot; + private HeatMapSeriesOxyPlotBase _plot; + /// + /// Gets the HeatMapSeriesOxyPlotBase control used to display the heatmap. + /// + public HeatMapSeriesOxyPlotBase Plot => _plot; /// public override void Load(IServiceProvider provider) { - Plot = new HeatMapSeriesOxyPlotBase(PaletteSelectedIndex, RenderMethodSelectedIndex) + _plot = new HeatMapSeriesOxyPlotBase(PaletteSelectedIndex, RenderMethodSelectedIndex) { Dock = DockStyle.Fill, }; - Plot.PaletteComboBoxValueChanged += PaletteIndexChanged; - Plot.RenderMethodComboBoxValueChanged += RenderMethodIndexChanged; + _plot.PaletteComboBoxValueChanged += PaletteIndexChanged; + _plot.RenderMethodComboBoxValueChanged += RenderMethodIndexChanged; var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); - if (visualizerService != null) - { - visualizerService.AddControl(Plot); - } + visualizerService?.AddControl(_plot); } /// From 21ae0ceaced443bb707ea2dbeac4564dc6b3f817 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:23:35 +0000 Subject: [PATCH 112/131] Update decoder package to v0.3.0 --- .../Bonsai.ML.PointProcessDecoder.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index c756f16f..3608f916 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -11,6 +11,6 @@ - + \ No newline at end of file From bca52d8f0f5aef9454bd19df5fe0e0e87aa70365 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:24:40 +0000 Subject: [PATCH 113/131] Update documentation for ignore no spikes property --- src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs index ca07fcb8..2e5015f3 100644 --- a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -358,11 +358,10 @@ public LikelihoodType LikelihoodType private bool ignoreNoSpikes = false; /// - /// Gets or sets a value indicating whether to ignore contributions from channels with no spikes. - /// Only used when the likelihood type is set to . + /// Gets or sets a value indicating whether to ignore contributions from units or channels with no spikes. /// [Category("5. Likelihood Parameters")] - [Description("Indicates whether to ignore contributions from channels with no spikes. Only used when the likelihood type is set to Clusterless.")] + [Description("Indicates whether to ignore contributions from units or channels with no spikes.")] public bool IgnoreNoSpikes { get From db968f3cf278065af01953a718a7954ba36f90bd Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:25:42 +0000 Subject: [PATCH 114/131] Ensured ignore no spikes property is applied to both clusterless and poisson likelihood types --- src/Bonsai.ML.PointProcessDecoder/Decode.cs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Decode.cs b/src/Bonsai.ML.PointProcessDecoder/Decode.cs index fdaf9307..361d1c12 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Decode.cs +++ b/src/Bonsai.ML.PointProcessDecoder/Decode.cs @@ -48,8 +48,16 @@ public IObservable Process(IObservable source) var modelName = Model; return source.Select(input => { var model = PointProcessModelManager.GetModel(modelName); - if (_updateIgnoreNoSpikes && model.Likelihood is ClusterlessLikelihood likelihood) { - likelihood.IgnoreNoSpikes = _ignoreNoSpikes; + if (_updateIgnoreNoSpikes) { + if (model.Likelihood is ClusterlessLikelihood clusterlessLikelihood) + { + clusterlessLikelihood.IgnoreNoSpikes = _ignoreNoSpikes; + } + else if (model.Likelihood is PoissonLikelihood poissonLikelihood) + { + poissonLikelihood.IgnoreNoSpikes = _ignoreNoSpikes; + } + _updateIgnoreNoSpikes = false; } return model.Decode(input); From a3661d223bd92747dfb2261cdd98d782696f9335 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:26:07 +0000 Subject: [PATCH 115/131] Added documentation to public get model function --- .../PointProcessModelManager.cs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs index 629594ef..d756c1b5 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -21,6 +21,12 @@ public static class PointProcessModelManager { private static readonly Dictionary models = []; + /// + /// Gets the point process model with the specified name. + /// + /// + /// + /// public static PointProcessModel GetModel(string name) { return models.TryGetValue(name, out var model) ? model : throw new InvalidOperationException($"Model with name {name} not found."); From b0144333714a623cf14afa593636acb415359115 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:27:05 +0000 Subject: [PATCH 116/131] Updated documentation for using Datetime suffix and used high resolution scheduler for extracting time information --- src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs index d82c5457..20746345 100644 --- a/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs @@ -29,6 +29,7 @@ public class SavePointProcessModel /// /// Specifies the type of suffix to add to the save path. + /// If DateTime, a suffix with the current date and time is added to the save path in the format 'yyyyMMddHHmmss'. /// [Description("Specifies the type of suffix to add to the save path.")] public SuffixType AddSuffix { get; set; } = SuffixType.None; @@ -53,7 +54,7 @@ public IObservable Process(IObservable source) var path = AddSuffix switch { - SuffixType.DateTime => System.IO.Path.Combine(Path, $"{DateTime.Now:yyyyMMddHHmmss}"), + SuffixType.DateTime => System.IO.Path.Combine(Path, $"{HighResolutionScheduler.Now:yyyyMMddHHmmss}"), SuffixType.Guid => System.IO.Path.Combine(Path, Guid.NewGuid().ToString()), _ => Path }; From dbf2949d8923448cab68772c14acdec16d91a486 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:30:37 +0000 Subject: [PATCH 117/131] Renamed `ConditionalIntensitiesVisualizer` to just `IntensitiesVisualizer` and made changes to display heatmaps correctly. --- ...Visualizer.cs => IntensitiesVisualizer.cs} | 165 +++++++++++------- 1 file changed, 106 insertions(+), 59 deletions(-) rename src/Bonsai.ML.PointProcessDecoder.Design/{ConditionalIntensitiesVisualizer.cs => IntensitiesVisualizer.cs} (70%) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/IntensitiesVisualizer.cs similarity index 70% rename from src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs rename to src/Bonsai.ML.PointProcessDecoder.Design/IntensitiesVisualizer.cs index 47b02540..0776a3f9 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/ConditionalIntensitiesVisualizer.cs +++ b/src/Bonsai.ML.PointProcessDecoder.Design/IntensitiesVisualizer.cs @@ -14,14 +14,17 @@ using PointProcessDecoder.Core; -[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.ConditionalIntensitiesVisualizer), +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.IntensitiesVisualizer), Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] namespace Bonsai.ML.PointProcessDecoder.Design { - public class ConditionalIntensitiesVisualizer : DialogTypeVisualizer + public class IntensitiesVisualizer : DialogTypeVisualizer { private int _rowCount = 1; + /// + /// The number of rows in the visualizer. + /// public int RowCount { get => _rowCount; @@ -36,6 +39,9 @@ public int RowCount } private int _columnCount = 1; + /// + /// The number of columns in the visualizer. + /// public int ColumnCount { get => _columnCount; @@ -50,6 +56,9 @@ public int ColumnCount } private int _selectedPageIndex = 0; + /// + /// The index of the current page displayed in the visualizer. + /// public int SelectedPageIndex { get => _selectedPageIndex; @@ -59,21 +68,28 @@ public int SelectedPageIndex } } + private StatusStrip _statusStrip = null; + /// + /// The status strip control that displays the visualizer options. + /// + public StatusStrip StatusStrip => _statusStrip; + private readonly int _sampleFrequency = 30; private int _pageCount = 1; private string _modelName = string.Empty; private List _heatmapPlots = null; - private int _conditionalIntensitiesCount = 0; + private int _intensitiesCount = 0; private TableLayoutPanel _container = null; - private readonly List _conditionalIntensitiesCumulativeIndex = []; - private StatusStrip _statusStrip = null; - public StatusStrip StatusStrip => _statusStrip; + private readonly List _intensitiesCumulativeIndex = []; private ToolStripNumericUpDown _pageIndexControl = null; private ToolStripNumericUpDown _rowControl = null; private ToolStripNumericUpDown _columnControl = null; - private Tensor[] _conditionalIntensities = null; + private Tensor[] _intensities = null; private long _stateSpaceWidth; private long _stateSpaceHeight; + private double[] _stateSpaceMin; + private double[] _stateSpaceMax; + private bool _isProcessing = false; /// public override void Load(IServiceProvider provider) @@ -90,20 +106,19 @@ public override void Load(IServiceProvider provider) if (decodeNode == null) { - Console.WriteLine("The decode node is invalid."); throw new InvalidOperationException("The decode node is invalid."); } _modelName = decodeNode.Model; if (string.IsNullOrEmpty(_modelName)) { - Console.WriteLine("The point process model name is not set."); throw new InvalidOperationException("The point process model name is not set."); } _container = new TableLayoutPanel() { Dock = DockStyle.Fill, + AutoSize = true, ColumnCount = ColumnCount, RowCount = _rowCount, }; @@ -118,6 +133,10 @@ public override void Load(IServiceProvider provider) _pageIndexControl.ValueChanged += (sender, e) => { + if (_heatmapPlots is null) + { + return; + } var value = Convert.ToInt32(_pageIndexControl.Value); SelectedPageIndex = value; UpdateTableLayout(); @@ -135,6 +154,11 @@ public override void Load(IServiceProvider provider) _rowControl.ValueChanged += (sender, e) => { + if (_heatmapPlots is null) + { + return; + } + RowCount = Convert.ToInt32(_rowControl.Value); UpdatePages(); if (_selectedPageIndex >= _pageCount) @@ -160,6 +184,11 @@ public override void Load(IServiceProvider provider) _columnControl.ValueChanged += (sender, e) => { + if (_heatmapPlots is null) + { + return; + } + ColumnCount = Convert.ToInt32(_columnControl.Value); UpdatePages(); if (_selectedPageIndex >= _pageCount) @@ -196,12 +225,13 @@ public override void Load(IServiceProvider provider) private void UpdatePages() { - _pageCount = (int)Math.Ceiling((double)_conditionalIntensitiesCount / (_rowCount * _columnCount)); + _pageCount = (int)Math.Ceiling((double)_intensitiesCount / (_rowCount * _columnCount)); _pageIndexControl.Maximum = _pageCount - 1; } private bool UpdateModel() { + _isProcessing = true; PointProcessModel model; try { @@ -217,32 +247,48 @@ private bool UpdateModel() if (model.StateSpace.Dimensions != 2) { - throw new InvalidOperationException("For the conditional intensities visualizer to work, the state space dimensions must be 2."); + throw new InvalidOperationException("For the intensities visualizer to work, the state space dimensions must be 2."); } - if (model.Encoder.ConditionalIntensities.Length == 0 || (model.Encoder.ConditionalIntensities.Length == 1 && model.Encoder.ConditionalIntensities[0].numel() == 0)) + if (model.Encoder.Intensities.Length == 0 || (model.Encoder.Intensities.Length == 1 && model.Encoder.Intensities[0].NumberOfElements == 0)) { return false; } - - _conditionalIntensities = model.Encoder.ConditionalIntensities; + + _intensities = model.Encoder.Intensities; _stateSpaceWidth = model.StateSpace.Shape[0]; _stateSpaceHeight = model.StateSpace.Shape[1]; + _stateSpaceMin = [.. model.StateSpace.Points + .min(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _stateSpaceMax = [.. model.StateSpace.Points + .max(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _isProcessing = false; + return true; } - private static int GetConditionalIntensitiesCount(Tensor[] conditionalIntensities, List conditionalIntensitiesCumulativeIndex) + private static int GetIntensitiesCount(Tensor[] intensities, List intensitiesCumulativeIndex) { - long conditionalIntensitiesCount = 0; - conditionalIntensitiesCumulativeIndex.Clear(); - for (int i = 0; i < conditionalIntensities.Length; i++) { - if (conditionalIntensities[i].numel() > 0) { - conditionalIntensitiesCount += conditionalIntensities[i].size(0); - conditionalIntensitiesCumulativeIndex.Add(conditionalIntensitiesCount); + long intensitiesCount = 0; + intensitiesCumulativeIndex.Clear(); + for (int i = 0; i < intensities.Length; i++) { + if (intensities[i].NumberOfElements > 0) { + intensitiesCount += intensities[i].size(0); + intensitiesCumulativeIndex.Add(intensitiesCount); } } - return (int)conditionalIntensitiesCount; + return (int)intensitiesCount; } private bool UpdateHeatmaps() @@ -250,31 +296,31 @@ private bool UpdateHeatmaps() if (_heatmapPlots is null) { _heatmapPlots = []; - for (int i = 0; i < _conditionalIntensitiesCount; i++) + for (int i = 0; i < _intensitiesCount; i++) { - _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(0, 0) + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(1, 0) { Dock = DockStyle.Fill, }); } } - else if (_heatmapPlots.Count > _conditionalIntensitiesCount) + else if (_heatmapPlots.Count > _intensitiesCount) { - var count = _heatmapPlots.Count - _conditionalIntensitiesCount; + var count = _heatmapPlots.Count - _intensitiesCount; for (int i = 0; i < count; i++) { - if (!_heatmapPlots[i + _conditionalIntensitiesCount].IsDisposed) + if (!_heatmapPlots[i + _intensitiesCount].IsDisposed) { - _heatmapPlots[i + _conditionalIntensitiesCount].Dispose(); + _heatmapPlots[i + _intensitiesCount].Dispose(); } } - _heatmapPlots.RemoveRange(_conditionalIntensitiesCount, count); + _heatmapPlots.RemoveRange(_intensitiesCount, count); } - else if (_heatmapPlots.Count < _conditionalIntensitiesCount) + else if (_heatmapPlots.Count < _intensitiesCount) { - for (int i = _heatmapPlots.Count; i < _conditionalIntensitiesCount; i++) + for (int i = _heatmapPlots.Count; i < _intensitiesCount; i++) { - _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(0, 0) + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(1, 0) { Dock = DockStyle.Fill, }); @@ -308,7 +354,7 @@ private void UpdateTableLayout() for (int j = 0; j < _columnCount; j++) { var index = SelectedPageIndex * _rowCount * _columnCount + i * _columnCount + j; - if (index >= _conditionalIntensitiesCount) + if (index >= _intensitiesCount) { break; } @@ -318,42 +364,40 @@ private void UpdateTableLayout() } } - private (int ConditionalIntensitiesIndex, int ConditionalIntensitiesTensorIndex) GetConditionalIntensitiesIndex(int index) + private (int intensitiesIndex, int intensitiesTensorIndex) GetIntensitiesIndex(int index) { - var conditionalIntensitiesIndex = 0; - for (int i = 0; i < _conditionalIntensitiesCumulativeIndex.Count; i++) + var intensitiesIndex = 0; + for (int i = 0; i < _intensitiesCumulativeIndex.Count; i++) { - if (index < _conditionalIntensitiesCumulativeIndex[i]) + if (index < _intensitiesCumulativeIndex[i]) { - conditionalIntensitiesIndex = i; + intensitiesIndex = i; break; } } - var conditionalIntensitiesTensorIndex = conditionalIntensitiesIndex == 0 ? index : index - _conditionalIntensitiesCumulativeIndex[conditionalIntensitiesIndex - 1]; - return (conditionalIntensitiesIndex, (int)conditionalIntensitiesTensorIndex); + var intensitiesTensorIndex = intensitiesIndex == 0 ? index : index - _intensitiesCumulativeIndex[intensitiesIndex - 1]; + return (intensitiesIndex, (int)intensitiesTensorIndex); } /// public override void Show(object value) - { - var startIndex = SelectedPageIndex * _rowCount * _columnCount; - var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _conditionalIntensitiesCount); + { + var startIndex = _selectedPageIndex * _rowCount * _columnCount; + var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _intensitiesCount); for (int i = startIndex; i < endIndex; i++) { - var (conditionalIntensitiesIndex, conditionalIntensitiesTensorIndex) = GetConditionalIntensitiesIndex(i); + var (intensitiesIndex, intensitiesTensorIndex) = GetIntensitiesIndex(i); - var conditionalIntensity = _conditionalIntensities[conditionalIntensitiesIndex][conditionalIntensitiesTensorIndex]; + var intensity = _intensities[intensitiesIndex][intensitiesTensorIndex]; - if (conditionalIntensity.Dimensions == 2) { - conditionalIntensity = conditionalIntensity + if (intensity.Dimensions == 2) { + intensity = intensity .sum(dim: 0); } - - - var conditionalIntensityValues = (double[,])conditionalIntensity + var intensityValues = (double[,])intensity .exp() .to_type(ScalarType.Float64) .reshape([_stateSpaceWidth, _stateSpaceHeight]) @@ -361,7 +405,11 @@ public override void Show(object value) .ToNDArray(); _heatmapPlots[i].UpdateHeatMapSeries( - conditionalIntensityValues + _stateSpaceMin[0], + _stateSpaceMax[0], + _stateSpaceMin[1], + _stateSpaceMax[1], + intensityValues ); _heatmapPlots[i].UpdatePlot(); @@ -392,9 +440,9 @@ public override void Unload() _heatmapPlots = null; }; - _conditionalIntensitiesCount = 0; - _conditionalIntensitiesCumulativeIndex.Clear(); - _conditionalIntensities = null; + _intensitiesCount = 0; + _intensitiesCumulativeIndex.Clear(); + _intensities = null; } public override IObservable Visualize(IObservable> source, IServiceProvider provider) @@ -409,16 +457,15 @@ public override IObservable Visualize(IObservable> s .ObserveOn(visualizerControl) .Do(value => { - var success = UpdateModel(); - if (!success) + if (!UpdateModel() && !_isProcessing) { return; } - var newConditionalIntensitiesCount = GetConditionalIntensitiesCount(_conditionalIntensities, _conditionalIntensitiesCumulativeIndex); - if (_conditionalIntensitiesCount != newConditionalIntensitiesCount) + var newIntensitiesCount = GetIntensitiesCount(_intensities, _intensitiesCumulativeIndex); + if (_intensitiesCount != newIntensitiesCount) { - _conditionalIntensitiesCount = newConditionalIntensitiesCount; + _intensitiesCount = newIntensitiesCount; UpdatePages(); UpdateHeatmaps(); UpdateTableLayout(); From 5af05b12da8fe308f26e5b3f683680ec9c64761f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:31:34 +0000 Subject: [PATCH 118/131] Added a density estimations visualizer --- .../DensityEstimationsVisualizer.cs | 460 ++++++++++++++++++ 1 file changed, 460 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/DensityEstimationsVisualizer.cs diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/DensityEstimationsVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/DensityEstimationsVisualizer.cs new file mode 100644 index 00000000..63669df2 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/DensityEstimationsVisualizer.cs @@ -0,0 +1,460 @@ +using System; +using System.Reactive.Linq; +using System.Reactive; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; + +using Bonsai; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.DensityEstimationsVisualizer), + Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + public class DensityEstimationsVisualizer : DialogTypeVisualizer + { + private int _rowCount = 1; + public int RowCount + { + get => _rowCount; + set + { + if (value < 1) + { + throw new InvalidOperationException("The number of rows must be greater than 0."); + } + _rowCount = value; + } + } + + private int _columnCount = 1; + public int ColumnCount + { + get => _columnCount; + set + { + if (value < 1) + { + throw new InvalidOperationException("The number of columns must be greater than 0."); + } + _columnCount = value; + } + } + + private int _selectedPageIndex = 0; + public int SelectedPageIndex + { + get => _selectedPageIndex; + set + { + _selectedPageIndex = value; + } + } + + private readonly int _sampleFrequency = 5; + private int _pageCount = 1; + private string _modelName = string.Empty; + private List _heatmapPlots = null; + private int _estimationsCount = 0; + private TableLayoutPanel _container = null; + private StatusStrip _statusStrip = null; + public StatusStrip StatusStrip => _statusStrip; + private ToolStripNumericUpDown _pageIndexControl = null; + private ToolStripNumericUpDown _rowControl = null; + private ToolStripNumericUpDown _columnControl = null; + private Tensor[] _estimations = null; + private long _stateSpaceWidth; + private long _stateSpaceHeight; + private double[] _stateSpaceMin; + private double[] _stateSpaceMax; + private bool _isProcessing = false; + + /// + public override void Load(IServiceProvider provider) + { + Decode decodeNode = null; + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + decodeNode = ExpressionBuilder.GetWorkflowElement( + expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) + .FirstOrDefault().Value) as Decode; + } + + if (decodeNode == null) + { + throw new InvalidOperationException("The decode node is invalid."); + } + + _modelName = decodeNode.Model; + if (string.IsNullOrEmpty(_modelName)) + { + throw new InvalidOperationException("The point process model name is not set."); + } + + _container = new TableLayoutPanel() + { + Dock = DockStyle.Fill, + AutoSize = true, + ColumnCount = ColumnCount, + RowCount = _rowCount, + }; + + var pageIndexLabel = new ToolStripLabel($"Page: {_selectedPageIndex}"); + _pageIndexControl = new ToolStripNumericUpDown() + { + Minimum = 0, + DecimalPlaces = 0, + Value = _selectedPageIndex, + }; + + _pageIndexControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + + var value = Convert.ToInt32(_pageIndexControl.Value); + SelectedPageIndex = value; + UpdateTableLayout(); + if (UpdateModel()) + { + Show(null); + } + pageIndexLabel.Text = $"Page: {_selectedPageIndex}"; + }; + + var rowLabel = new ToolStripLabel($"Rows: {_rowCount}"); + _rowControl = new ToolStripNumericUpDown() + { + Minimum = 1, + DecimalPlaces = 0, + Value = _rowCount, + }; + + _rowControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + + RowCount = Convert.ToInt32(_rowControl.Value); + UpdatePages(); + if (_selectedPageIndex >= _pageCount) + { + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Value = _selectedPageIndex; + } + else + { + UpdateTableLayout(); + if (UpdateModel()) + { + Show(null); + } + } + rowLabel.Text = $"Rows: {_rowCount}"; + }; + + var columnLabel = new ToolStripLabel($"Columns: {_columnCount}"); + _columnControl = new ToolStripNumericUpDown() + { + Minimum = 1, + DecimalPlaces = 0, + Value = _columnCount, + }; + + _columnControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + + ColumnCount = Convert.ToInt32(_columnControl.Value); + UpdatePages(); + if (_selectedPageIndex >= _pageCount) + { + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Value = _selectedPageIndex; + } + else + { + UpdateTableLayout(); + if (UpdateModel()) + { + Show(null); + } + } + columnLabel.Text = $"Columns: {_columnCount}"; + }; + + _statusStrip = new StatusStrip() + { + Visible = true, + }; + + _statusStrip.Items.AddRange([ + pageIndexLabel, + _pageIndexControl, + rowLabel, + _rowControl, + columnLabel, + _columnControl + ]); + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + visualizerService?.AddControl(_container); + visualizerService?.AddControl(_statusStrip); + } + + private void UpdatePages() + { + _pageCount = (int)Math.Ceiling((double)_estimationsCount / (_rowCount * _columnCount)); + _pageIndexControl.Maximum = _pageCount - 1; + } + + private bool UpdateModel() + { + _isProcessing = true; + PointProcessModel model; + + try { + model = PointProcessModelManager.GetModel(_modelName); + } catch { + return false; + } + + if (model == null) + { + return false; + } + + if (model.StateSpace.Dimensions != 2) + { + throw new InvalidOperationException("For the conditional intensities visualizer to work, the state space dimensions must be 2."); + } + + if (model.Encoder.Estimations.Length == 0) + { + return false; + } + + + _estimationsCount = model.Encoder.Estimations.Length; + _estimations = new Tensor[_estimationsCount]; + + var startIndex = SelectedPageIndex * _rowCount * _columnCount; + var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _estimationsCount); + + for (int i = startIndex; i < endIndex; i++) + { + var estimate = model.Encoder.Estimations[i].Estimate( + model.StateSpace.Points, + null, + model.StateSpace.Dimensions + ); + + if (estimate.NumberOfElements == 0) { + _estimations[i] = ones([model.StateSpace.Points.size(0), 1]) * double.NaN; + } else { + _estimations[i] = model.Encoder.Estimations[i].Normalize(estimate); + } + } + + _stateSpaceWidth = model.StateSpace.Shape[0]; + _stateSpaceHeight = model.StateSpace.Shape[1]; + + _stateSpaceMin = [.. model.StateSpace.Points + .min(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + _stateSpaceMax = [.. model.StateSpace.Points + .max(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + // GC.KeepAlive(model); + _isProcessing = false; + + return true; + } + + private bool UpdateHeatmaps() + { + if (_heatmapPlots is null) + { + _heatmapPlots = []; + for (int i = 0; i < _estimationsCount; i++) + { + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(1, 0) + { + Dock = DockStyle.Fill, + }); + } + } + else if (_heatmapPlots.Count > _estimationsCount) + { + var count = _heatmapPlots.Count - _estimationsCount; + for (int i = 0; i < count; i++) + { + if (!_heatmapPlots[i + _estimationsCount].IsDisposed) + { + _heatmapPlots[i + _estimationsCount].Dispose(); + } + } + _heatmapPlots.RemoveRange(_estimationsCount, count); + } + else if (_heatmapPlots.Count < _estimationsCount) + { + for (int i = _heatmapPlots.Count; i < _estimationsCount; i++) + { + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(1, 0) + { + Dock = DockStyle.Fill, + }); + } + } + + return true; + } + + private void UpdateTableLayout() + { + _container.Controls.Clear(); + _container.RowStyles.Clear(); + _container.ColumnStyles.Clear(); + + _container.RowCount = _rowCount; + _container.ColumnCount = _columnCount; + + for (int i = 0; i < _rowCount; i++) + { + _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / _rowCount)); + } + + for (int i = 0; i < _columnCount; i++) + { + _container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / _columnCount)); + } + + for (int i = 0; i < _rowCount; i++) + { + for (int j = 0; j < _columnCount; j++) + { + var index = SelectedPageIndex * _rowCount * _columnCount + i * _columnCount + j; + if (index >= _estimationsCount) + { + break; + } + + _container.Controls.Add(_heatmapPlots[index], j, i); + } + } + } + + /// + public override void Show(object value) + { + var startIndex = SelectedPageIndex * _rowCount * _columnCount; + var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _estimationsCount); + + for (int i = startIndex; i < endIndex; i++) + { + var estimation = _estimations[i]; + + var estimationValues = (double[,])estimation + .to_type(ScalarType.Float64) + .reshape([_stateSpaceWidth, _stateSpaceHeight]) + .data() + .ToNDArray(); + + _heatmapPlots[i].UpdateHeatMapSeries( + _stateSpaceMin[0], + _stateSpaceMax[0], + _stateSpaceMin[1], + _stateSpaceMax[1], + estimationValues + ); + + _heatmapPlots[i].UpdatePlot(); + } + } + + /// + public override void Unload() + { + if (_container != null) + { + if (!_container.IsDisposed) + { + _container.Dispose(); + } + _container = null; + } + + if (_heatmapPlots != null) + { + for (int i = 0; i < _heatmapPlots.Count; i++) + { + if (!_heatmapPlots[i].IsDisposed) + { + _heatmapPlots[i].Dispose(); + } + } + _heatmapPlots = null; + }; + + _estimationsCount = 0; + _estimations = null; + } + + public override IObservable Visualize(IObservable> source, IServiceProvider provider) + { + if (provider.GetService(typeof(IDialogTypeVisualizerService)) is not Control visualizerControl) + { + return source; + } + + var timer = Observable.Interval( + TimeSpan.FromMilliseconds(100), + HighResolutionScheduler.Default + ); + + return source.SelectMany(input => + input.Buffer(timer) + .Where(buffer => buffer.Count > 0 && !_isProcessing) + .ObserveOn(visualizerControl) + .Do(buffer => + { + if (!UpdateModel()) + { + return; + } + + UpdatePages(); + UpdateHeatmaps(); + UpdateTableLayout(); + + Show(buffer.LastOrDefault()); + } + ) + ); + } + } +} \ No newline at end of file From dfd1c3e294814d964d4e5377f9f29fc7e997207d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:31:54 +0000 Subject: [PATCH 119/131] Added a likelihood visualizer --- .../LikelihoodVisualizer.cs | 244 ++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs new file mode 100644 index 00000000..827240dd --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs @@ -0,0 +1,244 @@ +using System; +using System.Reactive.Linq; +using System.Reactive; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; + +using Bonsai; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; + +using OxyPlot.Series; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using TorchSharp; +using Bonsai.Dag; +using System.Linq.Expressions; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.LikelihoodVisualizer), + Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + public class LikelihoodVisualizer : MashupVisualizer + { + private MultidimensionalArrayVisualizer _visualizer; + + /// + /// Gets the underlying heatmap plot. + /// + public HeatMapSeriesOxyPlotBase Plot => _visualizer.Plot; + + private int _capacity = 100; + /// + /// Gets or sets the integer value that determines how many data points should be shown along the x axis if the posterior is a 1D tensor. + /// + public int Capacity + { + get => _capacity; + set + { + _capacity = value; + } + } + + private double[,] _data = null; + private long _stateSpaceWidth; + private long _stateSpaceHeight; + private double[] _stateSpaceMin = null; + private double[] _stateSpaceMax = null; + private string _modelName; + private ILikelihood _likelihood; + private Tensor[] _intensities; + private IObservable> _inputSource; + private Tensor _dataTensor; + + /// + public override void Load(IServiceProvider provider) + { + Node visualizerNode = null; + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + // decodeNode = ExpressionBuilder.GetWorkflowElement( + // expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) + // .FirstOrDefault()); + visualizerNode = (from node in expressionBuilderGraph + where node.Value == typeVisualizerContext.Source + select node).FirstOrDefault(); + } + + if (visualizerNode == null) + { + throw new InvalidOperationException("The visualizer node is invalid."); + } + + var inspector = (InspectBuilder)expressionBuilderGraph + .Predecessors(visualizerNode) + .First(p => !p.Value.IsBuildDependency()) + .Value; + + _inputSource = inspector.Output; + + if (ExpressionBuilder.GetWorkflowElement(visualizerNode.Value) is not Decode decodeNode) + { + throw new InvalidOperationException("The decode node is invalid."); + } + + _modelName = decodeNode.Model; + + _visualizer = new MultidimensionalArrayVisualizer() + { + PaletteSelectedIndex = 1, + RenderMethodSelectedIndex = 0 + }; + + _visualizer.Load(provider); + + var capacityLabel = new ToolStripLabel + { + Text = "Capacity:", + AutoSize = true + }; + var capacityValue = new ToolStripLabel + { + Text = Capacity.ToString(), + AutoSize = true + }; + + _visualizer.Plot.StatusStrip.Items.AddRange([ + capacityLabel, + capacityValue + ]); + + base.Load(provider); + } + + private bool UpdateModel() + { + PointProcessModel model; + + try + { + model = PointProcessModelManager.GetModel(_modelName); + } catch { + return false; + } + + if (model == null) + { + return false; + } + + _stateSpaceWidth = model.StateSpace.Shape[0]; + _stateSpaceHeight = model.StateSpace.Shape[1]; + + _stateSpaceMin ??= [.. model.StateSpace.Points + .min(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _stateSpaceMax ??= [.. model.StateSpace.Points + .max(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _likelihood = model.Likelihood; + _intensities = model.Encoder.Intensities; + + return true; + } + + /// + public override void Show(object value) + { + Tensor inputs = (Tensor)value; + Tensor likelihood = _likelihood.Likelihood(inputs, _intensities); + + if (likelihood.Dimensions == 2) { + likelihood = likelihood + .sum(dim: 0); + } + + _data = (double[,])likelihood + .to_type(ScalarType.Float64) + .reshape([_stateSpaceWidth, _stateSpaceHeight]) + .data() + .ToNDArray(); + + + _visualizer.Plot.UpdateHeatMapSeries( + _stateSpaceMin[0], + _stateSpaceMax[0], + _stateSpaceMin[1], + _stateSpaceMax[1], + _data + ); + + _visualizer.Plot.UpdatePlot(); + } + + /// + public override void Unload() + { + _visualizer.Unload(); + base.Unload(); + } + + /// + public override IObservable Visualize(IObservable> source, IServiceProvider provider) + { + if (provider.GetService(typeof(IDialogTypeVisualizerService)) is not Control visualizerControl) + { + return source; + } + + var timer = Observable.Interval( + TimeSpan.FromMilliseconds(100), + HighResolutionScheduler.Default + ); + + // var mergedSource = source.SelectMany(xs => + // xs.Buffer(timer) + // .Where(buffer => buffer.Count > 0) + // .Do(buffer => { + // if (!UpdateModel()) + // { + // return; + // } + // Show(buffer.LastOrDefault()); + // })); + + var mergedSource = _inputSource.SelectMany(xs => + xs.Buffer(timer) + .Where(buffer => buffer.Count > 0) + .Sample(source.Merge()) + .Do(buffer => { + if (!UpdateModel()) + { + return; + } + Show(buffer.LastOrDefault()); + })); + + var mashupSourceStreams = Observable.Merge( + MashupSources.Select(mashupSource => + mashupSource.Source.Output.SelectMany(xs => + xs.Buffer(timer) + .Where(buffer => buffer.Count > 0) + .Do(buffer => mashupSource.Visualizer.Show(buffer.LastOrDefault())) + ))); + + return Observable.Merge(mergedSource, mashupSourceStreams); + } + } +} \ No newline at end of file From 0514fcf4780b7ca238b1757a19187acc5e0b85d8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:32:03 +0000 Subject: [PATCH 120/131] Added a posterior visualizer --- .../PosteriorVisualizer.cs | 243 ++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs new file mode 100644 index 00000000..9230c734 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs @@ -0,0 +1,243 @@ +using System; +using System.Reactive.Linq; +using System.Reactive; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; + +using Bonsai; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; + +using OxyPlot.Series; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using TorchSharp; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.PosteriorVisualizer), + Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + public class PosteriorVisualizer : MashupVisualizer + { + private MultidimensionalArrayVisualizer _visualizer; + + /// + /// Gets the underlying heatmap plot. + /// + public HeatMapSeriesOxyPlotBase Plot => _visualizer.Plot; + + private int _capacity = 100; + /// + /// Gets or sets the integer value that determines how many data points should be shown along the x axis if the posterior is a 1D tensor. + /// + public int Capacity + { + get => _capacity; + set + { + _capacity = value; + } + } + + private double[,] _data = null; + private double[] _stateSpaceMin; + private double[] _stateSpaceMax; + private string _modelName; + private bool _success = false; + private Tensor _dataTensor; + + /// + public override void Load(IServiceProvider provider) + { + Decode decodeNode = null; + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + decodeNode = ExpressionBuilder.GetWorkflowElement( + expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) + .FirstOrDefault().Value) as Decode; + } + + if (decodeNode == null) + { + throw new InvalidOperationException("The decode node is invalid."); + } + + _modelName = decodeNode.Model; + + _visualizer = new MultidimensionalArrayVisualizer() + { + PaletteSelectedIndex = 1, + RenderMethodSelectedIndex = 0 + }; + + _visualizer.Load(provider); + + var capacityLabel = new ToolStripLabel + { + Text = "Capacity:", + AutoSize = true + }; + var capacityValue = new ToolStripLabel + { + Text = Capacity.ToString(), + AutoSize = true + }; + + _visualizer.Plot.StatusStrip.Items.AddRange([ + capacityLabel, + capacityValue + ]); + + base.Load(provider); + } + + private void UpdateModel() + { + PointProcessModel model; + + try + { + model = PointProcessModelManager.GetModel(_modelName); + } catch { + _success = false; + return; + } + + if (model == null) + { + _success = false; + return; + } + + _stateSpaceMin = [.. model.StateSpace.Points + .min(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _stateSpaceMax = [.. model.StateSpace.Points + .max(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _success = true; + } + + /// + public override void Show(object value) + { + Tensor posterior = (Tensor)value; + if (posterior.NumberOfElements == 0 || !_success) + { + return; + } + + posterior = posterior.sum(dim: 0); + + double xMin; + double xMax; + double yMin; + double yMax; + + if (posterior.Dimensions == 1) + { + + if (_data == null) + { + _dataTensor = zeros(_capacity, posterior.size(0), dtype: ScalarType.Float64, device: posterior.device); + } + + _dataTensor = _dataTensor[TensorIndex.Slice(1)]; + _dataTensor = concat([_dataTensor, + posterior.to_type(ScalarType.Float64) + .unsqueeze(0) + ], dim: 0); + + _data = (double[,])_dataTensor + .data() + .ToNDArray(); + + xMin = 0; + xMax = _capacity; + yMin = _stateSpaceMin[0]; + yMax = _stateSpaceMax[0]; + } + else + { + + while (posterior.Dimensions > 2) + { + posterior = posterior.sum(dim: 0); + } + + _data = (double[,])posterior + .to_type(ScalarType.Float64) + .data() + .ToNDArray(); + + xMin = _stateSpaceMin[_stateSpaceMin.Length - 2]; + xMax = _stateSpaceMax[_stateSpaceMax.Length - 2]; + yMin = _stateSpaceMin[_stateSpaceMin.Length - 1]; + yMax = _stateSpaceMax[_stateSpaceMax.Length - 1]; + } + + _visualizer.Plot.UpdateHeatMapSeries( + xMin, xMax, yMin, yMax, _data + ); + + _visualizer.Plot.UpdatePlot(); + } + + /// + public override void Unload() + { + _visualizer.Unload(); + base.Unload(); + } + + /// + public override IObservable Visualize(IObservable> source, IServiceProvider provider) + { + if (provider.GetService(typeof(IDialogTypeVisualizerService)) is not Control visualizerControl) + { + return source; + } + + var timer = Observable.Interval( + TimeSpan.FromMilliseconds(100), + HighResolutionScheduler.Default + ); + + var mergedSource = source.SelectMany(xs => + xs.Buffer(timer) + .Where(buffer => buffer.Count > 0) + .Do(buffer => { + if (!_success) + { + UpdateModel(); + } + Show(buffer.LastOrDefault()); + })); + + var mashupSourceStreams = Observable.Merge( + MashupSources.Select(mashupSource => + mashupSource.Source.Output.SelectMany(xs => + xs.Buffer(timer) + .Where(buffer => buffer.Count > 0) + .Do(buffer => mashupSource.Visualizer.Show(buffer.LastOrDefault())) + ))); + + return Observable.Merge(mergedSource, mashupSourceStreams); + } + } +} \ No newline at end of file From e18af86b15b968a3e268298bd8f24c9e8bc94931 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:32:31 +0000 Subject: [PATCH 121/131] Added a point2D overlay visualizer for posterior and likelihood mashup visualizers --- .../Point2DOverlay.cs | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs b/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs new file mode 100644 index 00000000..537b389e --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs @@ -0,0 +1,86 @@ +using Bonsai; +using Bonsai.Design; +using Bonsai.Design.Visualizers; +using System; +using System.Collections.Generic; +using OxyPlot.Series; +using OxyPlot.Axes; +using OxyPlot; +using Bonsai.Vision.Design; +using Bonsai.ML.Design; +using System.Linq; +using TorchSharp; +using OpenCV.Net; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.Point2DOverlay), + Target = typeof(MashupSource))] + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.Point2DOverlay), + Target = typeof(MashupSource))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + /// + /// Class that overlays the true + /// + public class Point2DOverlay : DialogTypeVisualizer + { + internal LineSeries _lineSeries; + internal ScatterSeries _scatterSeries; + private int _capacity; + private int _dataCount; + + /// + public override void Load(IServiceProvider provider) + { + dynamic service = provider.GetService(typeof(MashupVisualizer)); + _capacity = service.Capacity; + + _lineSeries = new LineSeries() + { + Color = OxyColors.LimeGreen, + StrokeThickness = 2 + }; + + var colorAxis = new LinearColorAxis() + { + IsAxisVisible = false, + Key = "Point2DOverlayColorAxis" + }; + + _scatterSeries = new ScatterSeries() + { + MarkerType = MarkerType.Circle, + MarkerSize = 10, + MarkerFill = OxyColors.LimeGreen, + ColorAxisKey = "Point2DOverlayColorAxis" + }; + + + service.Plot.Model.Series.Add(_scatterSeries); + service.Plot.Model.Series.Add(_lineSeries); + service.Plot.Model.Axes.Add(colorAxis); + } + + /// + public override void Show(object value) + { + dynamic point = value; + _dataCount++; + _lineSeries.Points.Add(new DataPoint(point.X, point.Y)); + _scatterSeries.Points.Clear(); + _scatterSeries.Points.Add(new ScatterPoint(point.X, point.Y, value: 1)); + + while (_dataCount > _capacity) + { + _lineSeries.Points.RemoveAt(0); + _dataCount--; + } + } + + /// + public override void Unload() + { + } + } +} \ No newline at end of file From 4f8e96203c4aadff1f0dc6c52017fa7c75f7a1a0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 21 Feb 2025 13:32:57 +0000 Subject: [PATCH 122/131] Added `Bonsai.Vision.Design` package to project reference --- .../Bonsai.ML.PointProcessDecoder.Design.csproj | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj b/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj index 1a052001..8b3fb9a1 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj +++ b/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj @@ -10,4 +10,7 @@ + + + \ No newline at end of file From aea32e2971f97b9eab122cd31a27ee6b1ac737d6 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 25 Feb 2025 15:14:18 +0000 Subject: [PATCH 123/131] Updated heatmap series oxyplot base to enable setting x and y axis ranges from external class --- .../HeatMapSeriesOxyPlotBase.cs | 81 ++++++++++++++++--- 1 file changed, 71 insertions(+), 10 deletions(-) diff --git a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs index f009a13a..0f23f2ae 100644 --- a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs +++ b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs @@ -27,7 +27,7 @@ public class HeatMapSeriesOxyPlotBase : UserControl public PlotModel Model => _model; private HeatMapSeries heatMapSeries; - private LinearColorAxis colorAxis; + private LinearColorAxis colorAxis = null; private ToolStripComboBox paletteComboBox; private ToolStripLabel paletteLabel; @@ -40,10 +40,10 @@ public class HeatMapSeriesOxyPlotBase : UserControl private HeatMapRenderMethod renderMethod = HeatMapRenderMethod.Bitmap; private StatusStrip statusStrip; - private ToolStripTextBox maxValueTextBox; + private ToolStripTextBox maxValueTextBox = null; private ToolStripLabel maxValueLabel; - private ToolStripTextBox minValueTextBox; + private ToolStripTextBox minValueTextBox = null; private ToolStripLabel minValueLabel; private ToolStripDropDownButton _visualizerPropertiesDropDown; @@ -79,6 +79,40 @@ public class HeatMapSeriesOxyPlotBase : UserControl /// public StatusStrip StatusStrip => statusStrip; + private double? _valueMin = null; + /// + /// Gets or sets the minimum value of the color axis. + /// + public double? ValueMin + { + get => _valueMin; + set + { + _valueMin = value; + if (minValueTextBox != null) + minValueTextBox.Text = value?.ToString(); + if (colorAxis != null) + colorAxis.Maximum = value ?? double.NaN; + } + } + + private double? _valueMax = null; + /// + /// Gets or sets the maximum value of the color axis. + /// + public double? ValueMax + { + get => _valueMax; + set + { + _valueMax = value; + if (maxValueTextBox != null) + maxValueTextBox.Text = value?.ToString(); + if (colorAxis != null) + colorAxis.Maximum = value ?? double.NaN; + } + } + /// /// Constructor of the TimeSeriesOxyPlotBase class. /// Requires a line series name and an area series name. @@ -178,7 +212,7 @@ private void InitializeColorAxisValues() { Name = "maxValue", AutoSize = true, - Text = "auto", + Text = _valueMax.HasValue ? _valueMax.ToString() : "auto", }; maxValueTextBox.TextChanged += (sender, e) => @@ -186,15 +220,18 @@ private void InitializeColorAxisValues() if (double.TryParse(maxValueTextBox.Text, out double maxValue)) { colorAxis.Maximum = maxValue; + ValueMax = maxValue; } else if (maxValueTextBox.Text.ToLower() == "auto") { colorAxis.Maximum = double.NaN; maxValueTextBox.Text = "auto"; + ValueMax = null; } else { colorAxis.Maximum = heatMapSeries.MaxValue; + ValueMax = heatMapSeries.MaxValue; } UpdatePlot(); }; @@ -209,7 +246,7 @@ private void InitializeColorAxisValues() { Name = "minValue", AutoSize = true, - Text = "auto", + Text = _valueMin.HasValue ? _valueMin.ToString() : "auto", }; minValueTextBox.TextChanged += (sender, e) => @@ -217,15 +254,18 @@ private void InitializeColorAxisValues() if (double.TryParse(minValueTextBox.Text, out double minValue)) { colorAxis.Minimum = minValue; + ValueMin = minValue; } else if (minValueTextBox.Text.ToLower() == "auto") { colorAxis.Minimum = double.NaN; minValueTextBox.Text = "auto"; + ValueMin = null; } else { colorAxis.Minimum = heatMapSeries.MinValue; + ValueMin = heatMapSeries.MinValue; } UpdatePlot(); }; @@ -322,6 +362,29 @@ public void UpdateHeatMapSeries(double[,] data) heatMapSeries.Data = data; } + + /// + /// Method to update the heatmap series X axis range. + /// + /// + /// + public void UpdateXRange(double x0, double x1) + { + heatMapSeries.X0 = x0; + heatMapSeries.X1 = x1; + } + + /// + /// Method to update the heatmap series Y axis range. + /// + /// + /// + public void UpdateYRange(double y0, double y1) + { + heatMapSeries.Y0 = y0; + heatMapSeries.Y1 = y1; + } + /// /// Method to update the heatmap series with new data. /// @@ -332,11 +395,9 @@ public void UpdateHeatMapSeries(double[,] data) /// The data to be displayed. public void UpdateHeatMapSeries(double x0, double x1, double y0, double y1, double[,] data) { - heatMapSeries.X0 = x0; - heatMapSeries.X1 = x1; - heatMapSeries.Y0 = y0; - heatMapSeries.Y1 = y1; - heatMapSeries.Data = data; + UpdateXRange(x0, x1); + UpdateYRange(y0, y1); + UpdateHeatMapSeries(data); } /// From 68258184b4345d8d8722334ba2aa49e5c4a5eed4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 25 Feb 2025 15:15:49 +0000 Subject: [PATCH 124/131] Added a decoder visualizer interface --- .../IDecoderVisualizer.cs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/IDecoderVisualizer.cs diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/IDecoderVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/IDecoderVisualizer.cs new file mode 100644 index 00000000..ee80e44b --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/IDecoderVisualizer.cs @@ -0,0 +1,19 @@ +using Bonsai.ML.Design; + +namespace Bonsai.ML.PointProcessDecoder.Design; + +/// +/// Interface for visualizing the output of a point process decoder. +/// +public interface IDecoderVisualizer +{ + /// + /// Gets or sets the capacity of the visualizer. + /// + public int Capacity { get; set; } + + /// + /// Gets the heatmap that visualizes a component's output. + /// + public HeatMapSeriesOxyPlotBase Plot { get; } +} From cb7fc07a4b79efca672a95ff7d763f1f9946b76d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 25 Feb 2025 15:20:58 +0000 Subject: [PATCH 125/131] Added a class which will automatically cycle through different oxyplot colors --- .../OxyColorPresetCycle.cs | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs b/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs new file mode 100644 index 00000000..e2976cf4 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs @@ -0,0 +1,40 @@ + +using OxyPlot; + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + /// + /// Enumerates the colors and provides a preset collection of colors to cycle through. + /// + public class OxyColorPresetCycle + { + private static readonly OxyColor[] Colors = + [ + OxyColors.LimeGreen, + OxyColors.Red, + OxyColors.Blue, + OxyColors.Orange, + OxyColors.Purple, + OxyColors.Yellow, + OxyColors.Pink, + OxyColors.Brown, + OxyColors.Cyan, + OxyColors.Magenta, + OxyColors.Green, + OxyColors.Gray, + OxyColors.Black + ]; + + private int _index; + + /// + /// Gets the next color in the cycle. + /// + public OxyColor Next() + { + var color = Colors[_index]; + _index = (_index + 1) % Colors.Length; + return color; + } + } +} \ No newline at end of file From 2e2f21689672477a9b58429d732c7521e66060d3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 25 Feb 2025 15:24:16 +0000 Subject: [PATCH 126/131] Updated to use decoder visualizer interface and to use value min/max to set heatmap min/max from parent visualizer class --- .../LikelihoodVisualizer.cs | 45 ++++++++++--------- .../PosteriorVisualizer.cs | 43 +++++++++++++++--- 2 files changed, 62 insertions(+), 26 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs index 827240dd..52e73fe5 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs +++ b/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs @@ -24,7 +24,10 @@ namespace Bonsai.ML.PointProcessDecoder.Design { - public class LikelihoodVisualizer : MashupVisualizer + /// + /// Visualizer for the likelihood of a point process model. + /// + public class LikelihoodVisualizer : MashupVisualizer, IDecoderVisualizer { private MultidimensionalArrayVisualizer _visualizer; @@ -46,6 +49,16 @@ public int Capacity } } + /// + /// Gets or sets the minimum value of the likelihood. + /// + public double? ValueMin { get; set; } = null; + + /// + /// Gets or sets the maximum value of the likelihood. + /// + public double? ValueMax { get; set; } = null; + private double[,] _data = null; private long _stateSpaceWidth; private long _stateSpaceHeight; @@ -55,7 +68,6 @@ public int Capacity private ILikelihood _likelihood; private Tensor[] _intensities; private IObservable> _inputSource; - private Tensor _dataTensor; /// public override void Load(IServiceProvider provider) @@ -65,9 +77,6 @@ public override void Load(IServiceProvider provider) var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); if (expressionBuilderGraph != null && typeVisualizerContext != null) { - // decodeNode = ExpressionBuilder.GetWorkflowElement( - // expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) - // .FirstOrDefault()); visualizerNode = (from node in expressionBuilderGraph where node.Value == typeVisualizerContext.Source select node).FirstOrDefault(); @@ -202,22 +211,13 @@ public override IObservable Visualize(IObservable> s return source; } + var colorCycler = new OxyColorPresetCycle(); + var timer = Observable.Interval( TimeSpan.FromMilliseconds(100), HighResolutionScheduler.Default ); - // var mergedSource = source.SelectMany(xs => - // xs.Buffer(timer) - // .Where(buffer => buffer.Count > 0) - // .Do(buffer => { - // if (!UpdateModel()) - // { - // return; - // } - // Show(buffer.LastOrDefault()); - // })); - var mergedSource = _inputSource.SelectMany(xs => xs.Buffer(timer) .Where(buffer => buffer.Count > 0) @@ -227,16 +227,21 @@ public override IObservable Visualize(IObservable> s { return; } + ValueMin = _visualizer.Plot.ValueMin; + ValueMax = _visualizer.Plot.ValueMax; Show(buffer.LastOrDefault()); })); var mashupSourceStreams = Observable.Merge( MashupSources.Select(mashupSource => - mashupSource.Source.Output.SelectMany(xs => - xs.Buffer(timer) + mashupSource.Source.Output.SelectMany(xs => { + var color = colorCycler.Next(); + var visualizer = mashupSource.Visualizer as Point2DOverlay; + visualizer.Color = color; + return xs.Buffer(timer) .Where(buffer => buffer.Count > 0) - .Do(buffer => mashupSource.Visualizer.Show(buffer.LastOrDefault())) - ))); + .Do(buffer => mashupSource.Visualizer.Show(buffer.LastOrDefault())); + }))); return Observable.Merge(mergedSource, mashupSourceStreams); } diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs index 9230c734..e26266d1 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs +++ b/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs @@ -22,7 +22,7 @@ namespace Bonsai.ML.PointProcessDecoder.Design { - public class PosteriorVisualizer : MashupVisualizer + public class PosteriorVisualizer : MashupVisualizer, IDecoderVisualizer { private MultidimensionalArrayVisualizer _visualizer; @@ -44,6 +44,16 @@ public int Capacity } } + /// + /// Gets or sets the minimum value of the likelihood. + /// + public double? ValueMin { get; set; } = null; + + /// + /// Gets or sets the maximum value of the likelihood. + /// + public double? ValueMax { get; set; } = null; + private double[,] _data = null; private double[] _stateSpaceMin; private double[] _stateSpaceMax; @@ -79,18 +89,30 @@ public override void Load(IServiceProvider provider) _visualizer.Load(provider); + _visualizer.Plot.ValueMin = ValueMin; + _visualizer.Plot.ValueMax = ValueMax; + var capacityLabel = new ToolStripLabel { Text = "Capacity:", AutoSize = true }; - var capacityValue = new ToolStripLabel + + var capacityValue = new ToolStripTextBox { Text = Capacity.ToString(), AutoSize = true }; - _visualizer.Plot.StatusStrip.Items.AddRange([ + capacityValue.TextChanged += (sender, e) => + { + if (int.TryParse(capacityValue.Text, out int capacity)) + { + Capacity = capacity; + } + }; + + _visualizer.Plot.VisualizerPropertiesDropDown.DropDownItems.AddRange([ capacityLabel, capacityValue ]); @@ -213,6 +235,8 @@ public override IObservable Visualize(IObservable> s return source; } + var colorCycler = new OxyColorPresetCycle(); + var timer = Observable.Interval( TimeSpan.FromMilliseconds(100), HighResolutionScheduler.Default @@ -226,16 +250,23 @@ public override IObservable Visualize(IObservable> s { UpdateModel(); } + ValueMin = _visualizer.Plot.ValueMin; + ValueMax = _visualizer.Plot.ValueMax; Show(buffer.LastOrDefault()); })); var mashupSourceStreams = Observable.Merge( - MashupSources.Select(mashupSource => - mashupSource.Source.Output.SelectMany(xs => + MashupSources.Select(mashupSource => { + var color = colorCycler.Next(); + var visualizer = mashupSource.Visualizer as Point2DOverlay; + visualizer.Color = color; + return mashupSource.Source.Output.SelectMany(xs => xs.Buffer(timer) .Where(buffer => buffer.Count > 0) .Do(buffer => mashupSource.Visualizer.Show(buffer.LastOrDefault())) - ))); + ); + }) + ); return Observable.Merge(mergedSource, mashupSourceStreams); } From 7d091d3f94854e8662aff7f92c1b727598db2bab Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 25 Feb 2025 15:25:32 +0000 Subject: [PATCH 127/131] Updated point overlay class to allow mashup visualizer class to handle setting the color --- .../Point2DOverlay.cs | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs b/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs index 537b389e..a4f3f089 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs +++ b/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs @@ -25,20 +25,38 @@ namespace Bonsai.ML.PointProcessDecoder.Design /// public class Point2DOverlay : DialogTypeVisualizer { - internal LineSeries _lineSeries; - internal ScatterSeries _scatterSeries; - private int _capacity; + private LineSeries _lineSeries; + private ScatterSeries _scatterSeries; private int _dataCount; + private IDecoderVisualizer decoderVisualizer; + + private OxyColor _color = OxyColors.LimeGreen; + + /// + /// Gets or sets the color of the overlay. + /// + public OxyColor Color + { + get => _color; + set + { + if (_lineSeries != null && _scatterSeries != null) + { + _lineSeries.Color = value; + _scatterSeries.MarkerFill = value; + _color = value; + } + } + } /// public override void Load(IServiceProvider provider) { - dynamic service = provider.GetService(typeof(MashupVisualizer)); - _capacity = service.Capacity; + decoderVisualizer = provider.GetService(typeof(MashupVisualizer)) as IDecoderVisualizer; _lineSeries = new LineSeries() { - Color = OxyColors.LimeGreen, + Color = _color, StrokeThickness = 2 }; @@ -52,14 +70,13 @@ public override void Load(IServiceProvider provider) { MarkerType = MarkerType.Circle, MarkerSize = 10, - MarkerFill = OxyColors.LimeGreen, + MarkerFill = _color, ColorAxisKey = "Point2DOverlayColorAxis" }; - - service.Plot.Model.Series.Add(_scatterSeries); - service.Plot.Model.Series.Add(_lineSeries); - service.Plot.Model.Axes.Add(colorAxis); + decoderVisualizer.Plot.Model.Series.Add(_scatterSeries); + decoderVisualizer.Plot.Model.Series.Add(_lineSeries); + decoderVisualizer.Plot.Model.Axes.Add(colorAxis); } /// @@ -71,7 +88,7 @@ public override void Show(object value) _scatterSeries.Points.Clear(); _scatterSeries.Points.Add(new ScatterPoint(point.X, point.Y, value: 1)); - while (_dataCount > _capacity) + while (_dataCount > decoderVisualizer.Capacity) { _lineSeries.Points.RemoveAt(0); _dataCount--; From d7b16c09c70c52f1513ccdbd7f301239957459af Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 25 Feb 2025 17:15:28 +0000 Subject: [PATCH 128/131] Add `SumAcrossBatch` parameter to model constructor and decode function --- .../CreatePointProcessModel.cs | 20 ++++++++- src/Bonsai.ML.PointProcessDecoder/Decode.cs | 41 ++++++++++++++----- .../PointProcessModelManager.cs | 2 + 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs index 2e5015f3..9eec439a 100644 --- a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -25,7 +25,6 @@ namespace Bonsai.ML.PointProcessDecoder; public class CreatePointProcessModel { private string name = "PointProcessModel"; - /// /// Gets or sets the name of the point process model. /// @@ -374,6 +373,24 @@ public bool IgnoreNoSpikes } } + private bool sumAcrossBatch = true; + /// + /// Gets or sets a value indicating whether to sum across the batched likelihood. + /// + [Category("5. Likelihood Parameters")] + [Description("Indicates whether to sum across the batched likelihood.")] + public bool SumAcrossBatch + { + get + { + return sumAcrossBatch; + } + set + { + sumAcrossBatch = value; + } + } + private TransitionsType transitionsType = TransitionsType.RandomWalk; /// /// Gets or sets the type of transition model used during the decoding process. @@ -453,6 +470,7 @@ public IObservable Process() markChannels: markChannels, markBandwidth: markBandwidth, ignoreNoSpikes: ignoreNoSpikes, + sumAcrossBatch: sumAcrossBatch, nUnits: nUnits, distanceThreshold: distanceThreshold, sigmaRandomWalk: sigmaRandomWalk, diff --git a/src/Bonsai.ML.PointProcessDecoder/Decode.cs b/src/Bonsai.ML.PointProcessDecoder/Decode.cs index 361d1c12..7c6d215a 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Decode.cs +++ b/src/Bonsai.ML.PointProcessDecoder/Decode.cs @@ -38,6 +38,23 @@ public bool IgnoreNoSpikes } } + private bool _sumAcrossBatch = true; + private bool _updateSumAcrossBatch = false; + /// + /// Gets or sets a value indicating whether to ignore contributions from no spike events. + /// + [Description("Indicates whether to ignore contributions from no spike events.")] + public bool SumAcrossBatch + { + get => _sumAcrossBatch; + set + { + _sumAcrossBatch = value; + _updateSumAcrossBatch = true; + } + } + + /// /// Decodes the input neural data into a posterior state estimate using a point process model. /// @@ -46,20 +63,22 @@ public bool IgnoreNoSpikes public IObservable Process(IObservable source) { var modelName = Model; - return source.Select(input => { + return source.Select(input => + { var model = PointProcessModelManager.GetModel(modelName); - if (_updateIgnoreNoSpikes) { - if (model.Likelihood is ClusterlessLikelihood clusterlessLikelihood) - { - clusterlessLikelihood.IgnoreNoSpikes = _ignoreNoSpikes; - } - else if (model.Likelihood is PoissonLikelihood poissonLikelihood) - { - poissonLikelihood.IgnoreNoSpikes = _ignoreNoSpikes; - } - + if (_updateIgnoreNoSpikes) + { + model.Likelihood.IgnoreNoSpikes = _ignoreNoSpikes; _updateIgnoreNoSpikes = false; } + + if (_updateSumAcrossBatch) + { + + model.Likelihood.SumAcrossBatch = _sumAcrossBatch; + _updateSumAcrossBatch = false; + } + return model.Decode(input); }); } diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs index d756c1b5..c647d674 100644 --- a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -49,6 +49,7 @@ internal static PointProcessModelDisposable Reserve( int? markChannels = null, double[]? markBandwidth = null, bool ignoreNoSpikes = false, + bool sumAcrossBatch = true, int? nUnits = null, double? distanceThreshold = null, double? sigmaRandomWalk = null, @@ -73,6 +74,7 @@ internal static PointProcessModelDisposable Reserve( markChannels: markChannels, markBandwidth: markBandwidth, ignoreNoSpikes: ignoreNoSpikes, + sumAcrossBatch: sumAcrossBatch, nUnits: nUnits, distanceThreshold: distanceThreshold, sigmaRandomWalk: sigmaRandomWalk, From 0173558a1eacfde3056b6c5263f6a2d5e32f47ea Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 28 Feb 2025 08:58:04 +0000 Subject: [PATCH 129/131] Fixed issue with not being able to set the min/max heatmap values properly --- .../HeatMapSeriesOxyPlotBase.cs | 46 +++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs index 0f23f2ae..e37c0567 100644 --- a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs +++ b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs @@ -91,8 +91,6 @@ public double? ValueMin _valueMin = value; if (minValueTextBox != null) minValueTextBox.Text = value?.ToString(); - if (colorAxis != null) - colorAxis.Maximum = value ?? double.NaN; } } @@ -108,8 +106,6 @@ public double? ValueMax _valueMax = value; if (maxValueTextBox != null) maxValueTextBox.Text = value?.ToString(); - if (colorAxis != null) - colorAxis.Maximum = value ?? double.NaN; } } @@ -215,23 +211,30 @@ private void InitializeColorAxisValues() Text = _valueMax.HasValue ? _valueMax.ToString() : "auto", }; + var updateMaxValueText = true; + maxValueTextBox.TextChanged += (sender, e) => { + if (!updateMaxValueText) + { + updateMaxValueText = true; + return; + } + if (double.TryParse(maxValueTextBox.Text, out double maxValue)) { + _valueMax = maxValue; colorAxis.Maximum = maxValue; - ValueMax = maxValue; } - else if (maxValueTextBox.Text.ToLower() == "auto") + else if (string.IsNullOrEmpty(maxValueTextBox.Text)) { + _valueMax = null; colorAxis.Maximum = double.NaN; - maxValueTextBox.Text = "auto"; - ValueMax = null; } else { - colorAxis.Maximum = heatMapSeries.MaxValue; - ValueMax = heatMapSeries.MaxValue; + updateMaxValueText = false; + maxValueTextBox.Text = ""; } UpdatePlot(); }; @@ -249,23 +252,30 @@ private void InitializeColorAxisValues() Text = _valueMin.HasValue ? _valueMin.ToString() : "auto", }; + var updateMinValueText = true; + minValueTextBox.TextChanged += (sender, e) => { + if (!updateMinValueText) + { + updateMinValueText = true; + return; + } + if (double.TryParse(minValueTextBox.Text, out double minValue)) { + _valueMin = minValue; colorAxis.Minimum = minValue; - ValueMin = minValue; } - else if (minValueTextBox.Text.ToLower() == "auto") + else if (string.IsNullOrEmpty(minValueTextBox.Text)) { + _valueMin = null; colorAxis.Minimum = double.NaN; - minValueTextBox.Text = "auto"; - ValueMin = null; } else { - colorAxis.Minimum = heatMapSeries.MinValue; - ValueMin = heatMapSeries.MinValue; + updateMinValueText = false; + minValueTextBox.Text = ""; } UpdatePlot(); }; @@ -331,12 +341,12 @@ private void InitializeRenderMethod() renderMethodComboBox.Items.Add(value); } - renderMethodComboBox.SelectedIndexChanged += renderMethodComboBoxSelectedIndexChanged; + renderMethodComboBox.SelectedIndexChanged += RenderMethodComboBoxSelectedIndexChanged; renderMethodComboBox.SelectedIndex = _renderMethodSelectedIndex; UpdateRenderMethod(); } - private void renderMethodComboBoxSelectedIndexChanged(object sender, EventArgs e) + private void RenderMethodComboBoxSelectedIndexChanged(object sender, EventArgs e) { if (renderMethodComboBox.SelectedIndex != _renderMethodSelectedIndex) { From 0ddb210c1ae978b9b807971b118af3fa1017916c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 28 Feb 2025 08:58:35 +0000 Subject: [PATCH 130/131] Updated `PointProcessDecoder.Core` package version --- .../Bonsai.ML.PointProcessDecoder.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj index 3608f916..3b9d62b1 100644 --- a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -11,6 +11,6 @@ - + \ No newline at end of file From 72f43aa8f2f9a17984ce5f6ff9cee6e08469e5bf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 3 Mar 2025 17:54:24 +0000 Subject: [PATCH 131/131] Updated likelihood visualizer to place capacity property into visualizer properties dropdown menu --- .../LikelihoodVisualizer.cs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs index 52e73fe5..9e09ab91 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs +++ b/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs @@ -114,13 +114,22 @@ public override void Load(IServiceProvider provider) Text = "Capacity:", AutoSize = true }; - var capacityValue = new ToolStripLabel + + var capacityValue = new ToolStripTextBox { Text = Capacity.ToString(), AutoSize = true }; - _visualizer.Plot.StatusStrip.Items.AddRange([ + capacityValue.TextChanged += (sender, e) => + { + if (int.TryParse(capacityValue.Text, out int capacity)) + { + Capacity = capacity; + } + }; + + _visualizer.Plot.VisualizerPropertiesDropDown.DropDownItems.AddRange([ capacityLabel, capacityValue ]);