diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index c5a91b13..9090db9f 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -1,4 +1,4 @@ - + Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio Version 17 VisualStudioVersion = 17.0.31903.59 @@ -30,6 +30,12 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.LinearDynamicalSy EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.HiddenMarkovModels.Design", "src\Bonsai.ML.HiddenMarkovModels.Design\Bonsai.ML.HiddenMarkovModels.Design.csproj", "{FC395DDC-62A4-4E14-A198-272AB05B33C7}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch", "src\Bonsai.ML.Torch\Bonsai.ML.Torch.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PointProcessDecoder", "src\Bonsai.ML.PointProcessDecoder\Bonsai.ML.PointProcessDecoder.csproj", "{AD32C680-1E8C-4340-81B1-DA19C9104516}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PointProcessDecoder.Design", "src\Bonsai.ML.PointProcessDecoder.Design\Bonsai.ML.PointProcessDecoder.Design.csproj", "{91C3E252-9457-43AB-A21A-6064E2404BAA}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -60,18 +66,30 @@ Global {39A4414F-52B1-42D7-82FA-E65DAD885264}.Debug|Any CPU.Build.0 = Debug|Any CPU {39A4414F-52B1-42D7-82FA-E65DAD885264}.Release|Any CPU.ActiveCfg = Release|Any CPU {39A4414F-52B1-42D7-82FA-E65DAD885264}.Release|Any CPU.Build.0 = Release|Any CPU - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.Build.0 = Debug|Any CPU - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.ActiveCfg = Release|Any CPU - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.Build.0 = Release|Any CPU - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Debug|Any CPU.Build.0 = Debug|Any CPU - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Release|Any CPU.ActiveCfg = Release|Any CPU - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Release|Any CPU.Build.0 = Release|Any CPU - {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.Build.0 = Debug|Any CPU - {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.ActiveCfg = Release|Any CPU - {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.Build.0 = Release|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.Build.0 = Release|Any CPU + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1}.Release|Any CPU.Build.0 = Release|Any CPU + {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.Build.0 = Release|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.Build.0 = Debug|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.ActiveCfg = Release|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.Build.0 = Release|Any CPU + {AD32C680-1E8C-4340-81B1-DA19C9104516}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AD32C680-1E8C-4340-81B1-DA19C9104516}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AD32C680-1E8C-4340-81B1-DA19C9104516}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AD32C680-1E8C-4340-81B1-DA19C9104516}.Release|Any CPU.Build.0 = Release|Any CPU + {91C3E252-9457-43AB-A21A-6064E2404BAA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {91C3E252-9457-43AB-A21A-6064E2404BAA}.Debug|Any CPU.Build.0 = Debug|Any CPU + {91C3E252-9457-43AB-A21A-6064E2404BAA}.Release|Any CPU.ActiveCfg = Release|Any CPU + {91C3E252-9457-43AB-A21A-6064E2404BAA}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -84,8 +102,11 @@ Global {81DB65B3-EA65-4947-8CF1-0E777324C082} = {461FE3E2-21C4-47F9-8405-DF72326AAB2B} {BAD0A733-8EFB-4EAF-9648-9851656AF7FF} = {12312384-8828-4786-AE19-EFCEDF968290} {39A4414F-52B1-42D7-82FA-E65DAD885264} = {12312384-8828-4786-AE19-EFCEDF968290} - {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290} - {17DF50BE-F481-4904-A4C8-5DF9725B2CA1} = {12312384-8828-4786-AE19-EFCEDF968290} + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290} + {17DF50BE-F481-4904-A4C8-5DF9725B2CA1} = {12312384-8828-4786-AE19-EFCEDF968290} + {06FCC9AF-CE38-44BB-92B3-0D451BE88537} = {12312384-8828-4786-AE19-EFCEDF968290} + {AD32C680-1E8C-4340-81B1-DA19C9104516} = {12312384-8828-4786-AE19-EFCEDF968290} + {91C3E252-9457-43AB-A21A-6064E2404BAA} = {12312384-8828-4786-AE19-EFCEDF968290} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {B6468F13-97CD-45E0-9E1E-C122D7F1E09F} diff --git a/README.md b/README.md index 1b27738a..60864b6d 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,12 @@ Facilitates inference using Hidden Markov Models (HMMs). It interfaces with the ### Bonsai.ML.HiddenMarkovModels.Design Visualizers and editor features for the HiddenMarkovModels package. +### Bonsai.ML.PointProcessDecoder +Interfaces with the [PointProcessDecoder](https://github.com/ncguilbeault/PointProcessDecoder) package which can be used for decoding neural activity and other point processes. + +### Bonsai.ML.Torch +Interfaces with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# wrapper around the torch library. Provides tooling for manipulating tensors, performing linear algebra, training and inference with deep neural networks, and more. + > [!NOTE] > Bonsai.ML packages can be installed through Bonsai's integrated package manager and are generally ready for immediate use. However, some packages may require additional installation steps. Refer to the specific package section for detailed installation guides and documentation. diff --git a/docs/articles/PointProcessDecoder/ppd-getting-started.md b/docs/articles/PointProcessDecoder/ppd-getting-started.md new file mode 100644 index 00000000..7006f682 --- /dev/null +++ b/docs/articles/PointProcessDecoder/ppd-getting-started.md @@ -0,0 +1,27 @@ +# Getting Started + +The `Bonsai.ML.PointProcessDecoder` package provides a Bonsai interface to the [PointProcessDecoder](https://github.com/ncguilbeault/PointProcessDecoder) package used for decoding neural activity (point processes), and relies on the `Bonsai.ML.Torch` package for tensor operations. + +## Installation + +The package can be installed by going to the bonsai package manager and installing the `Bonsai.ML.PointProcessDecoder`. Additional installation steps are required for installing the CPU or GPU version of the `Bonsai.ML.Torch` package. See the [Torch installation guide](../Torch/torch-overview.md) for more information. + +## Package Overview + +The `PointProcessDecoder` package is a C# implementation of a Bayesian state space point process decoder inspired by the [replay_trajectory_classification repository](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification) from the Eden-Kramer Lab. It can decode latent state observations from spike-train data or clusterless mark data based on point processes using Bayesian state space models. + +For more detailed information and documentation about the model, please see the [PointProcessDecoder repo](https://github.com/ncguilbeault/PointProcessDecoder). + +## Bonsai Implementation + +The following workflow showcases the core functionality of the `Bonsai.ML.PointProcessDecoder` package. + +:::workflow +![Point Process Decoder Implementation](~/workflows/PointProcessDecoder.bonsai) +::: + +The `CreatePointProcessModel` node is used to define a model and configure it's parameters. For details on model configuration, please see the [PointProcessDecoder documentation](https://github.com/ncguilbeault/PointProcessDecoder). Crucially, the user must specify the `Name` property in the `Model Parameters` section, as this is what allows you to reference the specific model in the `Encode` and `Decode` nodes, the two main methods that the model will use. + +During encoding, the user passes in a tuple of `Observation`s and `SpikeCounts`. Observations are variables that the user measures (for example, the animal's position), represented as a (M, N) tensor, where M is the number of samples in a batch and N is the dimensionality of the observations. Spike counts are the data you will use for decoding. For instance, spike counts might be sorted spike data or clusterless marks. If the data are sorted spiking units, then the `SpikeCounts` tensor will be an (M, U) tensor, where U is the number of sorted units. If the data are clusterles marks, then the `SpikeCounts` tensor will be an (M, F, C) tensor, where F is the number of features computed for each mark (for instance, the maximum spike amplitude across electrodes), and C is the number of independant recording channels (for example, individual tetrodes). Internally, the model fits the data as a point process, which will be used during decoding. + +Decoding is the process of taking just the `SpikeCounts` and inferring what is the latent `Observation`. To do this, the model uses a bayesian state space model to predict a posterior distribution over the latent observation space using the information contained in the spike counts data. The output of the `Decode` node will provide you with an (M x D*) tensor, with D* representing a discrete latent state space determined by the parameters defined in the model. \ No newline at end of file diff --git a/docs/articles/Torch/torch-getting-started.md b/docs/articles/Torch/torch-getting-started.md new file mode 100644 index 00000000..6a673b9b --- /dev/null +++ b/docs/articles/Torch/torch-getting-started.md @@ -0,0 +1,13 @@ +# Getting Started + +The aim of the `Bonsai.ML.Torch` package is to integrate the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# wrapper around the powerful libtorch library, into Bonsai. In the current version, the package primarily provides tooling and functionality for users to interact with and manipulate `Tensor`s, the core data type of libtorch which underlies many of the advanced torch operations. Additionally, the package provides some capabilities for defining neural network architectures, running forward inference, and learning via back propogation. + +## Tensor Operations +The package provides several ways to work with tensors. Users can initialize tensors, (`Ones`, `Zeros`, etc.), create tensors from .NET data types, (`ToTensor`), and define custom tensors using Python-like syntax (`CreateTensor`). Tensors can be converted back to .NET types using the `ToArray` node (for flattening tensors into a single array) or the `ToNDArray` node (for preserving multidimensional array shapes). Furthermore, the `Tensor` data types contains many extension methods which can be used via scripting, such as using `ExpressionTransform` (for example, it.sum() to sum a tensor, or it.T to transpose), and works with overloaded operators, for example, `Zip` -> `Multiply`. Thus, `ExpressionTransform` can also be used to access individual elements of a tensor, using the syntax `it.data.ReadCpuT(0)` where `T` is a primitive .NET data type. + + +## Running on the GPU +Users must be explicit about running tensors on the GPU. First, the `InitializeDeviceType` node must run with a CUDA-compatible GPU. Afterwards, tensors are moved to the GPU using the `ToDevice` node. Converting tensors back to .NET data types requires moving the tensor back to the CPU before converting. + +## Neural Networks +The package provides initial support for working with torch `Module`s, the conventional object for deep neural networks. The `LoadModuleFromArchitecture` node allows users to select from a list of common architectures, and can optionally load in pretrained weights from disk. Additionally, the package supports loading `TorchScript` modules with the `LoadScriptModule` node, which enables users to use torch modules saved in the `.pt` file format. Users can then use the `Forward` node to run inference and the `Backward` node to run back propogation. \ No newline at end of file diff --git a/docs/articles/Torch/torch-overview.md b/docs/articles/Torch/torch-overview.md new file mode 100644 index 00000000..0884ea01 --- /dev/null +++ b/docs/articles/Torch/torch-overview.md @@ -0,0 +1,27 @@ +# Bonsai.ML.Torch Overview + +The Torch package provides a Bonsai interface to interact with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# implementation of the torch library. + +## General Guide + +The Bonsai.ML.Torch package can be installed through the Bonsai Package Manager and depends on the TorchSharp library. Additionally, running the package requires installing the specific torch DLLs needed for your desired application. The steps for installing are outlined below. + +### Running on the CPU +For running the package using the CPU, the `TorchSharp-cpu` library can be installed though the `nuget` package source. + +### Running on the GPU +To run torch on the GPU, you first need to ensure that you have a CUDA compatible device installed on your system. + +Next, you must follow the [CUDA installation guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) or the [guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html). Make sure to install the correct `CUDA v12.1` version [found here](https://developer.nvidia.com/cuda-12-1-0-download-archive). Ensure that you have the correct CUDA version (v12.1) installed, as `TorchSharp` currently only supports this version. + +Next, you need to install the `cuDNN v9` library following the [guide for Windows](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/windows.html) or the [guide for Linux](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/linux.html). Again, you need to ensure you have the correct version installed (v9). You should consult [nvidia's support matrix](https://docs.nvidia.com/deeplearning/cudnn/latest/reference/support-matrix.html) to ensure the versions of CUDA and cuDNN you installed are compatible with your specific OS, graphics driver, and hardware. + +Once complete, you need to install the cuda-compatible torch libraries and place them into the correct location. You can download the libraries from [the pytorch website](https://pytorch.org/get-started/locally/) with the following options selected: + +- PyTorch Build: Stable (2.5.1) +- OS: [Your OS] +- Package: LibTorch +- Language: C++/Java +- Compute Platform: CUDA 12.1 + +Finally, extract the zip folder and copy the contents of the `lib` folder into the `Extensions` folder of your bonsai installation directory. \ No newline at end of file diff --git a/docs/articles/toc.yml b/docs/articles/toc.yml index e22b0b80..ddb774cc 100644 --- a/docs/articles/toc.yml +++ b/docs/articles/toc.yml @@ -13,4 +13,12 @@ - name: Overview href: HiddenMarkovModels/hmm-overview.md - name: Getting Started - href: HiddenMarkovModels/hmm-getting-started.md \ No newline at end of file + href: HiddenMarkovModels/hmm-getting-started.md +- name: PointProcessDecoder +- name: Getting Started + href: PointProcessDecoder/ppd-getting-started.md +- name: Torch +- name: Overview + href: Torch/torch-overview.md +- name: Getting Started + href: Torch/torch-getting-started.md \ No newline at end of file diff --git a/docs/workflows/PointProcessDecoder.bonsai b/docs/workflows/PointProcessDecoder.bonsai new file mode 100644 index 00000000..b13680cd --- /dev/null +++ b/docs/workflows/PointProcessDecoder.bonsai @@ -0,0 +1,73 @@ + + + + + + + PointProcessModel + KernelDensity + RandomWalk + SortedSpikeEncoder + StateSpaceDecoder + DiscreteUniformStateSpace + Poisson + + 2 + + 0 + 0 + + + 120 + 120 + + + 50 + 50 + + + 5 + 5 + + 104 + + + 1.5 + + + + + Observation + + + SpikeCounts + + + + + + + PointProcessModel + + + + SortedSpikeCounts + + + + PointProcessModel + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs index d6944816..e37c0567 100644 --- a/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs +++ b/src/Bonsai.ML.Design/HeatMapSeriesOxyPlotBase.cs @@ -14,10 +14,20 @@ namespace Bonsai.ML.Design /// public class HeatMapSeriesOxyPlotBase : UserControl { - private PlotView view; - private PlotModel model; + private PlotView _view; + /// + /// Gets the plot view of the control. + /// + public PlotView View => _view; + + private PlotModel _model; + /// + /// Gets the plot model of the control. + /// + public PlotModel Model => _model; + private HeatMapSeries heatMapSeries; - private LinearColorAxis colorAxis; + private LinearColorAxis colorAxis = null; private ToolStripComboBox paletteComboBox; private ToolStripLabel paletteLabel; @@ -30,12 +40,18 @@ public class HeatMapSeriesOxyPlotBase : UserControl private HeatMapRenderMethod renderMethod = HeatMapRenderMethod.Bitmap; private StatusStrip statusStrip; - private ToolStripTextBox maxValueTextBox; + private ToolStripTextBox maxValueTextBox = null; private ToolStripLabel maxValueLabel; - private ToolStripTextBox minValueTextBox; + private ToolStripTextBox minValueTextBox = null; private ToolStripLabel minValueLabel; + private ToolStripDropDownButton _visualizerPropertiesDropDown; + /// + /// Gets the visualizer properties drop down button. + /// + public ToolStripDropDownButton VisualizerPropertiesDropDown => _visualizerPropertiesDropDown; + private int _numColors = 100; /// @@ -63,13 +79,47 @@ public class HeatMapSeriesOxyPlotBase : UserControl /// public StatusStrip StatusStrip => statusStrip; + private double? _valueMin = null; + /// + /// Gets or sets the minimum value of the color axis. + /// + public double? ValueMin + { + get => _valueMin; + set + { + _valueMin = value; + if (minValueTextBox != null) + minValueTextBox.Text = value?.ToString(); + } + } + + private double? _valueMax = null; + /// + /// Gets or sets the maximum value of the color axis. + /// + public double? ValueMax + { + get => _valueMax; + set + { + _valueMax = value; + if (maxValueTextBox != null) + maxValueTextBox.Text = value?.ToString(); + } + } + /// /// Constructor of the TimeSeriesOxyPlotBase class. /// Requires a line series name and an area series name. /// Data source is optional, since pasing it to the constructor will populate the combobox and leave it empty otherwise. /// The selected index is only needed when the data source is provided. /// - public HeatMapSeriesOxyPlotBase(int paletteSelectedIndex, int renderMethodSelectedIndex, int numColors = 100) + public HeatMapSeriesOxyPlotBase( + int paletteSelectedIndex, + int renderMethodSelectedIndex, + int numColors = 100 + ) { _paletteSelectedIndex = paletteSelectedIndex; _renderMethodSelectedIndex = renderMethodSelectedIndex; @@ -79,12 +129,12 @@ public HeatMapSeriesOxyPlotBase(int paletteSelectedIndex, int renderMethodSelect private void Initialize() { - view = new PlotView + _view = new PlotView { Dock = DockStyle.Fill, }; - model = new PlotModel(); + _model = new PlotModel(); heatMapSeries = new HeatMapSeries { X0 = 0, @@ -100,11 +150,11 @@ private void Initialize() Position = AxisPosition.Right, }; - model.Axes.Add(colorAxis); - model.Series.Add(heatMapSeries); + _model.Axes.Add(colorAxis); + _model.Series.Add(heatMapSeries); - view.Model = model; - Controls.Add(view); + _view.Model = _model; + Controls.Add(_view); InitializeColorPalette(); InitializeRenderMethod(); @@ -126,17 +176,23 @@ private void Initialize() minValueTextBox }; - ToolStripDropDownButton visualizerPropertiesButton = new ToolStripDropDownButton("Visualizer Properties"); + _visualizerPropertiesDropDown = new ToolStripDropDownButton("Visualizer Properties"); foreach (var item in toolStripItems) { - visualizerPropertiesButton.DropDownItems.Add(item); + _visualizerPropertiesDropDown.DropDownItems.Add(item); } - statusStrip.Items.Add(visualizerPropertiesButton); + statusStrip.Items.Add(_visualizerPropertiesDropDown); Controls.Add(statusStrip); - view.MouseClick += new MouseEventHandler(onMouseClick); + _view.MouseClick += (sender, e) => { + if (e.Button == MouseButtons.Right) + { + statusStrip.Visible = !statusStrip.Visible; + } + }; + AutoScaleDimensions = new SizeF(6F, 13F); } @@ -152,23 +208,33 @@ private void InitializeColorAxisValues() { Name = "maxValue", AutoSize = true, - Text = "auto", + Text = _valueMax.HasValue ? _valueMax.ToString() : "auto", }; + var updateMaxValueText = true; + maxValueTextBox.TextChanged += (sender, e) => { + if (!updateMaxValueText) + { + updateMaxValueText = true; + return; + } + if (double.TryParse(maxValueTextBox.Text, out double maxValue)) { + _valueMax = maxValue; colorAxis.Maximum = maxValue; } - else if (maxValueTextBox.Text.ToLower() == "auto") + else if (string.IsNullOrEmpty(maxValueTextBox.Text)) { + _valueMax = null; colorAxis.Maximum = double.NaN; - maxValueTextBox.Text = "auto"; } else { - colorAxis.Maximum = heatMapSeries.MaxValue; + updateMaxValueText = false; + maxValueTextBox.Text = ""; } UpdatePlot(); }; @@ -183,23 +249,33 @@ private void InitializeColorAxisValues() { Name = "minValue", AutoSize = true, - Text = "auto", + Text = _valueMin.HasValue ? _valueMin.ToString() : "auto", }; + var updateMinValueText = true; + minValueTextBox.TextChanged += (sender, e) => { + if (!updateMinValueText) + { + updateMinValueText = true; + return; + } + if (double.TryParse(minValueTextBox.Text, out double minValue)) { + _valueMin = minValue; colorAxis.Minimum = minValue; } - else if (minValueTextBox.Text.ToLower() == "auto") + else if (string.IsNullOrEmpty(minValueTextBox.Text)) { + _valueMin = null; colorAxis.Minimum = double.NaN; - minValueTextBox.Text = "auto"; } else { - colorAxis.Minimum = heatMapSeries.MinValue; + updateMinValueText = false; + minValueTextBox.Text = ""; } UpdatePlot(); }; @@ -265,12 +341,12 @@ private void InitializeRenderMethod() renderMethodComboBox.Items.Add(value); } - renderMethodComboBox.SelectedIndexChanged += renderMethodComboBoxSelectedIndexChanged; + renderMethodComboBox.SelectedIndexChanged += RenderMethodComboBoxSelectedIndexChanged; renderMethodComboBox.SelectedIndex = _renderMethodSelectedIndex; UpdateRenderMethod(); } - private void renderMethodComboBoxSelectedIndexChanged(object sender, EventArgs e) + private void RenderMethodComboBoxSelectedIndexChanged(object sender, EventArgs e) { if (renderMethodComboBox.SelectedIndex != _renderMethodSelectedIndex) { @@ -287,14 +363,6 @@ private void UpdateRenderMethod() heatMapSeries.RenderMethod = renderMethod; } - private void onMouseClick(object sender, MouseEventArgs e) - { - if (e.Button == MouseButtons.Right) - { - statusStrip.Visible = !statusStrip.Visible; - } - } - /// /// Method to update the heatmap series with new data. /// @@ -304,6 +372,29 @@ public void UpdateHeatMapSeries(double[,] data) heatMapSeries.Data = data; } + + /// + /// Method to update the heatmap series X axis range. + /// + /// + /// + public void UpdateXRange(double x0, double x1) + { + heatMapSeries.X0 = x0; + heatMapSeries.X1 = x1; + } + + /// + /// Method to update the heatmap series Y axis range. + /// + /// + /// + public void UpdateYRange(double y0, double y1) + { + heatMapSeries.Y0 = y0; + heatMapSeries.Y1 = y1; + } + /// /// Method to update the heatmap series with new data. /// @@ -314,11 +405,9 @@ public void UpdateHeatMapSeries(double[,] data) /// The data to be displayed. public void UpdateHeatMapSeries(double x0, double x1, double y0, double y1, double[,] data) { - heatMapSeries.X0 = x0; - heatMapSeries.X1 = x1; - heatMapSeries.Y0 = y0; - heatMapSeries.Y1 = y1; - heatMapSeries.Data = data; + UpdateXRange(x0, x1); + UpdateYRange(y0, y1); + UpdateHeatMapSeries(data); } /// @@ -326,7 +415,7 @@ public void UpdateHeatMapSeries(double x0, double x1, double y0, double y1, doub /// public void UpdatePlot() { - model.InvalidatePlot(true); + _model.InvalidatePlot(true); } private static readonly Dictionary> paletteLookup = new Dictionary> diff --git a/src/Bonsai.ML.Design/MultidimensionalArrayVisualizer.cs b/src/Bonsai.ML.Design/MultidimensionalArrayVisualizer.cs index a2ea15ab..1101c601 100644 --- a/src/Bonsai.ML.Design/MultidimensionalArrayVisualizer.cs +++ b/src/Bonsai.ML.Design/MultidimensionalArrayVisualizer.cs @@ -23,24 +23,25 @@ public class MultidimensionalArrayVisualizer : DialogTypeVisualizer /// public int RenderMethodSelectedIndex { get; set; } - private HeatMapSeriesOxyPlotBase Plot; + private HeatMapSeriesOxyPlotBase _plot; + /// + /// Gets the HeatMapSeriesOxyPlotBase control used to display the heatmap. + /// + public HeatMapSeriesOxyPlotBase Plot => _plot; /// public override void Load(IServiceProvider provider) { - Plot = new HeatMapSeriesOxyPlotBase(PaletteSelectedIndex, RenderMethodSelectedIndex) + _plot = new HeatMapSeriesOxyPlotBase(PaletteSelectedIndex, RenderMethodSelectedIndex) { Dock = DockStyle.Fill, }; - Plot.PaletteComboBoxValueChanged += PaletteIndexChanged; - Plot.RenderMethodComboBoxValueChanged += RenderMethodIndexChanged; + _plot.PaletteComboBoxValueChanged += PaletteIndexChanged; + _plot.RenderMethodComboBoxValueChanged += RenderMethodIndexChanged; var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); - if (visualizerService != null) - { - visualizerService.AddControl(Plot); - } + visualizerService?.AddControl(_plot); } /// diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj b/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj new file mode 100644 index 00000000..8b3fb9a1 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/Bonsai.ML.PointProcessDecoder.Design.csproj @@ -0,0 +1,16 @@ + + + Bonsai.ML.PointProcessDecoder.Design + A package for visualizing data from the Bonsai.ML.PointProcessDecoder package. + Bonsai Rx ML Machine Learning Point Process Decoder Design + net472 + true + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/DensityEstimationsVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/DensityEstimationsVisualizer.cs new file mode 100644 index 00000000..63669df2 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/DensityEstimationsVisualizer.cs @@ -0,0 +1,460 @@ +using System; +using System.Reactive.Linq; +using System.Reactive; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; + +using Bonsai; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.DensityEstimationsVisualizer), + Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + public class DensityEstimationsVisualizer : DialogTypeVisualizer + { + private int _rowCount = 1; + public int RowCount + { + get => _rowCount; + set + { + if (value < 1) + { + throw new InvalidOperationException("The number of rows must be greater than 0."); + } + _rowCount = value; + } + } + + private int _columnCount = 1; + public int ColumnCount + { + get => _columnCount; + set + { + if (value < 1) + { + throw new InvalidOperationException("The number of columns must be greater than 0."); + } + _columnCount = value; + } + } + + private int _selectedPageIndex = 0; + public int SelectedPageIndex + { + get => _selectedPageIndex; + set + { + _selectedPageIndex = value; + } + } + + private readonly int _sampleFrequency = 5; + private int _pageCount = 1; + private string _modelName = string.Empty; + private List _heatmapPlots = null; + private int _estimationsCount = 0; + private TableLayoutPanel _container = null; + private StatusStrip _statusStrip = null; + public StatusStrip StatusStrip => _statusStrip; + private ToolStripNumericUpDown _pageIndexControl = null; + private ToolStripNumericUpDown _rowControl = null; + private ToolStripNumericUpDown _columnControl = null; + private Tensor[] _estimations = null; + private long _stateSpaceWidth; + private long _stateSpaceHeight; + private double[] _stateSpaceMin; + private double[] _stateSpaceMax; + private bool _isProcessing = false; + + /// + public override void Load(IServiceProvider provider) + { + Decode decodeNode = null; + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + decodeNode = ExpressionBuilder.GetWorkflowElement( + expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) + .FirstOrDefault().Value) as Decode; + } + + if (decodeNode == null) + { + throw new InvalidOperationException("The decode node is invalid."); + } + + _modelName = decodeNode.Model; + if (string.IsNullOrEmpty(_modelName)) + { + throw new InvalidOperationException("The point process model name is not set."); + } + + _container = new TableLayoutPanel() + { + Dock = DockStyle.Fill, + AutoSize = true, + ColumnCount = ColumnCount, + RowCount = _rowCount, + }; + + var pageIndexLabel = new ToolStripLabel($"Page: {_selectedPageIndex}"); + _pageIndexControl = new ToolStripNumericUpDown() + { + Minimum = 0, + DecimalPlaces = 0, + Value = _selectedPageIndex, + }; + + _pageIndexControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + + var value = Convert.ToInt32(_pageIndexControl.Value); + SelectedPageIndex = value; + UpdateTableLayout(); + if (UpdateModel()) + { + Show(null); + } + pageIndexLabel.Text = $"Page: {_selectedPageIndex}"; + }; + + var rowLabel = new ToolStripLabel($"Rows: {_rowCount}"); + _rowControl = new ToolStripNumericUpDown() + { + Minimum = 1, + DecimalPlaces = 0, + Value = _rowCount, + }; + + _rowControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + + RowCount = Convert.ToInt32(_rowControl.Value); + UpdatePages(); + if (_selectedPageIndex >= _pageCount) + { + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Value = _selectedPageIndex; + } + else + { + UpdateTableLayout(); + if (UpdateModel()) + { + Show(null); + } + } + rowLabel.Text = $"Rows: {_rowCount}"; + }; + + var columnLabel = new ToolStripLabel($"Columns: {_columnCount}"); + _columnControl = new ToolStripNumericUpDown() + { + Minimum = 1, + DecimalPlaces = 0, + Value = _columnCount, + }; + + _columnControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + + ColumnCount = Convert.ToInt32(_columnControl.Value); + UpdatePages(); + if (_selectedPageIndex >= _pageCount) + { + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Value = _selectedPageIndex; + } + else + { + UpdateTableLayout(); + if (UpdateModel()) + { + Show(null); + } + } + columnLabel.Text = $"Columns: {_columnCount}"; + }; + + _statusStrip = new StatusStrip() + { + Visible = true, + }; + + _statusStrip.Items.AddRange([ + pageIndexLabel, + _pageIndexControl, + rowLabel, + _rowControl, + columnLabel, + _columnControl + ]); + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + visualizerService?.AddControl(_container); + visualizerService?.AddControl(_statusStrip); + } + + private void UpdatePages() + { + _pageCount = (int)Math.Ceiling((double)_estimationsCount / (_rowCount * _columnCount)); + _pageIndexControl.Maximum = _pageCount - 1; + } + + private bool UpdateModel() + { + _isProcessing = true; + PointProcessModel model; + + try { + model = PointProcessModelManager.GetModel(_modelName); + } catch { + return false; + } + + if (model == null) + { + return false; + } + + if (model.StateSpace.Dimensions != 2) + { + throw new InvalidOperationException("For the conditional intensities visualizer to work, the state space dimensions must be 2."); + } + + if (model.Encoder.Estimations.Length == 0) + { + return false; + } + + + _estimationsCount = model.Encoder.Estimations.Length; + _estimations = new Tensor[_estimationsCount]; + + var startIndex = SelectedPageIndex * _rowCount * _columnCount; + var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _estimationsCount); + + for (int i = startIndex; i < endIndex; i++) + { + var estimate = model.Encoder.Estimations[i].Estimate( + model.StateSpace.Points, + null, + model.StateSpace.Dimensions + ); + + if (estimate.NumberOfElements == 0) { + _estimations[i] = ones([model.StateSpace.Points.size(0), 1]) * double.NaN; + } else { + _estimations[i] = model.Encoder.Estimations[i].Normalize(estimate); + } + } + + _stateSpaceWidth = model.StateSpace.Shape[0]; + _stateSpaceHeight = model.StateSpace.Shape[1]; + + _stateSpaceMin = [.. model.StateSpace.Points + .min(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + _stateSpaceMax = [.. model.StateSpace.Points + .max(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + // GC.KeepAlive(model); + _isProcessing = false; + + return true; + } + + private bool UpdateHeatmaps() + { + if (_heatmapPlots is null) + { + _heatmapPlots = []; + for (int i = 0; i < _estimationsCount; i++) + { + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(1, 0) + { + Dock = DockStyle.Fill, + }); + } + } + else if (_heatmapPlots.Count > _estimationsCount) + { + var count = _heatmapPlots.Count - _estimationsCount; + for (int i = 0; i < count; i++) + { + if (!_heatmapPlots[i + _estimationsCount].IsDisposed) + { + _heatmapPlots[i + _estimationsCount].Dispose(); + } + } + _heatmapPlots.RemoveRange(_estimationsCount, count); + } + else if (_heatmapPlots.Count < _estimationsCount) + { + for (int i = _heatmapPlots.Count; i < _estimationsCount; i++) + { + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(1, 0) + { + Dock = DockStyle.Fill, + }); + } + } + + return true; + } + + private void UpdateTableLayout() + { + _container.Controls.Clear(); + _container.RowStyles.Clear(); + _container.ColumnStyles.Clear(); + + _container.RowCount = _rowCount; + _container.ColumnCount = _columnCount; + + for (int i = 0; i < _rowCount; i++) + { + _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / _rowCount)); + } + + for (int i = 0; i < _columnCount; i++) + { + _container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / _columnCount)); + } + + for (int i = 0; i < _rowCount; i++) + { + for (int j = 0; j < _columnCount; j++) + { + var index = SelectedPageIndex * _rowCount * _columnCount + i * _columnCount + j; + if (index >= _estimationsCount) + { + break; + } + + _container.Controls.Add(_heatmapPlots[index], j, i); + } + } + } + + /// + public override void Show(object value) + { + var startIndex = SelectedPageIndex * _rowCount * _columnCount; + var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _estimationsCount); + + for (int i = startIndex; i < endIndex; i++) + { + var estimation = _estimations[i]; + + var estimationValues = (double[,])estimation + .to_type(ScalarType.Float64) + .reshape([_stateSpaceWidth, _stateSpaceHeight]) + .data() + .ToNDArray(); + + _heatmapPlots[i].UpdateHeatMapSeries( + _stateSpaceMin[0], + _stateSpaceMax[0], + _stateSpaceMin[1], + _stateSpaceMax[1], + estimationValues + ); + + _heatmapPlots[i].UpdatePlot(); + } + } + + /// + public override void Unload() + { + if (_container != null) + { + if (!_container.IsDisposed) + { + _container.Dispose(); + } + _container = null; + } + + if (_heatmapPlots != null) + { + for (int i = 0; i < _heatmapPlots.Count; i++) + { + if (!_heatmapPlots[i].IsDisposed) + { + _heatmapPlots[i].Dispose(); + } + } + _heatmapPlots = null; + }; + + _estimationsCount = 0; + _estimations = null; + } + + public override IObservable Visualize(IObservable> source, IServiceProvider provider) + { + if (provider.GetService(typeof(IDialogTypeVisualizerService)) is not Control visualizerControl) + { + return source; + } + + var timer = Observable.Interval( + TimeSpan.FromMilliseconds(100), + HighResolutionScheduler.Default + ); + + return source.SelectMany(input => + input.Buffer(timer) + .Where(buffer => buffer.Count > 0 && !_isProcessing) + .ObserveOn(visualizerControl) + .Do(buffer => + { + if (!UpdateModel()) + { + return; + } + + UpdatePages(); + UpdateHeatmaps(); + UpdateTableLayout(); + + Show(buffer.LastOrDefault()); + } + ) + ); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/IDecoderVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/IDecoderVisualizer.cs new file mode 100644 index 00000000..ee80e44b --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/IDecoderVisualizer.cs @@ -0,0 +1,19 @@ +using Bonsai.ML.Design; + +namespace Bonsai.ML.PointProcessDecoder.Design; + +/// +/// Interface for visualizing the output of a point process decoder. +/// +public interface IDecoderVisualizer +{ + /// + /// Gets or sets the capacity of the visualizer. + /// + public int Capacity { get; set; } + + /// + /// Gets the heatmap that visualizes a component's output. + /// + public HeatMapSeriesOxyPlotBase Plot { get; } +} diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/IntensitiesVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/IntensitiesVisualizer.cs new file mode 100644 index 00000000..0776a3f9 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/IntensitiesVisualizer.cs @@ -0,0 +1,480 @@ +using System; +using System.Reactive.Linq; +using System.Reactive; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; + +using Bonsai; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.IntensitiesVisualizer), + Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + public class IntensitiesVisualizer : DialogTypeVisualizer + { + private int _rowCount = 1; + /// + /// The number of rows in the visualizer. + /// + public int RowCount + { + get => _rowCount; + set + { + if (value < 1) + { + throw new InvalidOperationException("The number of rows must be greater than 0."); + } + _rowCount = value; + } + } + + private int _columnCount = 1; + /// + /// The number of columns in the visualizer. + /// + public int ColumnCount + { + get => _columnCount; + set + { + if (value < 1) + { + throw new InvalidOperationException("The number of columns must be greater than 0."); + } + _columnCount = value; + } + } + + private int _selectedPageIndex = 0; + /// + /// The index of the current page displayed in the visualizer. + /// + public int SelectedPageIndex + { + get => _selectedPageIndex; + set + { + _selectedPageIndex = value; + } + } + + private StatusStrip _statusStrip = null; + /// + /// The status strip control that displays the visualizer options. + /// + public StatusStrip StatusStrip => _statusStrip; + + private readonly int _sampleFrequency = 30; + private int _pageCount = 1; + private string _modelName = string.Empty; + private List _heatmapPlots = null; + private int _intensitiesCount = 0; + private TableLayoutPanel _container = null; + private readonly List _intensitiesCumulativeIndex = []; + private ToolStripNumericUpDown _pageIndexControl = null; + private ToolStripNumericUpDown _rowControl = null; + private ToolStripNumericUpDown _columnControl = null; + private Tensor[] _intensities = null; + private long _stateSpaceWidth; + private long _stateSpaceHeight; + private double[] _stateSpaceMin; + private double[] _stateSpaceMax; + private bool _isProcessing = false; + + /// + public override void Load(IServiceProvider provider) + { + Decode decodeNode = null; + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + decodeNode = ExpressionBuilder.GetWorkflowElement( + expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) + .FirstOrDefault().Value) as Decode; + } + + if (decodeNode == null) + { + throw new InvalidOperationException("The decode node is invalid."); + } + + _modelName = decodeNode.Model; + if (string.IsNullOrEmpty(_modelName)) + { + throw new InvalidOperationException("The point process model name is not set."); + } + + _container = new TableLayoutPanel() + { + Dock = DockStyle.Fill, + AutoSize = true, + ColumnCount = ColumnCount, + RowCount = _rowCount, + }; + + var pageIndexLabel = new ToolStripLabel($"Page: {_selectedPageIndex}"); + _pageIndexControl = new ToolStripNumericUpDown() + { + Minimum = 0, + DecimalPlaces = 0, + Value = _selectedPageIndex, + }; + + _pageIndexControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + var value = Convert.ToInt32(_pageIndexControl.Value); + SelectedPageIndex = value; + UpdateTableLayout(); + Show(null); + pageIndexLabel.Text = $"Page: {_selectedPageIndex}"; + }; + + var rowLabel = new ToolStripLabel($"Rows: {_rowCount}"); + _rowControl = new ToolStripNumericUpDown() + { + Minimum = 1, + DecimalPlaces = 0, + Value = _rowCount, + }; + + _rowControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + + RowCount = Convert.ToInt32(_rowControl.Value); + UpdatePages(); + if (_selectedPageIndex >= _pageCount) + { + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Value = _selectedPageIndex; + } + else + { + UpdateTableLayout(); + Show(null); + } + rowLabel.Text = $"Rows: {_rowCount}"; + }; + + var columnLabel = new ToolStripLabel($"Columns: {_columnCount}"); + _columnControl = new ToolStripNumericUpDown() + { + Minimum = 1, + DecimalPlaces = 0, + Value = _columnCount, + }; + + _columnControl.ValueChanged += (sender, e) => + { + if (_heatmapPlots is null) + { + return; + } + + ColumnCount = Convert.ToInt32(_columnControl.Value); + UpdatePages(); + if (_selectedPageIndex >= _pageCount) + { + SelectedPageIndex = _pageCount - 1; + _pageIndexControl.Value = _selectedPageIndex; + } + else + { + UpdateTableLayout(); + Show(null); + } + columnLabel.Text = $"Columns: {_columnCount}"; + }; + + _statusStrip = new StatusStrip() + { + Visible = true, + }; + + _statusStrip.Items.AddRange([ + pageIndexLabel, + _pageIndexControl, + rowLabel, + _rowControl, + columnLabel, + _columnControl + ]); + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + visualizerService?.AddControl(_container); + visualizerService?.AddControl(_statusStrip); + } + + private void UpdatePages() + { + _pageCount = (int)Math.Ceiling((double)_intensitiesCount / (_rowCount * _columnCount)); + _pageIndexControl.Maximum = _pageCount - 1; + } + + private bool UpdateModel() + { + _isProcessing = true; + PointProcessModel model; + + try { + model = PointProcessModelManager.GetModel(_modelName); + } catch { + return false; + } + + if (model == null) + { + return false; + } + + if (model.StateSpace.Dimensions != 2) + { + throw new InvalidOperationException("For the intensities visualizer to work, the state space dimensions must be 2."); + } + + if (model.Encoder.Intensities.Length == 0 || (model.Encoder.Intensities.Length == 1 && model.Encoder.Intensities[0].NumberOfElements == 0)) + { + return false; + } + + _intensities = model.Encoder.Intensities; + _stateSpaceWidth = model.StateSpace.Shape[0]; + _stateSpaceHeight = model.StateSpace.Shape[1]; + + _stateSpaceMin = [.. model.StateSpace.Points + .min(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _stateSpaceMax = [.. model.StateSpace.Points + .max(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _isProcessing = false; + + return true; + } + + private static int GetIntensitiesCount(Tensor[] intensities, List intensitiesCumulativeIndex) + { + long intensitiesCount = 0; + intensitiesCumulativeIndex.Clear(); + for (int i = 0; i < intensities.Length; i++) { + if (intensities[i].NumberOfElements > 0) { + intensitiesCount += intensities[i].size(0); + intensitiesCumulativeIndex.Add(intensitiesCount); + } + } + return (int)intensitiesCount; + } + + private bool UpdateHeatmaps() + { + if (_heatmapPlots is null) + { + _heatmapPlots = []; + for (int i = 0; i < _intensitiesCount; i++) + { + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(1, 0) + { + Dock = DockStyle.Fill, + }); + } + } + else if (_heatmapPlots.Count > _intensitiesCount) + { + var count = _heatmapPlots.Count - _intensitiesCount; + for (int i = 0; i < count; i++) + { + if (!_heatmapPlots[i + _intensitiesCount].IsDisposed) + { + _heatmapPlots[i + _intensitiesCount].Dispose(); + } + } + _heatmapPlots.RemoveRange(_intensitiesCount, count); + } + else if (_heatmapPlots.Count < _intensitiesCount) + { + for (int i = _heatmapPlots.Count; i < _intensitiesCount; i++) + { + _heatmapPlots.Add(new HeatMapSeriesOxyPlotBase(1, 0) + { + Dock = DockStyle.Fill, + }); + } + } + + return true; + } + + private void UpdateTableLayout() + { + _container.Controls.Clear(); + _container.RowStyles.Clear(); + _container.ColumnStyles.Clear(); + + _container.RowCount = _rowCount; + _container.ColumnCount = _columnCount; + + for (int i = 0; i < _rowCount; i++) + { + _container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / _rowCount)); + } + + for (int i = 0; i < _columnCount; i++) + { + _container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / _columnCount)); + } + + for (int i = 0; i < _rowCount; i++) + { + for (int j = 0; j < _columnCount; j++) + { + var index = SelectedPageIndex * _rowCount * _columnCount + i * _columnCount + j; + if (index >= _intensitiesCount) + { + break; + } + + _container.Controls.Add(_heatmapPlots[index], j, i); + } + } + } + + private (int intensitiesIndex, int intensitiesTensorIndex) GetIntensitiesIndex(int index) + { + + var intensitiesIndex = 0; + for (int i = 0; i < _intensitiesCumulativeIndex.Count; i++) + { + if (index < _intensitiesCumulativeIndex[i]) + { + intensitiesIndex = i; + break; + } + } + var intensitiesTensorIndex = intensitiesIndex == 0 ? index : index - _intensitiesCumulativeIndex[intensitiesIndex - 1]; + return (intensitiesIndex, (int)intensitiesTensorIndex); + } + + /// + public override void Show(object value) + { + var startIndex = _selectedPageIndex * _rowCount * _columnCount; + var endIndex = Math.Min(startIndex + _rowCount * _columnCount, _intensitiesCount); + + for (int i = startIndex; i < endIndex; i++) + { + var (intensitiesIndex, intensitiesTensorIndex) = GetIntensitiesIndex(i); + + var intensity = _intensities[intensitiesIndex][intensitiesTensorIndex]; + + if (intensity.Dimensions == 2) { + intensity = intensity + .sum(dim: 0); + } + + var intensityValues = (double[,])intensity + .exp() + .to_type(ScalarType.Float64) + .reshape([_stateSpaceWidth, _stateSpaceHeight]) + .data() + .ToNDArray(); + + _heatmapPlots[i].UpdateHeatMapSeries( + _stateSpaceMin[0], + _stateSpaceMax[0], + _stateSpaceMin[1], + _stateSpaceMax[1], + intensityValues + ); + + _heatmapPlots[i].UpdatePlot(); + } + } + + /// + public override void Unload() + { + if (_container != null) + { + if (!_container.IsDisposed) + { + _container.Dispose(); + } + _container = null; + } + + if (_heatmapPlots != null) + { + for (int i = 0; i < _heatmapPlots.Count; i++) + { + if (!_heatmapPlots[i].IsDisposed) + { + _heatmapPlots[i].Dispose(); + } + } + _heatmapPlots = null; + }; + + _intensitiesCount = 0; + _intensitiesCumulativeIndex.Clear(); + _intensities = null; + } + + public override IObservable Visualize(IObservable> source, IServiceProvider provider) + { + if (provider.GetService(typeof(IDialogTypeVisualizerService)) is not Control visualizerControl) + { + return source; + } + + return source.SelectMany(input => + input.Sample(TimeSpan.FromMilliseconds(_sampleFrequency)) + .ObserveOn(visualizerControl) + .Do(value => + { + if (!UpdateModel() && !_isProcessing) + { + return; + } + + var newIntensitiesCount = GetIntensitiesCount(_intensities, _intensitiesCumulativeIndex); + if (_intensitiesCount != newIntensitiesCount) + { + _intensitiesCount = newIntensitiesCount; + UpdatePages(); + UpdateHeatmaps(); + UpdateTableLayout(); + } + + Show(value); + } + ) + ); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs new file mode 100644 index 00000000..9e09ab91 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/LikelihoodVisualizer.cs @@ -0,0 +1,258 @@ +using System; +using System.Reactive.Linq; +using System.Reactive; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; + +using Bonsai; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; + +using OxyPlot.Series; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using TorchSharp; +using Bonsai.Dag; +using System.Linq.Expressions; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.LikelihoodVisualizer), + Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + /// + /// Visualizer for the likelihood of a point process model. + /// + public class LikelihoodVisualizer : MashupVisualizer, IDecoderVisualizer + { + private MultidimensionalArrayVisualizer _visualizer; + + /// + /// Gets the underlying heatmap plot. + /// + public HeatMapSeriesOxyPlotBase Plot => _visualizer.Plot; + + private int _capacity = 100; + /// + /// Gets or sets the integer value that determines how many data points should be shown along the x axis if the posterior is a 1D tensor. + /// + public int Capacity + { + get => _capacity; + set + { + _capacity = value; + } + } + + /// + /// Gets or sets the minimum value of the likelihood. + /// + public double? ValueMin { get; set; } = null; + + /// + /// Gets or sets the maximum value of the likelihood. + /// + public double? ValueMax { get; set; } = null; + + private double[,] _data = null; + private long _stateSpaceWidth; + private long _stateSpaceHeight; + private double[] _stateSpaceMin = null; + private double[] _stateSpaceMax = null; + private string _modelName; + private ILikelihood _likelihood; + private Tensor[] _intensities; + private IObservable> _inputSource; + + /// + public override void Load(IServiceProvider provider) + { + Node visualizerNode = null; + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + visualizerNode = (from node in expressionBuilderGraph + where node.Value == typeVisualizerContext.Source + select node).FirstOrDefault(); + } + + if (visualizerNode == null) + { + throw new InvalidOperationException("The visualizer node is invalid."); + } + + var inspector = (InspectBuilder)expressionBuilderGraph + .Predecessors(visualizerNode) + .First(p => !p.Value.IsBuildDependency()) + .Value; + + _inputSource = inspector.Output; + + if (ExpressionBuilder.GetWorkflowElement(visualizerNode.Value) is not Decode decodeNode) + { + throw new InvalidOperationException("The decode node is invalid."); + } + + _modelName = decodeNode.Model; + + _visualizer = new MultidimensionalArrayVisualizer() + { + PaletteSelectedIndex = 1, + RenderMethodSelectedIndex = 0 + }; + + _visualizer.Load(provider); + + var capacityLabel = new ToolStripLabel + { + Text = "Capacity:", + AutoSize = true + }; + + var capacityValue = new ToolStripTextBox + { + Text = Capacity.ToString(), + AutoSize = true + }; + + capacityValue.TextChanged += (sender, e) => + { + if (int.TryParse(capacityValue.Text, out int capacity)) + { + Capacity = capacity; + } + }; + + _visualizer.Plot.VisualizerPropertiesDropDown.DropDownItems.AddRange([ + capacityLabel, + capacityValue + ]); + + base.Load(provider); + } + + private bool UpdateModel() + { + PointProcessModel model; + + try + { + model = PointProcessModelManager.GetModel(_modelName); + } catch { + return false; + } + + if (model == null) + { + return false; + } + + _stateSpaceWidth = model.StateSpace.Shape[0]; + _stateSpaceHeight = model.StateSpace.Shape[1]; + + _stateSpaceMin ??= [.. model.StateSpace.Points + .min(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _stateSpaceMax ??= [.. model.StateSpace.Points + .max(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _likelihood = model.Likelihood; + _intensities = model.Encoder.Intensities; + + return true; + } + + /// + public override void Show(object value) + { + Tensor inputs = (Tensor)value; + Tensor likelihood = _likelihood.Likelihood(inputs, _intensities); + + if (likelihood.Dimensions == 2) { + likelihood = likelihood + .sum(dim: 0); + } + + _data = (double[,])likelihood + .to_type(ScalarType.Float64) + .reshape([_stateSpaceWidth, _stateSpaceHeight]) + .data() + .ToNDArray(); + + + _visualizer.Plot.UpdateHeatMapSeries( + _stateSpaceMin[0], + _stateSpaceMax[0], + _stateSpaceMin[1], + _stateSpaceMax[1], + _data + ); + + _visualizer.Plot.UpdatePlot(); + } + + /// + public override void Unload() + { + _visualizer.Unload(); + base.Unload(); + } + + /// + public override IObservable Visualize(IObservable> source, IServiceProvider provider) + { + if (provider.GetService(typeof(IDialogTypeVisualizerService)) is not Control visualizerControl) + { + return source; + } + + var colorCycler = new OxyColorPresetCycle(); + + var timer = Observable.Interval( + TimeSpan.FromMilliseconds(100), + HighResolutionScheduler.Default + ); + + var mergedSource = _inputSource.SelectMany(xs => + xs.Buffer(timer) + .Where(buffer => buffer.Count > 0) + .Sample(source.Merge()) + .Do(buffer => { + if (!UpdateModel()) + { + return; + } + ValueMin = _visualizer.Plot.ValueMin; + ValueMax = _visualizer.Plot.ValueMax; + Show(buffer.LastOrDefault()); + })); + + var mashupSourceStreams = Observable.Merge( + MashupSources.Select(mashupSource => + mashupSource.Source.Output.SelectMany(xs => { + var color = colorCycler.Next(); + var visualizer = mashupSource.Visualizer as Point2DOverlay; + visualizer.Color = color; + return xs.Buffer(timer) + .Where(buffer => buffer.Count > 0) + .Do(buffer => mashupSource.Visualizer.Show(buffer.LastOrDefault())); + }))); + + return Observable.Merge(mergedSource, mashupSourceStreams); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs b/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs new file mode 100644 index 00000000..e2976cf4 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs @@ -0,0 +1,40 @@ + +using OxyPlot; + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + /// + /// Enumerates the colors and provides a preset collection of colors to cycle through. + /// + public class OxyColorPresetCycle + { + private static readonly OxyColor[] Colors = + [ + OxyColors.LimeGreen, + OxyColors.Red, + OxyColors.Blue, + OxyColors.Orange, + OxyColors.Purple, + OxyColors.Yellow, + OxyColors.Pink, + OxyColors.Brown, + OxyColors.Cyan, + OxyColors.Magenta, + OxyColors.Green, + OxyColors.Gray, + OxyColors.Black + ]; + + private int _index; + + /// + /// Gets the next color in the cycle. + /// + public OxyColor Next() + { + var color = Colors[_index]; + _index = (_index + 1) % Colors.Length; + return color; + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs b/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs new file mode 100644 index 00000000..a4f3f089 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/Point2DOverlay.cs @@ -0,0 +1,103 @@ +using Bonsai; +using Bonsai.Design; +using Bonsai.Design.Visualizers; +using System; +using System.Collections.Generic; +using OxyPlot.Series; +using OxyPlot.Axes; +using OxyPlot; +using Bonsai.Vision.Design; +using Bonsai.ML.Design; +using System.Linq; +using TorchSharp; +using OpenCV.Net; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.Point2DOverlay), + Target = typeof(MashupSource))] + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.Point2DOverlay), + Target = typeof(MashupSource))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + /// + /// Class that overlays the true + /// + public class Point2DOverlay : DialogTypeVisualizer + { + private LineSeries _lineSeries; + private ScatterSeries _scatterSeries; + private int _dataCount; + private IDecoderVisualizer decoderVisualizer; + + private OxyColor _color = OxyColors.LimeGreen; + + /// + /// Gets or sets the color of the overlay. + /// + public OxyColor Color + { + get => _color; + set + { + if (_lineSeries != null && _scatterSeries != null) + { + _lineSeries.Color = value; + _scatterSeries.MarkerFill = value; + _color = value; + } + } + } + + /// + public override void Load(IServiceProvider provider) + { + decoderVisualizer = provider.GetService(typeof(MashupVisualizer)) as IDecoderVisualizer; + + _lineSeries = new LineSeries() + { + Color = _color, + StrokeThickness = 2 + }; + + var colorAxis = new LinearColorAxis() + { + IsAxisVisible = false, + Key = "Point2DOverlayColorAxis" + }; + + _scatterSeries = new ScatterSeries() + { + MarkerType = MarkerType.Circle, + MarkerSize = 10, + MarkerFill = _color, + ColorAxisKey = "Point2DOverlayColorAxis" + }; + + decoderVisualizer.Plot.Model.Series.Add(_scatterSeries); + decoderVisualizer.Plot.Model.Series.Add(_lineSeries); + decoderVisualizer.Plot.Model.Axes.Add(colorAxis); + } + + /// + public override void Show(object value) + { + dynamic point = value; + _dataCount++; + _lineSeries.Points.Add(new DataPoint(point.X, point.Y)); + _scatterSeries.Points.Clear(); + _scatterSeries.Points.Add(new ScatterPoint(point.X, point.Y, value: 1)); + + while (_dataCount > decoderVisualizer.Capacity) + { + _lineSeries.Points.RemoveAt(0); + _dataCount--; + } + } + + /// + public override void Unload() + { + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs b/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs new file mode 100644 index 00000000..e26266d1 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs @@ -0,0 +1,274 @@ +using System; +using System.Reactive.Linq; +using System.Reactive; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; + +using Bonsai; +using Bonsai.Expressions; +using Bonsai.Design; +using Bonsai.ML.Design; + +using OxyPlot.Series; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using TorchSharp; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.PosteriorVisualizer), + Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))] + +namespace Bonsai.ML.PointProcessDecoder.Design +{ + public class PosteriorVisualizer : MashupVisualizer, IDecoderVisualizer + { + private MultidimensionalArrayVisualizer _visualizer; + + /// + /// Gets the underlying heatmap plot. + /// + public HeatMapSeriesOxyPlotBase Plot => _visualizer.Plot; + + private int _capacity = 100; + /// + /// Gets or sets the integer value that determines how many data points should be shown along the x axis if the posterior is a 1D tensor. + /// + public int Capacity + { + get => _capacity; + set + { + _capacity = value; + } + } + + /// + /// Gets or sets the minimum value of the likelihood. + /// + public double? ValueMin { get; set; } = null; + + /// + /// Gets or sets the maximum value of the likelihood. + /// + public double? ValueMax { get; set; } = null; + + private double[,] _data = null; + private double[] _stateSpaceMin; + private double[] _stateSpaceMax; + private string _modelName; + private bool _success = false; + private Tensor _dataTensor; + + /// + public override void Load(IServiceProvider provider) + { + Decode decodeNode = null; + var expressionBuilderGraph = (ExpressionBuilderGraph)provider.GetService(typeof(ExpressionBuilderGraph)); + var typeVisualizerContext = (ITypeVisualizerContext)provider.GetService(typeof(ITypeVisualizerContext)); + if (expressionBuilderGraph != null && typeVisualizerContext != null) + { + decodeNode = ExpressionBuilder.GetWorkflowElement( + expressionBuilderGraph.Where(node => node.Value == typeVisualizerContext.Source) + .FirstOrDefault().Value) as Decode; + } + + if (decodeNode == null) + { + throw new InvalidOperationException("The decode node is invalid."); + } + + _modelName = decodeNode.Model; + + _visualizer = new MultidimensionalArrayVisualizer() + { + PaletteSelectedIndex = 1, + RenderMethodSelectedIndex = 0 + }; + + _visualizer.Load(provider); + + _visualizer.Plot.ValueMin = ValueMin; + _visualizer.Plot.ValueMax = ValueMax; + + var capacityLabel = new ToolStripLabel + { + Text = "Capacity:", + AutoSize = true + }; + + var capacityValue = new ToolStripTextBox + { + Text = Capacity.ToString(), + AutoSize = true + }; + + capacityValue.TextChanged += (sender, e) => + { + if (int.TryParse(capacityValue.Text, out int capacity)) + { + Capacity = capacity; + } + }; + + _visualizer.Plot.VisualizerPropertiesDropDown.DropDownItems.AddRange([ + capacityLabel, + capacityValue + ]); + + base.Load(provider); + } + + private void UpdateModel() + { + PointProcessModel model; + + try + { + model = PointProcessModelManager.GetModel(_modelName); + } catch { + _success = false; + return; + } + + if (model == null) + { + _success = false; + return; + } + + _stateSpaceMin = [.. model.StateSpace.Points + .min(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _stateSpaceMax = [.. model.StateSpace.Points + .max(dim: 0) + .values + .to_type(ScalarType.Float64) + .data() + ]; + + _success = true; + } + + /// + public override void Show(object value) + { + Tensor posterior = (Tensor)value; + if (posterior.NumberOfElements == 0 || !_success) + { + return; + } + + posterior = posterior.sum(dim: 0); + + double xMin; + double xMax; + double yMin; + double yMax; + + if (posterior.Dimensions == 1) + { + + if (_data == null) + { + _dataTensor = zeros(_capacity, posterior.size(0), dtype: ScalarType.Float64, device: posterior.device); + } + + _dataTensor = _dataTensor[TensorIndex.Slice(1)]; + _dataTensor = concat([_dataTensor, + posterior.to_type(ScalarType.Float64) + .unsqueeze(0) + ], dim: 0); + + _data = (double[,])_dataTensor + .data() + .ToNDArray(); + + xMin = 0; + xMax = _capacity; + yMin = _stateSpaceMin[0]; + yMax = _stateSpaceMax[0]; + } + else + { + + while (posterior.Dimensions > 2) + { + posterior = posterior.sum(dim: 0); + } + + _data = (double[,])posterior + .to_type(ScalarType.Float64) + .data() + .ToNDArray(); + + xMin = _stateSpaceMin[_stateSpaceMin.Length - 2]; + xMax = _stateSpaceMax[_stateSpaceMax.Length - 2]; + yMin = _stateSpaceMin[_stateSpaceMin.Length - 1]; + yMax = _stateSpaceMax[_stateSpaceMax.Length - 1]; + } + + _visualizer.Plot.UpdateHeatMapSeries( + xMin, xMax, yMin, yMax, _data + ); + + _visualizer.Plot.UpdatePlot(); + } + + /// + public override void Unload() + { + _visualizer.Unload(); + base.Unload(); + } + + /// + public override IObservable Visualize(IObservable> source, IServiceProvider provider) + { + if (provider.GetService(typeof(IDialogTypeVisualizerService)) is not Control visualizerControl) + { + return source; + } + + var colorCycler = new OxyColorPresetCycle(); + + var timer = Observable.Interval( + TimeSpan.FromMilliseconds(100), + HighResolutionScheduler.Default + ); + + var mergedSource = source.SelectMany(xs => + xs.Buffer(timer) + .Where(buffer => buffer.Count > 0) + .Do(buffer => { + if (!_success) + { + UpdateModel(); + } + ValueMin = _visualizer.Plot.ValueMin; + ValueMax = _visualizer.Plot.ValueMax; + Show(buffer.LastOrDefault()); + })); + + var mashupSourceStreams = Observable.Merge( + MashupSources.Select(mashupSource => { + var color = colorCycler.Next(); + var visualizer = mashupSource.Visualizer as Point2DOverlay; + visualizer.Color = color; + return mashupSource.Source.Output.SelectMany(xs => + xs.Buffer(timer) + .Where(buffer => buffer.Count > 0) + .Do(buffer => mashupSource.Visualizer.Show(buffer.LastOrDefault())) + ); + }) + ); + + return Observable.Merge(mergedSource, mashupSourceStreams); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/ToolStripNumericUpDown.cs b/src/Bonsai.ML.PointProcessDecoder.Design/ToolStripNumericUpDown.cs new file mode 100644 index 00000000..80e38e8b --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder.Design/ToolStripNumericUpDown.cs @@ -0,0 +1,47 @@ +using System; +using System.Windows.Forms; + +namespace Bonsai.ML.PointProcessDecoder.Design; + +internal class ToolStripNumericUpDown : ToolStripControlHost +{ + public ToolStripNumericUpDown() + : base(new NumericUpDown()) + { + } + + public NumericUpDown NumericUpDown + { + get { return Control as NumericUpDown; } + } + + public int DecimalPlaces + { + get { return NumericUpDown.DecimalPlaces; } + set { NumericUpDown.DecimalPlaces = value; } + } + + public decimal Value + { + get { return NumericUpDown.Value; } + set { NumericUpDown.Value = value; } + } + + public decimal Minimum + { + get { return NumericUpDown.Minimum; } + set { NumericUpDown.Minimum = value; } + } + + public decimal Maximum + { + get { return NumericUpDown.Maximum; } + set { NumericUpDown.Maximum = value; } + } + + public event EventHandler ValueChanged + { + add { NumericUpDown.ValueChanged += value; } + remove { NumericUpDown.ValueChanged -= value; } + } +} diff --git a/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj new file mode 100644 index 00000000..3b9d62b1 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/Bonsai.ML.PointProcessDecoder.csproj @@ -0,0 +1,16 @@ + + + Bonsai.ML.PointProcessDecoder + A Bonsai package for running a neural decoder based on point processes using Bayesian state-space models. + Bonsai Rx ML Point Process Neural Decoder + net472;netstandard2.0 + enable + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs new file mode 100644 index 00000000..9eec439a --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs @@ -0,0 +1,484 @@ +using System; +using System.ComponentModel; +using System.Xml.Serialization; +using System.Linq; +using System.Reactive.Linq; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using PointProcessDecoder.Core.Estimation; +using PointProcessDecoder.Core.Transitions; +using PointProcessDecoder.Core.Encoder; +using PointProcessDecoder.Core.Decoder; +using PointProcessDecoder.Core.StateSpace; +using PointProcessDecoder.Core.Likelihood; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Creates a new neural decoding model based on point processes using Bayesian state space models. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Source)] +[Description("Creates a new neural decoding model based on point processes using Bayesian state space models.")] +public class CreatePointProcessModel +{ + private string name = "PointProcessModel"; + /// + /// Gets or sets the name of the point process model. + /// + [Category("1. Model Parameters")] + [Description("The name of the point process model.")] + public string Name + { + get + { + return name; + } + set + { + name = value; + } + } + + Device? device = null; + /// + /// Gets or sets the device used to run the neural decoding model. + /// + [XmlIgnore] + [Category("1. Model Parameters")] + [Description("The device used to run the neural decoding model.")] + public Device? Device + { + get + { + return device; + } + set + { + device = value; + } + } + + ScalarType? scalarType = null; + /// + /// Gets or sets the scalar type used to run the neural decoding model. + /// + [Category("1. Model Parameters")] + [Description("The scalar type used to run the neural decoding model.")] + public ScalarType? ScalarType + { + get + { + return scalarType; + } + set + { + scalarType = value; + } + } + + private StateSpaceType stateSpaceType = StateSpaceType.DiscreteUniformStateSpace; + /// + /// Gets or sets the type of state space used. + /// + [Category("2. State Space Parameters")] + [Description("The type of state space used.")] + public StateSpaceType StateSpaceType + { + get + { + return stateSpaceType; + } + set + { + stateSpaceType = value; + } + } + + private int stateSpaceDimensions = 1; + /// + /// Gets or sets the number of dimensions in the state space. + /// + [Category("2. State Space Parameters")] + [Description("The number of dimensions in the state space.")] + public int StateSpaceDimensions + { + get + { + return stateSpaceDimensions; + } + set + { + stateSpaceDimensions = value; + } + } + + private double[] minStateSpace = [0]; + /// + /// Gets or sets the minimum values of the state space. Must be the same length as the number of state space dimensions. + /// + [Category("2. State Space Parameters")] + [Description("The minimum values of the state space. Must be the same length as the number of state space dimensions.")] + public double[] MinStateSpace + { + get + { + return minStateSpace; + } + set + { + minStateSpace = value; + } + } + + private double[] maxStateSpace = [100]; + /// + /// Gets or sets the maximum values of the state space. Must be the same length as the number of state space dimensions. + /// + [Category("2. State Space Parameters")] + [Description("The maximum values of the state space. Must be the same length as the number of state space dimensions.")] + public double[] MaxStateSpace + { + get + { + return maxStateSpace; + } + set + { + maxStateSpace = value; + } + } + + private long[] stepsStateSpace = [50]; + /// + /// Gets or sets the number of steps evaluated in the state space. Must be the same length as the number of state space dimensions. + /// + [Category("2. State Space Parameters")] + [Description("The number of steps evaluated in the state space. Must be the same length as the number of state space dimensions.")] + public long[] StepsStateSpace + { + get + { + return stepsStateSpace; + } + set + { + stepsStateSpace = value; + } + } + + private double[] observationBandwidth = [1]; + /// + /// Gets or sets the bandwidth of the observation estimation method. Must be the same length as the number of state space dimensions. + /// + [Category("2. State Space Parameters")] + [Description("The bandwidth of the observation estimation method. Must be the same length as the number of state space dimensions.")] + public double[] ObservationBandwidth + { + get + { + return observationBandwidth; + } + set + { + observationBandwidth = value; + } + } + + private EncoderType encoderType = EncoderType.SortedSpikeEncoder; + /// + /// Gets or sets the type of encoder used. + /// + [Category("3. Encoder Parameters")] + [Description("The type of encoder used.")] + public EncoderType EncoderType + { + get + { + return encoderType; + } + set + { + encoderType = value; + } + } + + private int? kernelLimit = null; + /// + /// Gets or sets the kernel limit. + /// + [Category("3. Encoder Parameters")] + [Description("The kernel limit.")] + public int? KernelLimit + { + get + { + return kernelLimit; + } + set + { + kernelLimit = value; + } + } + + private int? nUnits = null; + /// + /// Gets or sets the number of sorted spiking units. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("The number of sorted spiking units. Only used when the encoder type is set to SortedSpikeEncoder.")] + public int? NUnits + { + get + { + return nUnits; + } + set + { + nUnits = value; + } + } + + private int? markDimensions = null; + /// + /// Gets or sets the number of dimensions or features associated with each mark. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("The number of dimensions or features associated with each mark. Only used when the encoder type is set to ClusterlessMarkEncoder.")] + public int? MarkDimensions + { + get + { + return markDimensions; + } + set + { + markDimensions = value; + } + } + + private int? markChannels = null; + /// + /// Gets or sets the number of mark recording channels. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("The number of mark recording channels. Only used when the encoder type is set to ClusterlessMarkEncoder.")] + public int? MarkChannels + { + get + { + return markChannels; + } + set + { + markChannels = value; + } + } + + private double[]? markBandwidth = null; + /// + /// Gets or sets the bandwidth of the mark estimation method. + /// Must be the same length as the number of mark dimensions. + /// Only used when the encoder type is set to . + /// + [Category("3. Encoder Parameters")] + [Description("The bandwidth of the mark estimation method. Must be the same length as the number of mark dimensions. Only used when the encoder type is set to ClusterlessMarkEncoder.")] + public double[]? MarkBandwidth + { + get + { + return markBandwidth; + } + set + { + markBandwidth = value; + } + } + + private EstimationMethod estimationMethod = EstimationMethod.KernelDensity; + /// + /// Gets or sets the estimation method used during the encoding process. + /// + [Category("4. Estimation Parameters")] + [Description("The estimation method used during the encoding process.")] + public EstimationMethod EstimationMethod + { + get + { + return estimationMethod; + } + set + { + estimationMethod = value; + } + } + + private double? distanceThreshold = null; + /// + /// Gets or sets the distance threshold used to determine the threshold to merge unique clusters into a single compressed cluster. + /// Only used when the estimation method is set to . + /// + [Category("4. Estimation Parameters")] + [Description("The distance threshold used to determine the threshold to merge unique clusters into a single compressed cluster. Only used when the estimation method is set to KernelCompression.")] + public double? DistanceThreshold + { + get + { + return distanceThreshold; + } + set + { + distanceThreshold = value; + } + } + + private LikelihoodType likelihoodType = LikelihoodType.Poisson; + /// + /// Gets or sets the type of likelihood function used. + /// + [Category("5. Likelihood Parameters")] + [Description("The type of likelihood function used.")] + public LikelihoodType LikelihoodType + { + get + { + return likelihoodType; + } + set + { + likelihoodType = value; + } + } + + private bool ignoreNoSpikes = false; + /// + /// Gets or sets a value indicating whether to ignore contributions from units or channels with no spikes. + /// + [Category("5. Likelihood Parameters")] + [Description("Indicates whether to ignore contributions from units or channels with no spikes.")] + public bool IgnoreNoSpikes + { + get + { + return ignoreNoSpikes; + } + set + { + ignoreNoSpikes = value; + } + } + + private bool sumAcrossBatch = true; + /// + /// Gets or sets a value indicating whether to sum across the batched likelihood. + /// + [Category("5. Likelihood Parameters")] + [Description("Indicates whether to sum across the batched likelihood.")] + public bool SumAcrossBatch + { + get + { + return sumAcrossBatch; + } + set + { + sumAcrossBatch = value; + } + } + + private TransitionsType transitionsType = TransitionsType.RandomWalk; + /// + /// Gets or sets the type of transition model used during the decoding process. + /// + [Category("6. Transition Parameters")] + [Description("The type of transition model used during the decoding process.")] + public TransitionsType TransitionsType + { + get + { + return transitionsType; + } + set + { + transitionsType = value; + } + } + + private double? sigmaRandomWalk = null; + /// + /// Gets or sets the standard deviation of the random walk transitions model. + /// Only used when the transitions type is set to . + /// + [Category("6. Transition Parameters")] + [Description("The standard deviation of the random walk transitions model. Only used when the transitions type is set to RandomWalk.")] + public double? SigmaRandomWalk + { + get + { + return sigmaRandomWalk; + } + set + { + sigmaRandomWalk = value; + } + } + + private DecoderType decoderType = DecoderType.StateSpaceDecoder; + /// + /// Gets or sets the type of decoder used. + /// + [Category("7. Decoder Parameters")] + [Description("The type of decoder used.")] + public DecoderType DecoderType + { + get + { + return decoderType; + } + set + { + decoderType = value; + } + } + + /// + /// Creates a new neural decoding model based on point processes using Bayesian state space models. + /// + /// + public IObservable Process() + { + return Observable.Using( + () => PointProcessModelManager.Reserve( + name: name, + estimationMethod: estimationMethod, + transitionsType: transitionsType, + encoderType: encoderType, + decoderType: decoderType, + stateSpaceType: stateSpaceType, + likelihoodType: likelihoodType, + minStateSpace: minStateSpace, + maxStateSpace: maxStateSpace, + stepsStateSpace: stepsStateSpace, + observationBandwidth: observationBandwidth, + stateSpaceDimensions: stateSpaceDimensions, + markDimensions: markDimensions, + markChannels: markChannels, + markBandwidth: markBandwidth, + ignoreNoSpikes: ignoreNoSpikes, + sumAcrossBatch: sumAcrossBatch, + nUnits: nUnits, + distanceThreshold: distanceThreshold, + sigmaRandomWalk: sigmaRandomWalk, + kernelLimit: kernelLimit, + device: device, + scalarType: scalarType + ), resource => Observable.Return(resource.Model) + .Concat(Observable.Never(resource.Model)) + .Finally(resource.Dispose)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/Decode.cs b/src/Bonsai.ML.PointProcessDecoder/Decode.cs new file mode 100644 index 00000000..7c6d215a --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/Decode.cs @@ -0,0 +1,85 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using PointProcessDecoder.Core; +using PointProcessDecoder.Core.Likelihood; +using static TorchSharp.torch; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Decodes the input neural data into a posterior state estimate using a point process model. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Transform)] +[Description("Decodes the input neural data into a posterior state estimate using a point process model.")] +public class Decode +{ + /// + /// The name of the point process model to use. + /// + [TypeConverter(typeof(PointProcessModelNameConverter))] + [Description("The name of the point process model to use.")] + public string Model { get; set; } = string.Empty; + + private bool _ignoreNoSpikes = false; + private bool _updateIgnoreNoSpikes = false; + /// + /// Gets or sets a value indicating whether to ignore contributions from no spike events. + /// + [Description("Indicates whether to ignore contributions from no spike events.")] + public bool IgnoreNoSpikes + { + get => _ignoreNoSpikes; + set + { + _ignoreNoSpikes = value; + _updateIgnoreNoSpikes = true; + } + } + + private bool _sumAcrossBatch = true; + private bool _updateSumAcrossBatch = false; + /// + /// Gets or sets a value indicating whether to ignore contributions from no spike events. + /// + [Description("Indicates whether to ignore contributions from no spike events.")] + public bool SumAcrossBatch + { + get => _sumAcrossBatch; + set + { + _sumAcrossBatch = value; + _updateSumAcrossBatch = true; + } + } + + + /// + /// Decodes the input neural data into a posterior state estimate using a point process model. + /// + /// + /// + public IObservable Process(IObservable source) + { + var modelName = Model; + return source.Select(input => + { + var model = PointProcessModelManager.GetModel(modelName); + if (_updateIgnoreNoSpikes) + { + model.Likelihood.IgnoreNoSpikes = _ignoreNoSpikes; + _updateIgnoreNoSpikes = false; + } + + if (_updateSumAcrossBatch) + { + + model.Likelihood.SumAcrossBatch = _sumAcrossBatch; + _updateSumAcrossBatch = false; + } + + return model.Decode(input); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/Encode.cs b/src/Bonsai.ML.PointProcessDecoder/Encode.cs new file mode 100644 index 00000000..cd64404c --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/Encode.cs @@ -0,0 +1,39 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; + +using static TorchSharp.torch; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Encodes the combined state observation data and neural data into a point process model. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Sink)] +[Description("Encodes the combined state observation data and neural data into a point process model.")] +public class Encode +{ + /// + /// The name of the point process model to use. + /// + [TypeConverter(typeof(PointProcessModelNameConverter))] + [Description("The name of the point process model to use.")] + public string Model { get; set; } = string.Empty; + + /// + /// Encodes the combined state observation data and neural data into a point process model. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + var modelName = Model; + return source.Do(input => + { + var model = PointProcessModelManager.GetModel(modelName); + var (neuralData, stateObservations) = input; + model.Encode(neuralData, stateObservations); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/GetModel.cs b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs new file mode 100644 index 00000000..37ab955c --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/GetModel.cs @@ -0,0 +1,48 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using PointProcessDecoder.Core; +using static TorchSharp.torch; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Returns the point process model. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Source)] +[Description("Returns the point process model.")] +public class GetModel +{ + /// + /// The name of the point process model to return. + /// + [TypeConverter(typeof(PointProcessModelNameConverter))] + [Description("The name of the point process model to return.")] + public string Model { get; set; } = string.Empty; + + /// + /// Returns the point process model. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => + Observable.Return(PointProcessModelManager.GetModel(Model)) + ); + } + + /// + /// Returns the point process model. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + var modelName = Model; + return source.Select(input => { + return PointProcessModelManager.GetModel(modelName); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/LoadPointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/LoadPointProcessModel.cs new file mode 100644 index 00000000..803bbff9 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/LoadPointProcessModel.cs @@ -0,0 +1,100 @@ +using System; +using System.ComponentModel; +using System.Xml.Serialization; +using System.Linq; +using System.Reactive.Linq; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using PointProcessDecoder.Core.Estimation; +using PointProcessDecoder.Core.Transitions; +using PointProcessDecoder.Core.Encoder; +using PointProcessDecoder.Core.Decoder; +using PointProcessDecoder.Core.StateSpace; +using PointProcessDecoder.Core.Likelihood; +using System.IO; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Loads a point process model from a saved state. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Source)] +[Description("Loads a point process model from a saved state.")] +public class LoadPointProcessModel +{ + private string name = "PointProcessModel"; + + /// + /// Gets or sets the name of the point process model. + /// + [Description("The name of the point process model.")] + public string Name + { + get + { + return name; + } + set + { + name = value; + } + } + + Device? device = null; + /// + /// Gets or sets the device used to run the neural decoding model. + /// + [XmlIgnore] + [Description("The device used to run the neural decoding model.")] + public Device? Device + { + get + { + return device; + } + set + { + device = value; + } + } + + /// + /// The path to the folder where the state of the point process model was saved. + /// + [Editor("Bonsai.Design.FolderNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the folder where the state of the point process model was saved.")] + public string Path { get; set; } = string.Empty; + + + /// + /// Creates a new neural decoding model based on point processes using Bayesian state space models. + /// + /// + public IObservable Process() + { + return Observable.Using( + () => { + + if (string.IsNullOrEmpty(Path)) + { + throw new InvalidOperationException("The save path is not specified."); + } + + if (!Directory.Exists(Path)) + { + throw new InvalidOperationException("The save path does not exist."); + } + + return PointProcessModelManager.Load( + name: name, + path: Path, + device: device + ); + }, resource => Observable.Return(resource.Model) + .Concat(Observable.Never(resource.Model)) + .Finally(resource.Dispose)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs new file mode 100644 index 00000000..af9461d0 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelDisposable.cs @@ -0,0 +1,26 @@ +using System; +using System.Threading; +using PointProcessDecoder.Core; + +namespace Bonsai.ML.PointProcessDecoder; + +internal sealed class PointProcessModelDisposable(PointProcessModel model, IDisposable disposable) : IDisposable +{ + private IDisposable? resource = disposable ?? throw new ArgumentNullException(nameof(disposable)); + /// + /// Gets a value indicating whether the object has been disposed. + /// + public bool IsDisposed => resource == null; + + private readonly PointProcessModel model = model ?? throw new ArgumentNullException(nameof(model)); + /// + /// Gets the point process model. + /// + public PointProcessModel Model => model; + + public void Dispose() + { + var disposable = Interlocked.Exchange(ref resource, null); + disposable?.Dispose(); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs new file mode 100644 index 00000000..c647d674 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelManager.cs @@ -0,0 +1,112 @@ +using System; +using System.Collections.Generic; +using System.Reactive.Disposables; + +using static TorchSharp.torch; + +using PointProcessDecoder.Core; +using PointProcessDecoder.Core.Estimation; +using PointProcessDecoder.Core.Transitions; +using PointProcessDecoder.Core.StateSpace; +using PointProcessDecoder.Core.Encoder; +using PointProcessDecoder.Core.Decoder; +using PointProcessDecoder.Core.Likelihood; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Manages the point process models. +/// +public static class PointProcessModelManager +{ + private static readonly Dictionary models = []; + + /// + /// Gets the point process model with the specified name. + /// + /// + /// + /// + public static PointProcessModel GetModel(string name) + { + return models.TryGetValue(name, out var model) ? model : throw new InvalidOperationException($"Model with name {name} not found."); + } + + internal static PointProcessModelDisposable Reserve( + string name, + EstimationMethod estimationMethod, + TransitionsType transitionsType, + EncoderType encoderType, + DecoderType decoderType, + StateSpaceType stateSpaceType, + LikelihoodType likelihoodType, + double[] minStateSpace, + double[] maxStateSpace, + long[] stepsStateSpace, + double[] observationBandwidth, + int stateSpaceDimensions, + int? markDimensions = null, + int? markChannels = null, + double[]? markBandwidth = null, + bool ignoreNoSpikes = false, + bool sumAcrossBatch = true, + int? nUnits = null, + double? distanceThreshold = null, + double? sigmaRandomWalk = null, + int? kernelLimit = null, + Device? device = null, + ScalarType? scalarType = null + ) + { + var model = new PointProcessModel( + estimationMethod: estimationMethod, + transitionsType: transitionsType, + encoderType: encoderType, + decoderType: decoderType, + stateSpaceType: stateSpaceType, + likelihoodType: likelihoodType, + minStateSpace: minStateSpace, + maxStateSpace: maxStateSpace, + stepsStateSpace: stepsStateSpace, + observationBandwidth: observationBandwidth, + stateSpaceDimensions: stateSpaceDimensions, + markDimensions: markDimensions, + markChannels: markChannels, + markBandwidth: markBandwidth, + ignoreNoSpikes: ignoreNoSpikes, + sumAcrossBatch: sumAcrossBatch, + nUnits: nUnits, + distanceThreshold: distanceThreshold, + sigmaRandomWalk: sigmaRandomWalk, + kernelLimit: kernelLimit, + device: device, + scalarType: scalarType + ); + + models.Add(name, model); + + return new PointProcessModelDisposable( + model, + Disposable.Create(() => { + models.Remove(name); + }) + ); + } + + internal static PointProcessModelDisposable Load( + string name, + string path, + Device? device = null + ) + { + var model = PointProcessModel.Load(path, device) as PointProcessModel ?? throw new InvalidOperationException("The model could not be loaded."); + models.Add(name, model); + + return new PointProcessModelDisposable( + model, + Disposable.Create(() => { + models.Remove(name); + }) + ); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/PointProcessModelNameConverter.cs b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelNameConverter.cs new file mode 100644 index 00000000..5bef2701 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/PointProcessModelNameConverter.cs @@ -0,0 +1,43 @@ +using Bonsai; +using Bonsai.Expressions; +using System.Linq; +using System.ComponentModel; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Provides a type converter to display a list of available point process models. +/// +public class PointProcessModelNameConverter : StringConverter +{ + /// + public override bool GetStandardValuesSupported(ITypeDescriptorContext context) + { + return true; + } + + /// + public override StandardValuesCollection GetStandardValues(ITypeDescriptorContext context) + { + if (context != null) + { + var workflowBuilder = (WorkflowBuilder)context.GetService(typeof(WorkflowBuilder)); + if (workflowBuilder != null) + { + var models = (from builder in workflowBuilder.Workflow.Descendants() + where builder.GetType() != typeof(DisableBuilder) + let createPointProcessModel = ExpressionBuilder.GetWorkflowElement(builder) as CreatePointProcessModel + where createPointProcessModel != null && !string.IsNullOrEmpty(createPointProcessModel.Name) + select createPointProcessModel.Name) + .Distinct() + .ToList(); + if (models.Count > 0) + { + return new StandardValuesCollection(models); + } + } + } + + return new StandardValuesCollection(new string[] { }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs b/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs new file mode 100644 index 00000000..20746345 --- /dev/null +++ b/src/Bonsai.ML.PointProcessDecoder/SavePointProcessModel.cs @@ -0,0 +1,105 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using PointProcessDecoder.Core; +using static TorchSharp.torch; + +namespace Bonsai.ML.PointProcessDecoder; + +/// +/// Saves the state of the point process model. +/// +[Combinator] +[WorkflowElementCategory(ElementCategory.Sink)] +[Description("Saves the state of the point process model.")] +public class SavePointProcessModel +{ + /// + /// The path to the folder where the state of the point process model will be saved. + /// + [Editor("Bonsai.Design.FolderNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the folder where the state of the point process model will be saved.")] + public string Path { get; set; } = string.Empty; + + /// + /// If true, the contents of the folder will be overwritten if it already exists. + /// + [Description("If true, the contents of the folder will be overwritten if it already exists.")] + public bool Overwrite { get; set; } = false; + + /// + /// Specifies the type of suffix to add to the save path. + /// If DateTime, a suffix with the current date and time is added to the save path in the format 'yyyyMMddHHmmss'. + /// + [Description("Specifies the type of suffix to add to the save path.")] + public SuffixType AddSuffix { get; set; } = SuffixType.None; + + /// + /// The name of the point process model to save. + /// + [TypeConverter(typeof(PointProcessModelNameConverter))] + [Description("The name of the point process model to save.")] + public string Model { get; set; } = string.Empty; + + /// + /// Saves the state of the point process model. + /// + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Do(_ => { + + var path = AddSuffix switch + { + SuffixType.DateTime => System.IO.Path.Combine(Path, $"{HighResolutionScheduler.Now:yyyyMMddHHmmss}"), + SuffixType.Guid => System.IO.Path.Combine(Path, Guid.NewGuid().ToString()), + _ => Path + }; + + if (string.IsNullOrEmpty(path)) + { + throw new InvalidOperationException("The save path is not specified."); + } + + if (System.IO.Directory.Exists(path)) + { + if (Overwrite) + { + System.IO.Directory.Delete(path, true); + } + else + { + throw new InvalidOperationException("The save path already exists and overwrite is set to False."); + } + } + + var model = PointProcessModelManager.GetModel(Model); + + model.Save(path); + }); + } + + /// + /// Specifies the type of suffix to add to the save path. + /// + public enum SuffixType + { + /// + /// No suffix is added to the save path. + /// + None, + + /// + /// A suffix with the current date and time is added to the save path. + /// + DateTime, + + /// + /// A suffix with a unique identifier is added to the save path. + /// + Guid + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Arange.cs b/src/Bonsai.ML.Torch/Arange.cs new file mode 100644 index 00000000..fa80c08e --- /dev/null +++ b/src/Bonsai.ML.Torch/Arange.cs @@ -0,0 +1,43 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using TorchSharp; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a 1-D tensor of values within a given range given the start, end, and step. + /// + [Combinator] + [Description("Creates a 1-D tensor of values within a given range given the start, end, and step.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Arange + { + /// + /// The start of the range. + /// + [Description("The start of the range.")] + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + [Description("The end of the range.")] + public int End { get; set; } = 10; + + /// + /// The step size between values. + /// + [Description("The step size between values.")] + public int Step { get; set; } = 1; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(arange(Start, End, Step))); + } + } +} diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj new file mode 100644 index 00000000..3a2f0298 --- /dev/null +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -0,0 +1,19 @@ + + + Bonsai.ML.Torch + A Bonsai package for TorchSharp tensor manipulations. + Bonsai Rx ML Tensors TorchSharp + net472;netstandard2.0 + true + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Clone.cs b/src/Bonsai.ML.Torch/Clone.cs new file mode 100644 index 00000000..b8dc15fd --- /dev/null +++ b/src/Bonsai.ML.Torch/Clone.cs @@ -0,0 +1,25 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Clones the input tensor. +/// +[Combinator] +[Description("Clones the input tensor.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Clone +{ + /// + /// Clones the input tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => tensor.clone()); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Concat.cs b/src/Bonsai.ML.Torch/Concat.cs new file mode 100644 index 00000000..45402621 --- /dev/null +++ b/src/Bonsai.ML.Torch/Concat.cs @@ -0,0 +1,100 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Collections.Generic; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Concatenates tensors along a given dimension. + /// + [Combinator] + [Description("Concatenates tensors along a given dimension.")] + [WorkflowElementCategory(ElementCategory.Combinator)] + public class Concat + { + /// + /// The dimension along which to concatenate the tensors. + /// + [Description("The dimension along which to concatenate the tensors.")] + public long Dimension { get; set; } = 0; + + /// + /// Takes any number of observable sequences of tensors and concatenates the input tensors along the specified dimension by zipping each tensor together. + /// + public IObservable Process(params IObservable[] sources) + { + return sources.Aggregate((current, next) => + current.Zip(next, (tensor1, tensor2) => + cat([tensor1, tensor2], Dimension))); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4, value.Item5], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4, value.Item5, value.Item6], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat(value.ToList(), Dimension); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/ConvertDataType.cs b/src/Bonsai.ML.Torch/ConvertDataType.cs new file mode 100644 index 00000000..efe3496b --- /dev/null +++ b/src/Bonsai.ML.Torch/ConvertDataType.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Converts the input tensor to the specified scalar type. + /// + [Combinator] + [Description("Converts the input tensor to the specified scalar type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ConvertDataType + { + /// + /// The scalar type to which to convert the input tensor. + /// + [Description("The scalar type to which to convert the input tensor.")] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Returns an observable sequence that converts the input tensor to the specified scalar type. + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => + { + return tensor.to_type(Type); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs new file mode 100644 index 00000000..52d8de1a --- /dev/null +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -0,0 +1,251 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reactive.Linq; +using System.Reflection; +using System.Xml.Serialization; +using Bonsai.Expressions; +using static TorchSharp.torch; +using Bonsai.ML.Data; +using Bonsai.ML.Python; +using TorchSharp; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a tensor from the specified values. + /// Uses Python-like syntax to specify the tensor values. + /// For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// + [Combinator] + [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] + [WorkflowElementCategory(ElementCategory.Source)] + public class CreateTensor : ExpressionBuilder + { + readonly Range argumentRange = new(0, 1); + + /// + public override Range ArgumentRange => argumentRange; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType Type + { + get => scalarType; + set => scalarType = value; + } + + private ScalarType scalarType = ScalarType.Float32; + + /// + /// The values of the tensor elements. + /// Uses Python-like syntax to specify the tensor values. + /// For example: "[[1, 2], [3, 4]]". + /// + [Description("The values of the tensor elements. Uses Python-like syntax to specify the tensor values. For example: \"[[1, 2], [3, 4]]\".")] + public string Values + { + get => values; + set + { + values = value.ToLower(); + } + } + + private string values = "[0]"; + + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + [Description("The device on which to create the tensor.")] + public Device Device + { + get => device; + set => device = value; + } + + private Device device = null; + + private Expression BuildTensorFromArray(Array arrayValues, Type returnType) + { + var rank = arrayValues.Rank; + int[] lengths = [.. Enumerable.Range(0, rank).Select(arrayValues.GetLength)]; + + var arrayCreationExpression = Expression.NewArrayBounds(returnType, [.. lengths.Select(len => Expression.Constant(len))]); + var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); + var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); + + List assignments = []; + for (int i = 0; i < values.Length; i++) + { + var indices = new Expression[rank]; + int temp = i; + for (int j = rank - 1; j >= 0; j--) + { + indices[j] = Expression.Constant(temp % lengths[j]); + temp /= lengths[j]; + } + var value = Expression.Constant(arrayValues.GetValue(indices.Select(e => ((ConstantExpression)e).Value).Cast().ToArray())); + var arrayAccess = Expression.ArrayAccess(arrayVariable, indices); + var assignArrayValue = Expression.Assign(arrayAccess, value); + assignments.Add(assignArrayValue); + } + + var tensorDataInitializationBlock = Expression.Block( + [arrayVariable], + assignArray, + Expression.Block(assignments), + arrayVariable + ); + + var tensorCreationMethodInfo = typeof(torch).GetMethod( + "tensor", [ + arrayVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool), + typeof(string).MakeArrayType() + ] + ); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorDataInitializationBlock, + Expression.Constant(scalarType, typeof(ScalarType?)), + Expression.Constant(device, typeof(Device)), + Expression.Constant(false, typeof(bool)), + Expression.Constant(null, typeof(string).MakeArrayType()) + ); + + var tensorVariable = Expression.Variable(typeof(torch.Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + [tensorVariable], + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + private Expression BuildTensorFromScalarValue(object scalarValue, Type returnType) + { + var valueVariable = Expression.Variable(returnType, "value"); + var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); + + var tensorDataInitializationBlock = Expression.Block( + [valueVariable], + assignValue, + valueVariable + ); + + var tensorCreationMethodInfo = typeof(torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(Device), + typeof(bool) + ] + ); + + Expression[] tensorCreationMethodArguments = [ + Expression.Constant(device, typeof(Device)), + Expression.Constant(false, typeof(bool)) + ]; + + if (tensorCreationMethodInfo == null) + { + tensorCreationMethodInfo = typeof(torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool) + ] + ); + + tensorCreationMethodArguments = [.. tensorCreationMethodArguments.Prepend( + Expression.Constant(scalarType, typeof(ScalarType?)) + )]; + } + + tensorCreationMethodArguments = [.. tensorCreationMethodArguments.Prepend( + tensorDataInitializationBlock + )]; + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorCreationMethodArguments + ); + + var tensorVariable = Expression.Variable(typeof(torch.Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + /// + public override Expression Build(IEnumerable arguments) + { + var returnType = ScalarTypeLookup.GetTypeFromScalarType(scalarType); + Type[] argTypes = [.. arguments.Select(arg => arg.Type)]; + + Type[] methodInfoArgumentTypes = [typeof(Tensor)]; + + MethodInfo[] methods = [.. typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance).Where(m => m.Name == "Process")]; + + var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) + .MakeGenericMethod( + arguments + .First() + .Type + .GetGenericArguments()[0] + ) : methods.FirstOrDefault(m => !m.IsGenericMethod); + + var tensorValues = ArrayHelper.ParseString(values, returnType); + var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); + var methodArguments = arguments.Count() == 0 ? [buildTensor] : arguments.Concat([buildTensor]); + + try + { + return Expression.Call( + Expression.Constant(this), + methodInfo, + methodArguments + ); + } + finally + { + values = StringFormatter.FormatToPython(tensorValues).ToLower(); + scalarType = ScalarTypeLookup.GetScalarTypeFromType(returnType); + } + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values. + /// + public IObservable Process(Tensor tensor) + { + return Observable.Return(tensor); + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. + /// + public IObservable Process(IObservable source, Tensor tensor) + { + return source.Select(_ => tensor); + } + } +} diff --git a/src/Bonsai.ML.Torch/Empty.cs b/src/Bonsai.ML.Torch/Empty.cs new file mode 100644 index 00000000..dafcee05 --- /dev/null +++ b/src/Bonsai.ML.Torch/Empty.cs @@ -0,0 +1,51 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates an empty tensor with the given data type and size. + /// + [Combinator] + [Description("Creates an empty tensor with the given data type and size.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Empty + { + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + public long[] Size { get; set; } = [0]; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates an empty tensor with the given data type and size. + /// + public IObservable Process() + { + return Observable.Defer(() => + { + return Observable.Return(empty(Size, Type)); + }); + } + + /// + /// Generates an observable sequence of empty tensors for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return empty(Size, Type); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/Index/BooleanIndex.cs b/src/Bonsai.ML.Torch/Index/BooleanIndex.cs new file mode 100644 index 00000000..f854aa56 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/BooleanIndex.cs @@ -0,0 +1,42 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents a boolean index that can be used to select elements from a tensor. +/// +[Combinator] +[Description("Represents a boolean index that can be used to select elements from a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class BooleanIndex +{ + /// + /// Gets or sets the boolean value used to select elements from a tensor. + /// + [Description("The boolean value used to select elements from a tensor.")] + public bool Value { get; set; } = false; + + /// + /// Generates the boolean index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Bool(Value)); + } + + /// + /// Processes the input sequence and generates the boolean index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Bool(Value)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/ColonIndex.cs b/src/Bonsai.ML.Torch/Index/ColonIndex.cs new file mode 100644 index 00000000..bfd9ca7b --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/ColonIndex.cs @@ -0,0 +1,36 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents the colon index used to select all elements along a given dimension. +/// +[Combinator] +[Description("Represents the colon index used to select all elements along a given dimension.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class ColonIndex +{ + /// + /// Generates the colon index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Colon); + } + + /// + /// Processes the input sequence and generates the colon index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Colon); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs b/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs new file mode 100644 index 00000000..06207a8e --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects all dimensions of a tensor. +/// +[Combinator] +[Description("Represents an index that selects all dimensions of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class EllipsisIndex +{ + + /// + /// Generates the ellipsis index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Ellipsis); + } + + /// + /// Processes the input sequence and generates the ellipsis index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Ellipsis); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/Index.cs b/src/Bonsai.ML.Torch/Index/Index.cs new file mode 100644 index 00000000..6846b8c7 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/Index.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Indexes a tensor by parsing the specified indices. +/// Indices are specified as a comma-separated values. +/// Currently supports Python-style slicing syntax. +/// This includes numeric indices, None, slices, and ellipsis. +/// +[Combinator] +[Description("Indexes a tensor by parsing the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Index +{ + /// + /// The indices to use for indexing the tensor. + /// + [Description("The indices to use for indexing the tensor. For example, '...,3:5,:'")] + public string Indexes { get; set; } = string.Empty; + + /// + /// Indexes the input tensor with the specified indices. + /// + /// + /// + public IObservable Process(IObservable source) + { + var index = IndexHelper.Parse(Indexes); + return source.Select(tensor => { + return tensor.index(index); + }); + } +} diff --git a/src/Bonsai.ML.Torch/Index/IndexHelper.cs b/src/Bonsai.ML.Torch/Index/IndexHelper.cs new file mode 100644 index 00000000..b62c1c2c --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/IndexHelper.cs @@ -0,0 +1,87 @@ +using System; +using System.Collections.Generic; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Provides helper methods to parse tensor indexes. +/// +public static class IndexHelper +{ + + /// + /// Parses the input string into an array of tensor indexes. + /// + /// + public static torch.TensorIndex[] Parse(string input) + { + if (string.IsNullOrEmpty(input)) + { + return [0]; + } + + var indexStrings = input.Split(','); + var indices = new torch.TensorIndex[indexStrings.Length]; + + for (int i = 0; i < indexStrings.Length; i++) + { + var indexString = indexStrings[i].Trim(); + if (int.TryParse(indexString, out int intIndex)) + { + indices[i] = torch.TensorIndex.Single(intIndex); + } + else if (indexString == ":") + { + indices[i] = torch.TensorIndex.Colon; + } + else if (indexString == "None") + { + indices[i] = torch.TensorIndex.None; + } + else if (indexString == "...") + { + indices[i] = torch.TensorIndex.Ellipsis; + } + else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") + { + indices[i] = torch.TensorIndex.Bool(indexString.ToLower() == "true"); + } + else if (indexString.Contains(":")) + { + string[] rangeParts = [.. indexString.Split(':')]; + var argsList = new List([null, null, null]); + try + { + for (int j = 0; j < rangeParts.Length; j++) + { + if (!string.IsNullOrEmpty(rangeParts[j])) + { + argsList[j] = long.Parse(rangeParts[j]); + } + } + } + catch (Exception) + { + throw new Exception($"Invalid index format: {indexString}"); + } + indices[i] = torch.TensorIndex.Slice(argsList[0], argsList[1], argsList[2]); + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + return indices; + } + + /// + /// Serializes the input array of tensor indexes into a string representation. + /// + /// + /// + public static string Serialize(torch.TensorIndex[] indexes) + { + return string.Join(", ", indexes); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/NoneIndex.cs b/src/Bonsai.ML.Torch/Index/NoneIndex.cs new file mode 100644 index 00000000..b10c9d86 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/NoneIndex.cs @@ -0,0 +1,36 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects no elements of a tensor. +/// +[Combinator] +[Description("Represents an index that selects no elements of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class NoneIndex +{ + /// + /// Generates the none index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.None); + } + + /// + /// Processes the input sequence and generates the none index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.None); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/SingleIndex.cs b/src/Bonsai.ML.Torch/Index/SingleIndex.cs new file mode 100644 index 00000000..e2f5decd --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/SingleIndex.cs @@ -0,0 +1,42 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects a single value of a tensor. +/// +[Combinator] +[Description("Represents an index that selects a single value of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class SingleIndex +{ + /// + /// Gets or sets the index value used to select a single element from a tensor. + /// + [Description("The index value used to select a single element from a tensor.")] + public long Index { get; set; } = 0; + + /// + /// Generates the single index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Single(Index)); + } + + /// + /// Processes the input sequence and generates the single index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Single(Index)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/SliceIndex.cs b/src/Bonsai.ML.Torch/Index/SliceIndex.cs new file mode 100644 index 00000000..b31802a4 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/SliceIndex.cs @@ -0,0 +1,54 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects a range of elements from a tensor. +/// +[Combinator] +[Description("Represents an index that selects a range of elements from a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class SliceIndex +{ + /// + /// Gets or sets the start index of the slice. + /// + [Description("The start index of the slice.")] + public long? Start { get; set; } = null; + + /// + /// Gets or sets the end index of the slice. + /// + [Description("The end index of the slice.")] + public long? End { get; set; } = null; + + /// + /// Gets or sets the step size of the slice. + /// + [Description("The step size of the slice.")] + public long? Step { get; set; } = null; + + /// + /// Generates the slice index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Slice(Start, End, Step)); + } + + /// + /// Processes the input sequence and generates the slice index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Slice(Start, End, Step)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/TensorIndex.cs b/src/Bonsai.ML.Torch/Index/TensorIndex.cs new file mode 100644 index 00000000..e3f6612d --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/TensorIndex.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that is created from a tensor. +/// +[Combinator] +[Description("Represents an index that is created from a tensor.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class TensorIndex +{ + /// + /// Converts the input tensor into a tensor index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(TorchSharp.torch.TensorIndex.Tensor); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs new file mode 100644 index 00000000..a598b794 --- /dev/null +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -0,0 +1,48 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using TorchSharp; + +namespace Bonsai.ML.Torch +{ + /// + /// Initializes the Torch device with the specified device type. + /// + [Combinator] + [Description("Initializes the Torch device with the specified device type.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class InitializeTorchDevice + { + /// + /// The device type to initialize. + /// + [Description("The device type to initialize.")] + public DeviceType DeviceType { get; set; } + + /// + /// Initializes the Torch device with the specified device type. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => + { + InitializeDeviceType(DeviceType); + return Observable.Return(new Device(DeviceType)); + }); + } + + /// + /// Initializes the Torch device when the input sequence produces an element. + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => { + InitializeDeviceType(DeviceType); + return new Device(DeviceType); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs new file mode 100644 index 00000000..6843779a --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + /// + [Combinator] + [Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Cholesky + { + /// + /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(linalg.cholesky); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs new file mode 100644 index 00000000..90c5a45d --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the determinant of a square matrix. + /// + [Combinator] + [Description("Computes the determinant of a square matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Det + { + /// + /// Computes the determinant of a square matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(linalg.det); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs new file mode 100644 index 00000000..a94c8eb8 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs @@ -0,0 +1,29 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the eigenvalue decomposition of a square matrix if it exists. + /// + [Combinator] + [Description("Computes the eigenvalue decomposition of a square matrix if it exists.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Eig + { + /// + /// Computes the eigenvalue decomposition of a square matrix if it exists. + /// + /// + /// + public IObservable> Process(IObservable source) + { + return source.Select(tensor => { + var (eigvals, eigvecs) = linalg.eig(tensor); + return Tuple.Create(eigvals, eigvecs); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs new file mode 100644 index 00000000..58bb4ce3 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs @@ -0,0 +1,27 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the inverse of the input matrix. + /// + [Combinator] + [Description("Computes the inverse of the input matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Inv + { + /// + /// Computes the inverse of the input matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(inv); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs new file mode 100644 index 00000000..82914d39 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes a matrix norm. + /// + [Combinator] + [Description("Computes a matrix norm.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class MatrixNorm + { + + /// + /// The dimensions along which to compute the matrix norm. + /// + public long[] Dimensions { get; set; } = null; + + /// + /// If true, the reduced dimensions are retained in the result as dimensions with size one. + /// + public bool Keepdim { get; set; } = false; + + /// + /// Computes a matrix norm. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => linalg.norm(tensor, dims: Dimensions, keepdim: Keepdim)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs b/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs new file mode 100644 index 00000000..c722f53b --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the singular value decomposition (SVD) of a matrix. + /// + [Combinator] + [Description("Computes the singular value decomposition (SVD) of a matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class SVD + { + /// + /// Whether to compute the full or reduced SVD. + /// + public bool FullMatrices { get; set; } = false; + + /// + /// Computes the singular value decomposition (SVD) of a matrix. + /// + /// + /// + public IObservable> Process(IObservable source) + { + return source.Select(tensor => { + var (u, s, v) = linalg.svd(tensor, fullMatrices: FullMatrices); + return Tuple.Create(u, s, v); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Linspace.cs b/src/Bonsai.ML.Torch/Linspace.cs new file mode 100644 index 00000000..f7e27887 --- /dev/null +++ b/src/Bonsai.ML.Torch/Linspace.cs @@ -0,0 +1,43 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. + /// + [Combinator] + [Description("Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Linspace + { + /// + /// The start of the range. + /// + [Description("The start of the range.")] + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + [Description("The end of the range.")] + public int End { get; set; } = 1; + + /// + /// The number of points to generate. + /// + [Description("The number of points to generate.")] + public int Count { get; set; } = 10; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(linspace(Start, End, Count))); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LoadTensor.cs b/src/Bonsai.ML.Torch/LoadTensor.cs new file mode 100644 index 00000000..af1e7f05 --- /dev/null +++ b/src/Bonsai.ML.Torch/LoadTensor.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Loads a tensor from the specified file. + /// + [Combinator] + [Description("Loads a tensor from the specified file.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadTensor + { + /// + /// The path to the file containing the tensor. + /// + [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the file containing the tensor.")] + public string Path { get; set; } + + /// + /// Loads a tensor from the specified file. + /// + /// + public IObservable Process() + { + return Observable.Return(Tensor.Load(Path)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Mean.cs b/src/Bonsai.ML.Torch/Mean.cs new file mode 100644 index 00000000..294edf31 --- /dev/null +++ b/src/Bonsai.ML.Torch/Mean.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Takes the mean of the tensor along the specified dimensions. + /// + [Combinator] + [Description("Takes the mean of the tensor along the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Mean + { + /// + /// The dimensions along which to compute the mean. + /// + [Description("The dimensions along which to compute the mean.")] + public long[] Dimensions { get; set; } + + /// + /// Takes the mean of the tensor along the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.mean(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/MeshGrid.cs b/src/Bonsai.ML.Torch/MeshGrid.cs new file mode 100644 index 00000000..a32f9eca --- /dev/null +++ b/src/Bonsai.ML.Torch/MeshGrid.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Collections.Generic; +using static TorchSharp.torch; +using System.Linq; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. + /// + [Combinator] + [Description("Creates a mesh grid from an observable sequence of enumerable of 1-D tensors.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class MeshGrid + { + /// + /// The indexing mode to use for the mesh grid. + /// + [Description("The indexing mode to use for the mesh grid.")] + public string Indexing { get; set; } = "ij"; + + /// + /// Creates a mesh grid from the input tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(tensors => meshgrid(tensors, indexing: Indexing)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs new file mode 100644 index 00000000..328c35ba --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs @@ -0,0 +1,78 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using static TorchSharp.torch.optim; + +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Trains a model using backpropagation. + /// + [Combinator] + [Description("Trains a model using backpropagation.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Backward + { + /// + /// The optimizer to use for training. + /// + public Optimizer Optimizer { get; set; } + + /// + /// The model to train. + /// + [XmlIgnore] + public ITorchModule Model { get; set; } + + /// + /// The loss function to use for training. + /// + public Loss Loss { get; set; } + + /// + /// Trains the model using backpropagation. + /// + /// + /// + public IObservable Process(IObservable> source) + { + optim.Optimizer optimizer = null; + switch (Optimizer) + { + case Optimizer.Adam: + optimizer = Adam(Model.Module.parameters()); + break; + } + + Module loss = null; + switch (Loss) + { + case Loss.NLLLoss: + loss = NLLLoss(); + break; + } + + var scheduler = lr_scheduler.StepLR(optimizer, 1, 0.7); + Model.Module.train(); + + return source.Select((input) => { + var (data, target) = input; + using (_ = NewDisposeScope()) + { + optimizer.zero_grad(); + + var prediction = Model.Forward(data); + var output = loss.forward(prediction, target); + + output.backward(); + + optimizer.step(); + return output.MoveToOuterDisposeScope(); + } + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs new file mode 100644 index 00000000..175ed3c0 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Runs forward inference on the input tensor using the specified model. + /// + [Combinator] + [Description("Runs forward inference on the input tensor using the specified model.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Forward + { + /// + /// The model to use for inference. + /// + [XmlIgnore] + public ITorchModule Model { get; set; } + + /// + /// Runs forward inference on the input tensor using the specified model. + /// + /// + /// + public IObservable Process(IObservable source) + { + Model.Module.eval(); + return source.Select(Model.Forward); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs new file mode 100644 index 00000000..5cde6f73 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs @@ -0,0 +1,22 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Represents an interface for a Torch module. + /// + public interface ITorchModule + { + /// + /// The module. + /// + public nn.Module Module { get; } + + /// + /// Runs forward inference on the input tensor using the specified model. + /// + /// + /// + public Tensor Forward(Tensor tensor); + } +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs new file mode 100644 index 00000000..ac791d8d --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs @@ -0,0 +1,84 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Loads a neural network module from a specified architecture. + /// + [Combinator] + [Description("Loads a neural network module from a specified architecture.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadModuleFromArchitecture + { + /// + /// The model architecture to load. + /// + [Description("The model architecture to load.")] + public Models.ModelArchitecture ModelArchitecture { get; set; } + + /// + /// The device on which to load the model. + /// + [Description("The device on which to load the model.")] + [XmlIgnore] + public Device Device { get; set; } + + /// + /// The optional path to the model weights file. + /// + [Description("The optional path to the model weights file.")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelWeightsPath { get; set; } + + private int numClasses = 10; + /// + /// The number of classes in the dataset. + /// + [Description("The number of classes in the dataset.")] + public int NumClasses + { + get => numClasses; + set + { + if (value <= 0) + { + numClasses = 10; + } + else + { + numClasses = value; + } + } + } + + /// + /// Loads the neural network module from the specified architecture. + /// + /// + /// + public IObservable Process() + { + var modelArchitecture = ModelArchitecture.ToString().ToLower(); + var device = Device; + + nn.Module module = modelArchitecture switch + { + "alexnet" => new Models.AlexNet(modelArchitecture, numClasses, device), + "mobilenet" => new Models.MobileNet(modelArchitecture, numClasses, device), + "mnist" => new Models.MNIST(modelArchitecture, numClasses, device), + _ => throw new ArgumentException($"Model {modelArchitecture} not supported.") + }; + + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); + + var torchModule = new TorchModuleAdapter(module); + return Observable.Defer(() => { + return Observable.Return((ITorchModule)torchModule); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs new file mode 100644 index 00000000..fb3b2b78 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -0,0 +1,43 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Loads a TorchScript module from the specified file path. + /// + [Combinator] + [Description("Loads a TorchScript module from the specified file path.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadScriptModule + { + + /// + /// The device on which to load the model. + /// + [Description("The device on which to load the model.")] + [XmlIgnore] + public Device Device { get; set; } = CPU; + + /// + /// The path to the TorchScript model file. + /// + [Description("The path to the TorchScript model file.")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelPath { get; set; } + + /// + /// Loads the TorchScript module from the specified file path. + /// + /// + public IObservable Process() + { + var scriptModule = jit.load(ModelPath, Device); + var torchModule = new TorchModuleAdapter(scriptModule); + return Observable.Return((ITorchModule)torchModule); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Loss.cs b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs new file mode 100644 index 00000000..376139c1 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs @@ -0,0 +1,13 @@ +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Represents a loss function that computes the loss value for a given input and target tensor. + /// + public enum Loss + { + /// + /// Computes the negative log likelihood loss. + /// + NLLLoss, + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs new file mode 100644 index 00000000..c3d19d55 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -0,0 +1,74 @@ +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + /// + /// Modified version of original AlexNet to fix CIFAR10 32x32 images. + /// + internal class AlexNet : Module + { + private readonly Module features; + private readonly Module avgPool; + private readonly Module classifier; + + /// + /// Constructs a new AlexNet model. + /// + /// + /// + /// + public AlexNet(string name, int numClasses, Device device = null) : base(name) + { + features = Sequential( + ("c1", Conv2d(3, 64, kernel_size: 3, stride: 2, padding: 1)), + ("r1", ReLU(inplace: true)), + ("mp1", MaxPool2d(kernel_size: [ 2, 2 ])), + ("c2", Conv2d(64, 192, kernel_size: 3, padding: 1)), + ("r2", ReLU(inplace: true)), + ("mp2", MaxPool2d(kernel_size: [ 2, 2 ])), + ("c3", Conv2d(192, 384, kernel_size: 3, padding: 1)), + ("r3", ReLU(inplace: true)), + ("c4", Conv2d(384, 256, kernel_size: 3, padding: 1)), + ("r4", ReLU(inplace: true)), + ("c5", Conv2d(256, 256, kernel_size: 3, padding: 1)), + ("r5", ReLU(inplace: true)), + ("mp3", MaxPool2d(kernel_size: [ 2, 2 ]))); + + avgPool = AdaptiveAvgPool2d([ 2, 2 ]); + + classifier = Sequential( + ("d1", Dropout()), + ("l1", Linear(256 * 2 * 2, 4096)), + ("r1", ReLU(inplace: true)), + ("d2", Dropout()), + ("l2", Linear(4096, 4096)), + ("r3", ReLU(inplace: true)), + ("d3", Dropout()), + ("l3", Linear(4096, numClasses)) + ); + + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + /// + /// Forward pass of the AlexNet model. + /// + /// + /// + public override Tensor forward(Tensor input) + { + var f = features.forward(input); + var avg = avgPool.forward(f); + + var x = avg.view([ avg.shape[0], 256 * 2 * 2 ]); + + return classifier.forward(x); + } + } + +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs new file mode 100644 index 00000000..8bd3e0a4 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -0,0 +1,86 @@ +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + /// + /// Represents a simple convolutional neural network for the MNIST dataset. + /// + internal class MNIST : Module + { + private readonly Module conv1; + private readonly Module conv2; + private readonly Module fc1; + private readonly Module fc2; + + private readonly Module pool1; + + private readonly Module relu1; + private readonly Module relu2; + private readonly Module relu3; + + private readonly Module dropout1; + private readonly Module dropout2; + + private readonly Module flatten; + private readonly Module logsm; + + /// + /// Constructs a new MNIST model. + /// + /// + /// + /// + public MNIST(string name, int numClasses, Device device = null) : base(name) + { + conv1 = Conv2d(1, 32, 3); + conv2 = Conv2d(32, 64, 3); + fc1 = Linear(9216, 128); + fc2 = Linear(128, numClasses); + + pool1 = MaxPool2d(kernel_size: [2, 2]); + + relu1 = ReLU(); + relu2 = ReLU(); + relu3 = ReLU(); + + dropout1 = Dropout(0.25); + dropout2 = Dropout(0.5); + + flatten = Flatten(); + logsm = LogSoftmax(1); + + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + /// + /// Forward pass of the MNIST model. + /// + /// + /// + public override Tensor forward(Tensor input) + { + var l11 = conv1.forward(input); + var l12 = relu1.forward(l11); + + var l21 = conv2.forward(l12); + var l22 = relu2.forward(l21); + var l23 = pool1.forward(l22); + var l24 = dropout1.forward(l23); + + var x = flatten.forward(l24); + + var l31 = fc1.forward(x); + var l32 = relu3.forward(l31); + var l33 = dropout2.forward(l32); + + var l41 = fc2.forward(l33); + + return logsm.forward(l41); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs new file mode 100644 index 00000000..a5f7701a --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -0,0 +1,77 @@ +using System; +using System.Collections.Generic; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + /// + /// MobileNet model. + /// + internal class MobileNet : Module + { + private readonly long[] planes = [ 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 ]; + private readonly long[] strides = [ 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 ]; + + private readonly Module layers; + + /// + /// Constructs a new MobileNet model. + /// + /// + /// + /// + /// + public MobileNet(string name, int numClasses, Device device = null) : base(name) + { + if (planes.Length != strides.Length) throw new ArgumentException("'planes' and 'strides' must have the same length."); + + var modules = new List<(string, Module)> + { + ($"conv2d-first", Conv2d(3, 32, kernel_size: 3, stride: 1, padding: 1, bias: false)), + ($"bnrm2d-first", BatchNorm2d(32)), + ($"relu-first", ReLU()) + }; + MakeLayers(modules, 32); + modules.Add(("avgpool", AvgPool2d([2, 2]))); + modules.Add(("flatten", Flatten())); + modules.Add(($"linear", Linear(planes[planes.Length-1], numClasses))); + + layers = Sequential(modules); + + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + private void MakeLayers(List<(string, Module)> modules, long in_planes) + { + + for (var i = 0; i < strides.Length; i++) { + var out_planes = planes[i]; + var stride = strides[i]; + + modules.Add(($"conv2d-{i}a", Conv2d(in_planes, in_planes, kernel_size: 3, stride: stride, padding: 1, groups: in_planes, bias: false))); + modules.Add(($"bnrm2d-{i}a", BatchNorm2d(in_planes))); + modules.Add(($"relu-{i}a", ReLU())); + modules.Add(($"conv2d-{i}b", Conv2d(in_planes, out_planes, kernel_size: 1L, stride: 1L, padding: 0L, bias: false))); + modules.Add(($"bnrm2d-{i}b", BatchNorm2d(out_planes))); + modules.Add(($"relu-{i}b", ReLU())); + + in_planes = out_planes; + } + } + + /// + /// Forward pass of the MobileNet model. + /// + /// + /// + public override Tensor forward(Tensor input) + { + return layers.forward(input); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs new file mode 100644 index 00000000..98a30216 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs @@ -0,0 +1,23 @@ +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + /// + /// Represents the architecture of a neural network model. + /// + public enum ModelArchitecture + { + /// + /// The AlexNet model architecture. + /// + AlexNet, + + /// + /// The MobileNet model architecture. + /// + MobileNet, + + /// + /// The MNIST model architecture. + /// + MNIST + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs new file mode 100644 index 00000000..4ab09dbd --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs @@ -0,0 +1,13 @@ +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Represents an optimizer that updates the parameters of a neural network. + /// + public enum Optimizer + { + /// + /// Adam optimizer. + /// + Adam, + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs new file mode 100644 index 00000000..c426aedf --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs @@ -0,0 +1,43 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Saves the model to a file. + /// + [Combinator] + [Description("Saves the model to a file.")] + [WorkflowElementCategory(ElementCategory.Sink)] + public class SaveModel + { + /// + /// The model to save. + /// + [Description("The model to save.")] + [XmlIgnore] + public ITorchModule Model { get; set; } + + /// + /// The path to save the model. + /// + [Description("The path to save the model.")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelPath { get; set; } + + /// + /// Saves the model to the specified file path. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Do(input => { + Model.Module.save(ModelPath); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs new file mode 100644 index 00000000..3ec35071 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs @@ -0,0 +1,51 @@ +using System; +using System.Reflection; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.NeuralNets +{ + /// + /// Represents a torch module adapter that wraps a torch module or script module. + /// + public class TorchModuleAdapter : ITorchModule + { + private readonly nn.Module _module = null; + + private readonly jit.ScriptModule _scriptModule = null; + + private readonly Func _forwardFunc; + + /// + /// The module. + /// + public nn.Module Module { get; } + + /// + /// Initializes a new instance of the class. + /// + /// + public TorchModuleAdapter(nn.Module module) + { + _module = module; + _forwardFunc = _module.forward; + Module = _module; + } + + /// + /// Initializes a new instance of the class. + /// + /// + public TorchModuleAdapter(jit.ScriptModule scriptModule) + { + _scriptModule = scriptModule; + _forwardFunc = _scriptModule.forward; + Module = _scriptModule; + } + + /// + public Tensor Forward(Tensor input) + { + return _forwardFunc(input); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Ones.cs b/src/Bonsai.ML.Torch/Ones.cs new file mode 100644 index 00000000..77d26577 --- /dev/null +++ b/src/Bonsai.ML.Torch/Ones.cs @@ -0,0 +1,43 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a tensor filled with ones. + /// + [Combinator] + [Description("Creates a tensor filled with ones.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Ones + { + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with ones. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + + /// + /// Generates an observable sequence of tensors filled with ones for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return ones(Size); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs new file mode 100644 index 00000000..31a27ac4 --- /dev/null +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -0,0 +1,127 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Helper class to convert between OpenCV mats, images and Torch tensors. + /// + public static class OpenCVHelper + { + private static Dictionary bitDepthLookup = new Dictionary + { + { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, + { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, + { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, + { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, + { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, + { ScalarType.Int8, (IplDepth.S8, Depth.S8) } + }; + + /// + /// Converts an OpenCV image to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(IplImage image) + { + if (image == null) + return empty([ 0, 0, 0 ]); + + int height = image.Height; + int channels = image.Channels; + var width = image.WidthStep / channels; + + var iplDepth = image.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; + + IntPtr data = image.ImageData; + ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; + + if (data == IntPtr.Zero) + throw new InvalidOperationException($"Got {nameof(IplImage)} without backing data, this isn't expected to be possible."); + + return TorchSharpEx.CreateTensorFromUnmanagedMemoryWithManagedAnchor(data, image, dimensions, tensorType); + } + + /// + /// Converts an OpenCV mat to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(Mat mat) + { + if (mat == null) + return empty([0, 0, 0 ]); + + int width = mat.Size.Width; + int height = mat.Size.Height; + int channels = mat.Channels; + + var depth = mat.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; + + IntPtr data = mat.Data; + ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; + + if (data == IntPtr.Zero) + throw new InvalidOperationException($"Got {nameof(Mat)} without backing data, this isn't expected to be possible."); + + return TorchSharpEx.CreateTensorFromUnmanagedMemoryWithManagedAnchor(data, mat, dimensions, tensorType); + } + + private static (int height, int width, int channels) GetImageDimensions(this Tensor tensor) + { + if (tensor.dim() != 3) + throw new ArgumentException("The tensor does not have exactly 3 dimensions."); + + checked + { return ((int)tensor.size(0), (int)tensor.size(1), (int)tensor.size(2)); } + } + + /// + /// Converts a Torch tensor to an OpenCV image. + /// + /// + /// + public unsafe static IplImage ToImage(Tensor tensor) + { + var (height, width, channels) = tensor.GetImageDimensions(); + + var tensorType = tensor.dtype; + var iplDepth = bitDepthLookup[tensorType].IplDepth; + var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels); + + // Create a temporary tensor backed by the image's memory and copy the source tensor into it + ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; + using var imageTensor = TorchSharpEx.CreateStackTensor(image.ImageData, image, dimensions, tensorType); + imageTensor.Tensor.copy_(tensor); + + return image; + } + + /// + /// Converts a Torch tensor to an OpenCV mat. + /// + /// + /// + public unsafe static Mat ToMat(Tensor tensor) + { + var (height, width, channels) = tensor.GetImageDimensions(); + + var tensorType = tensor.dtype; + var depth = bitDepthLookup[tensorType].Depth; + var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels); + + // Create a temporary tensor backed by the matrix's memory and copy the source tensor into it + ReadOnlySpan dimensions = stackalloc long[] { height, width, channels }; + using var matTensor = TorchSharpEx.CreateStackTensor(mat.Data, mat, dimensions, tensorType); + matTensor.Tensor.copy_(tensor); + + return mat; + } + } +} diff --git a/src/Bonsai.ML.Torch/Permute.cs b/src/Bonsai.ML.Torch/Permute.cs new file mode 100644 index 00000000..507d31d2 --- /dev/null +++ b/src/Bonsai.ML.Torch/Permute.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Permutes the dimensions of the input tensor according to the specified permutation. + /// + [Combinator] + [Description("Permutes the dimensions of the input tensor according to the specified permutation.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Permute + { + /// + /// The permutation of the dimensions. + /// + [Description("The permutation of the dimensions.")] + public long[] Dimensions { get; set; } = [0]; + + /// + /// Returns an observable sequence that permutes the dimensions of the input tensor according to the specified permutation. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.permute(Dimensions); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Reshape.cs b/src/Bonsai.ML.Torch/Reshape.cs new file mode 100644 index 00000000..fdd07fa5 --- /dev/null +++ b/src/Bonsai.ML.Torch/Reshape.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + [Combinator] + [Description("Reshapes the input tensor according to the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Reshape + { + /// + /// The dimensions of the reshaped tensor. + /// + [Description("The dimensions of the reshaped tensor.")] + public long[] Dimensions { get; set; } = [0]; + + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.reshape(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/SaveTensor.cs b/src/Bonsai.ML.Torch/SaveTensor.cs new file mode 100644 index 00000000..1a3c4772 --- /dev/null +++ b/src/Bonsai.ML.Torch/SaveTensor.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Saves the input tensor to the specified file. + /// + [Combinator] + [Description("Saves the input tensor to the specified file.")] + [WorkflowElementCategory(ElementCategory.Sink)] + public class SaveTensor + { + /// + /// The path to the file where the tensor will be saved. + /// + [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the file where the tensor will be saved.")] + public string Path { get; set; } = string.Empty; + + /// + /// Saves the input tensor to the specified file. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Do(tensor => tensor.save(Path)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ScalarTypeLookup.cs b/src/Bonsai.ML.Torch/ScalarTypeLookup.cs new file mode 100644 index 00000000..1e4c6c57 --- /dev/null +++ b/src/Bonsai.ML.Torch/ScalarTypeLookup.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Provides methods to look up tensor data types. + /// + public static class ScalarTypeLookup + { + private static readonly Dictionary _lookup = new() + { + { ScalarType.Byte, (typeof(byte), "byte") }, + { ScalarType.Int16, (typeof(short), "short") }, + { ScalarType.Int32, (typeof(int), "int") }, + { ScalarType.Int64, (typeof(long), "long") }, + { ScalarType.Float32, (typeof(float), "float") }, + { ScalarType.Float64, (typeof(double), "double") }, + { ScalarType.Bool, (typeof(bool), "bool") }, + { ScalarType.Int8, (typeof(sbyte), "sbyte") }, + }; + + /// + /// Returns the type corresponding to the specified tensor data type. + /// + /// + /// + public static Type GetTypeFromScalarType(ScalarType type) => _lookup[type].Type; + + /// + /// Returns the string representation corresponding to the specified tensor data type. + /// + /// + /// + public static string GetStringFromScalarType(ScalarType type) => _lookup[type].StringValue; + + /// + /// Returns the tensor data type corresponding to the specified string representation. + /// + /// + /// + public static ScalarType GetScalarTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Returns the tensor data type corresponding to the specified type. + /// + /// + /// + public static ScalarType GetScalarTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Set.cs b/src/Bonsai.ML.Torch/Set.cs new file mode 100644 index 00000000..18dcc02a --- /dev/null +++ b/src/Bonsai.ML.Torch/Set.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Sets the value of the input tensor at the specified index. +/// +[Combinator] +[Description("Sets the value of the input tensor at the specified index.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Set +{ + /// + /// The index at which to set the value. + /// + [Description("The index at which to set the value.")] + public string Index { get; set; } = string.Empty; + + /// + /// The value to set at the specified index. + /// + [XmlIgnore] + [Description("The value to set at the specified index.")] + public Tensor Value { get; set; } = null; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + var indexes = Torch.Index.IndexHelper.Parse(Index); + return tensor.index_put_(Value, indexes); + }); + } + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => { + var tensor = input.Item1; + var index = input.Item2; + return tensor.index_put_(Value, index); + }); + } + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => { + var tensor = input.Item1; + var indexes = input.Item2; + return tensor.index_put_(Value, indexes); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Sum.cs b/src/Bonsai.ML.Torch/Sum.cs new file mode 100644 index 00000000..1e4c1a2c --- /dev/null +++ b/src/Bonsai.ML.Torch/Sum.cs @@ -0,0 +1,31 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Computes the sum of the input tensor elements along the specified dimensions. + /// + [Combinator] + [Description("Computes the sum of the input tensor elements along the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Sum + { + /// + /// The dimensions along which to compute the sum. + /// + public long[] Dimensions { get; set; } + + /// + /// Computes the sum of the input tensor elements along the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.sum(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tile.cs b/src/Bonsai.ML.Torch/Tile.cs new file mode 100644 index 00000000..df25b8ac --- /dev/null +++ b/src/Bonsai.ML.Torch/Tile.cs @@ -0,0 +1,33 @@ +using static TorchSharp.torch; +using System; +using System.ComponentModel; +using System.Reactive.Linq; + +namespace Bonsai.ML.Torch +{ + /// + /// Constructs a tensor by repeating the elements of input. + /// + [Combinator] + [Description("Constructs a tensor by repeating the elements of input.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Tile + { + /// + /// The number of repetitions in each dimension. + /// + public long[] Dimensions { get; set; } + + /// + /// Constructs a tensor by repeating the elements of input along the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tile(tensor, Dimensions); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/ToArray.cs b/src/Bonsai.ML.Torch/ToArray.cs new file mode 100644 index 00000000..e9ca21f1 --- /dev/null +++ b/src/Bonsai.ML.Torch/ToArray.cs @@ -0,0 +1,74 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Converts the input tensor into an array of the specified element type. + /// + [Combinator] + [Description("Converts the input tensor into an array of the specified element type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + public class ToArray : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public ToArray() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type mapping used to convert the input tensor into an array. + /// + [Description("Gets or sets the type mapping used to convert the input tensor into an array.")] + public TypeMapping Type { get; set; } + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + methodInfo = methodInfo.MakeGenericMethod(returnType); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into an array of the specified element type. + /// + /// + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + return tensor.data().ToArray(); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToDevice.cs b/src/Bonsai.ML.Torch/ToDevice.cs new file mode 100644 index 00000000..0377df46 --- /dev/null +++ b/src/Bonsai.ML.Torch/ToDevice.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Moves the input tensor to the specified device. + /// + [Combinator] + [Description("Moves the input tensor to the specified device.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToDevice + { + /// + /// The device to which the input tensor should be moved. + /// + [XmlIgnore] + [Description("The device to which the input tensor should be moved.")] + public Device Device { get; set; } + + /// + /// Returns the input tensor moved to the specified device. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.to(Device); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToImage.cs b/src/Bonsai.ML.Torch/ToImage.cs new file mode 100644 index 00000000..70c8227e --- /dev/null +++ b/src/Bonsai.ML.Torch/ToImage.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Converts the input tensor into an OpenCV image. + /// + [Combinator] + [Description("Converts the input tensor into an OpenCV image.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToImage + { + /// + /// Converts the input tensor into an OpenCV image. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(OpenCVHelper.ToImage); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToMat.cs b/src/Bonsai.ML.Torch/ToMat.cs new file mode 100644 index 00000000..1b1746ed --- /dev/null +++ b/src/Bonsai.ML.Torch/ToMat.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Converts the input tensor into an OpenCV mat. + /// + [Combinator] + [Description("Converts the input tensor into an OpenCV mat.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToMat + { + /// + /// Converts the input tensor into an OpenCV mat. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(OpenCVHelper.ToMat); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToNDArray.cs b/src/Bonsai.ML.Torch/ToNDArray.cs new file mode 100644 index 00000000..89b7b1e1 --- /dev/null +++ b/src/Bonsai.ML.Torch/ToNDArray.cs @@ -0,0 +1,83 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Converts the input tensor into an array of the specified element type and rank. + /// + [Combinator] + [Description("Converts the input tensor into an array of the specified element type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + public class ToNDArray : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public ToNDArray() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type mapping used to convert the input tensor into an array. + /// + [Description("Gets or sets the type mapping used to convert the input tensor into an array.")] + public TypeMapping Type { get; set; } + + /// + /// Gets or sets the rank of the output array. Must be greater than or equal to 1. + /// + [Description("Gets or sets the rank of the output array. Must be greater than or equal to 1.")] + public int Rank { get; set; } = 1; + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + var lengths = new int[Rank]; + Type arrayType = Array.CreateInstance(returnType, lengths).GetType(); + methodInfo = methodInfo.MakeGenericMethod(returnType, arrayType); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into an array of the specified element type. + /// + /// + /// + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + return (TResult)(object)tensor.data().ToNDArray(); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs new file mode 100644 index 00000000..7af26dc9 --- /dev/null +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -0,0 +1,134 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Converts the input value into a tensor. + /// + [Combinator] + [Description("Converts the input value into a tensor.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToTensor + { + /// + /// Converts an int into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a double into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a byte into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a bool into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a float into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a long into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a short into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an array into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an IplImage into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(OpenCVHelper.ToTensor); + } + + /// + /// Converts a Mat into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(OpenCVHelper.ToTensor); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/TorchSharpEx.cs b/src/Bonsai.ML.Torch/TorchSharpEx.cs new file mode 100644 index 00000000..47077ae4 --- /dev/null +++ b/src/Bonsai.ML.Torch/TorchSharpEx.cs @@ -0,0 +1,124 @@ +#nullable enable +#pragma warning disable CS1573 // Parameter has no matching param tag in the XML comment +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.InteropServices; +using TorchSharp; + +namespace Bonsai.ML.Torch; + +internal unsafe static class TorchSharpEx +{ + [DllImport("LibTorchSharp")] + private static extern IntPtr THSTensor_new(IntPtr rawArray, DeleterCallback deleter, long* dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, byte requires_grad); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private delegate void DeleterCallback(IntPtr context); + + // Torch does not expect the deleter callback to be able to be null since it's a C++ reference and LibTorchSharp + // does not expose the functions used to create a tensor without a deleter callback, so we must use a no-op callback + private static readonly DeleterCallback NullDeleterCallback = _ => { }; + + // Acts as GC root for unmanaged callbacks, value is unused + private static readonly ConcurrentDictionary ActiveDeleterCallbacks = new(); + + /// Creates a from unmanaged memory that is owned by a managed object + /// The unmanaged memory that will back the tensor, must remain valid and fixed for the lifetime of the tensor + /// The managed .NET object which owns + public static torch.Tensor CreateTensorFromUnmanagedMemoryWithManagedAnchor(IntPtr data, object managedAnchor, ReadOnlySpan dimensions, torch.ScalarType dataType) + { + //PERF: Ideally the deleter would receive the GCHandle as the context rather than the pointer to the unmanaged memory since that's + // would allow us to use a GCHandle to root the anchor and free it directly rather than capturing it in the lambda. + // Torch itself has the ability to set the context to something else via `TensorMaker::context(void* value, ContextDeleter deleter)`, + // but unfortunately this method isn't exposed in LibTorchSharp. + // This is similar to the inefficient method TorchSharp uses, which has quite a lot of unecessary overhead (particularly the unmanaged + // delegate allocation), but we do skip some aspects like the GC handle allocation. + // It may be tempting to use a GCHandle and a static delegate, looking up the GC handle from the native memory pointer, but doing this + // without breaking the ability to create redundant tensors over the same data is overly complicated. + DeleterCallback? deleter = null; + deleter = (data) => + { + GC.KeepAlive(managedAnchor); + + if (!ActiveDeleterCallbacks.TryRemove(deleter!, out _)) + Debug.Fail($"The same tensor data handle deleter was called more than once!"); + }; + + if (!ActiveDeleterCallbacks.TryAdd(deleter, default)) + Debug.Fail("Unreachable"); + + fixed (long* dimensionsPtr = &dimensions[0]) + { + IntPtr tensorHandle = THSTensor_new(data, deleter, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + if (tensorHandle == IntPtr.Zero) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(data, deleter, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + } + + if (tensorHandle == IntPtr.Zero) + torch.CheckForErrors(); + + return torch.Tensor.UnsafeCreateTensor(tensorHandle); + } + } + + internal readonly ref struct StackTensor + { + public readonly torch.Tensor Tensor; + private readonly object? Anchor; + + internal StackTensor(torch.Tensor tensor, object? anchor) + { + Tensor = tensor; + Anchor = anchor; + } + + public void Dispose() + { + Tensor.Dispose(); + GC.KeepAlive(Anchor); + } + } + + /// Creates a tensor which is associated with a stack scope. + /// The unmanaged memory that will back the tensor, must remain valid and fixed for the lifetime of the tensor + /// An optional managed .NET object which owns + /// + /// The returned stack tensor must be disposed. The tensor it refers to will not be valid outside of the scope where it was allocated. + /// + internal static StackTensor CreateStackTensor(IntPtr data, object? managedAnchor, ReadOnlySpan dimensions, torch.ScalarType dataType) + { + fixed (long* dimensionsPtr = &dimensions[0]) + { + IntPtr tensorHandle = THSTensor_new(data, NullDeleterCallback, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + if (tensorHandle == IntPtr.Zero) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(data, NullDeleterCallback, dimensionsPtr, dimensions.Length, (sbyte)dataType, (sbyte)dataType, 0, 0, 0); + } + + if (tensorHandle == IntPtr.Zero) + torch.CheckForErrors(); + + torch.Tensor result = torch.Tensor.UnsafeCreateTensor(tensorHandle); + return new StackTensor(result, data); + } + } + + /// Gets a pointer to the tensor's backing memory + /// The data backing a tensor is not necessarily contiguous or even present on the CPU, consider other strategies before using this method. + public static IntPtr DangerousGetDataPointer(this torch.Tensor tensor) + { + [DllImport("LibTorchSharp")] + static extern IntPtr THSTensor_data(IntPtr handle); + + IntPtr data = THSTensor_data(tensor.Handle); + if (data == IntPtr.Zero) + torch.CheckForErrors(); + return data; + } +} diff --git a/src/Bonsai.ML.Torch/View.cs b/src/Bonsai.ML.Torch/View.cs new file mode 100644 index 00000000..65a409be --- /dev/null +++ b/src/Bonsai.ML.Torch/View.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a new view of the input tensor with the specified dimensions. + /// + [Combinator] + [Description("Creates a new view of the input tensor with the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class View + { + /// + /// The dimensions of the reshaped tensor. + /// + [Description("The dimensions of the reshaped tensor.")] + public long[] Dimensions { get; set; } = [0]; + + /// + /// Creates a new view of the input tensor with the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.view(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Vision/Normalize.cs b/src/Bonsai.ML.Torch/Vision/Normalize.cs new file mode 100644 index 00000000..60b87c44 --- /dev/null +++ b/src/Bonsai.ML.Torch/Vision/Normalize.cs @@ -0,0 +1,46 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torchvision; + +namespace Bonsai.ML.Torch.Vision +{ + /// + /// Normalizes the input tensor with the mean and standard deviation. + /// + [Combinator] + [Description("Normalizes the input tensor with the mean and standard deviation.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Normalize + { + /// + /// The mean values for each channel. + /// + [Description("The mean values for each channel.")] + public double[] Means { get; set; } = [ 0.1307 ]; + + /// + /// The standard deviation values for each channel. + /// + [Description("The standard deviation values for each channel.")] + public double[] StdDevs { get; set; } = [ 0.3081 ]; + + private ITransform transform = null; + + /// + /// Normalizes the input tensor with the mean and standard deviation. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + transform ??= transforms.Normalize(Means, StdDevs, tensor.dtype, tensor.device); + return transform.call(tensor); + }).Finally(() => transform = null); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs new file mode 100644 index 00000000..e99bdce6 --- /dev/null +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -0,0 +1,43 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a tensor filled with zeros. + /// + [Combinator] + [Description("Creates a tensor filled with zeros.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Zeros + { + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with zeros. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(zeros(Size))); + } + + /// + /// Generates an observable sequence of tensors filled with zeros for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return zeros(Size); + }); + } + } +} \ No newline at end of file