diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index e44fc390..aaa9ad7f 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -20,6 +20,12 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Directory.Build.props = Directory.Build.props EndProjectSection EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.HiddenMarkovModels", "src\Bonsai.ML.HiddenMarkovModels\Bonsai.ML.HiddenMarkovModels.csproj", "{BAD0A733-8EFB-4EAF-9648-9851656AF7FF}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Python", "src\Bonsai.ML.Python\Bonsai.ML.Python.csproj", "{39A4414F-52B1-42D7-82FA-E65DAD885264}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Data", "src\Bonsai.ML.Data\Bonsai.ML.Data.csproj", "{A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -42,6 +48,18 @@ Global {81DB65B3-EA65-4947-8CF1-0E777324C082}.Debug|Any CPU.Build.0 = Debug|Any CPU {81DB65B3-EA65-4947-8CF1-0E777324C082}.Release|Any CPU.ActiveCfg = Release|Any CPU {81DB65B3-EA65-4947-8CF1-0E777324C082}.Release|Any CPU.Build.0 = Release|Any CPU + {BAD0A733-8EFB-4EAF-9648-9851656AF7FF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BAD0A733-8EFB-4EAF-9648-9851656AF7FF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BAD0A733-8EFB-4EAF-9648-9851656AF7FF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BAD0A733-8EFB-4EAF-9648-9851656AF7FF}.Release|Any CPU.Build.0 = Release|Any CPU + {39A4414F-52B1-42D7-82FA-E65DAD885264}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {39A4414F-52B1-42D7-82FA-E65DAD885264}.Debug|Any CPU.Build.0 = Debug|Any CPU + {39A4414F-52B1-42D7-82FA-E65DAD885264}.Release|Any CPU.ActiveCfg = Release|Any CPU + {39A4414F-52B1-42D7-82FA-E65DAD885264}.Release|Any CPU.Build.0 = Release|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -51,6 +69,9 @@ Global {17AABD18-E275-4409-9E33-3D755B809FF6} = {12312384-8828-4786-AE19-EFCEDF968290} {196AA5C7-AE8A-477B-B01A-B94676EC60EE} = {12312384-8828-4786-AE19-EFCEDF968290} {81DB65B3-EA65-4947-8CF1-0E777324C082} = {461FE3E2-21C4-47F9-8405-DF72326AAB2B} + {BAD0A733-8EFB-4EAF-9648-9851656AF7FF} = {12312384-8828-4786-AE19-EFCEDF968290} + {39A4414F-52B1-42D7-82FA-E65DAD885264} = {12312384-8828-4786-AE19-EFCEDF968290} + {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {B6468F13-97CD-45E0-9E1E-C122D7F1E09F} diff --git a/Directory.Build.props b/Directory.Build.props index 7a90d2f0..a96cf011 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -14,14 +14,16 @@ icon.png true git - 0.2.0 + 0.3.0 - 9.0 + 12.0 - + + $(RootNamespace).$(MSBuildProjectName).svg + \ No newline at end of file diff --git a/README.md b/README.md index 282052ba..e40ffcf9 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,44 @@ # Bonsai - Machine Learning -The Bonsai.ML project is a collection of packages with reactive infrastructure for adding machine learning algorithms in Bonsai. Below you will find the list of packages (and the included subpackages) currently available within the Bonsai.ML collection. +The **Bonsai.ML** project is a collection of packages designed to integrate machine learning algorithms with Bonsai. This document provides an overview of the available packages and their functionalities. -* Bonsai.ML - provides core functionality across all Bonsai.ML packages. -* Bonsai.ML.LinearDynamicalSystems - package for performing inference of linear dynamical systems. Interfaces with the [lds_python](https://github.com/joacorapela/lds_python) package. - - *Bonsai.ML.LinearDynamicalSystems.Kinematics* - subpackage included in the LinearDynamicalSystems package which supports using the Kalman Filter to infer kinematic data. - - *Bonsai.ML.LinearDynamicalSystems.LinearRegression* - subpackage included in the LinearDynamicalSystems package which supports using the Kalman Filter to perform Bayesian linear regression. -* Bonsai.ML.Visualizers - provides a set of visualizers for dynamic graphing/plotting. +## Core Packages + +- **Bonsai.ML** + Provides common tools and functionality. + +- **Bonsai.ML.Data** + Provides common tools and functionality for working with data. + +- **Bonsai.ML.Python** + Provides common tools and functionality for C# packages to interface with Python. + +## Available Packages + +### Bonsai.ML.LinearDynamicalSystems +Facilitates inference using linear dynamical systems (LDS). It interfaces with the [lds_python](https://github.com/joacorapela/lds_python) package using the [Bonsai - Python Scripting](https://github.com/bonsai-rx/python-scripting) library. + +- **Bonsai.ML.LinearDynamicalSystems.Kinematics** + Supports the use of the Kalman Filter for inferring kinematic data. + +- **Bonsai.ML.LinearDynamicalSystems.LinearRegression** + Utilizes the Kalman Filter to perform online Bayesian linear regression. + +### Bonsai.ML.HiddenMarkovModels +Facilitates inference using Hidden Markov Models (HMMs). It interfaces with the [ssm](https://github.com/lindermanlab/ssm) package using the [Bonsai - Python Scripting](https://github.com/bonsai-rx/python-scripting) library. + +- **Bonsai.ML.HiddenMarkovModels.Observations** + Provides functionality for specifying different types of observations. + +- **Bonsai.ML.HiddenMarkovModels.Transitions** + Provides functionality for specifying different types of transition models. + +### Bonsai.ML.Visualizers +Graphing and plotting library for visualizing data. > [!NOTE] -> Bonsai.ML packages are installed through Bonsai's integrated package manager and are typically available for use immediately. However, certain packages may require additional steps for installation. See the dedicated package section for specific guides and documentation. +> Bonsai.ML packages can be installed through Bonsai's integrated package manager and are generally ready for immediate use. However, some packages may require additional installation steps. Refer to the specific package section for detailed installation guides and documentation. ## Acknowledgments -Development of this package was supported by funding from the Biotechnology and Biological Sciences Research Council [grant number BB/W019132/1]. \ No newline at end of file +Development of the Bonsai.ML package is supported by the Biotechnology and Biological Sciences Research Council [grant number BB/W019132/1]. \ No newline at end of file diff --git a/docs/articles/HiddenMarkovModels/hmm-getting-started.md b/docs/articles/HiddenMarkovModels/hmm-getting-started.md new file mode 100644 index 00000000..8acfa203 --- /dev/null +++ b/docs/articles/HiddenMarkovModels/hmm-getting-started.md @@ -0,0 +1,47 @@ +# Getting Started + +The workflow starts with creating a python runtime, followed by loading the ssm package, referred to as the HMM module. After this, you can instantiate the HMM model and pass it observations of data to perform inference. Since this package relies on communication between Bonsai and Python, the observations that the model uses must be formatted into a valid string representation of a Python data type, namely a list of numbers. + +## Workflow + +```mermaid + +flowchart LR + + A(["Create Python Runtime"]) + B(["Load HMM Module"]) + C(["Create HMM"]) + D(["Generate Observations"]) + E(["Infer Hidden State"]) + + A --> B + B --> C + C --> D + D --> E + +``` + +> [!NOTE] +> Due to the way Bonsai.ML interacts with Python, it is necessary for the first two steps to complete before instantiating the model. It is important to know that the initialization of the Python runtime, loading the module, and creating the model takes time to complete, and that only once the model has been created can inference be performed. + +## Implementation + +Below is a simplified Bonsai workflow that implements the core logic of the package. + +:::workflow +![HMM Implementation](~/workflows/HMMImplementation.bonsai) +::: + +A `CreateRuntime` node is used to initialize a python runtime engine, which gets passed to a `BehaviorSubject` called `PythonEngine`. Bonsai's `CreateRuntime` node should automatically detect the python virtual environment if it was activated in the same terminal that was used to launch Bonsai, otherwise the path to the virtual environment can be specified in the `CreateRuntime` node by setting the `PythonHome` property. + +Next, the `PythonEngine` node is passed to a `LoadHMMModule` node which will load the ssm package into the python environment. + +Once the HMM module has been initialized, the `CreateHMM` node instantiates a python instance of the HMM model. Here, you can specify the initialization parameters of the model and provide a `ModelName` parameter that gets used to reference the model in other parts of the Bonsai workflow. + +It is crucial that the `Data` are formatted into a string that the model can use, namely a string representing a Python list. For example, if you pass a Tuple with 2 items as your data, then the formatter should look something like `"[" + Item1.ToString() + Item2.ToString() + "]"`. The output of this should be used as your observations into the model, so connect your data source to a `Subject` named `Data` and modify the `FormatToPython` node to fit with your data. + +`Observations` are then passed to an `InferState` node, which will use the specified model (given by the `ModelName` property) to infer the latent state of the model and outputs the `StateProbabilities`, or probabilities of being in each state given the observation. + +### Further Examples + +For further examples and demonstrations for how this package works, see the [Bonsai - Machine Learning Examples](~/examples/README.md) section. diff --git a/docs/articles/HiddenMarkovModels/hmm-overview.md b/docs/articles/HiddenMarkovModels/hmm-overview.md new file mode 100644 index 00000000..4b13b883 --- /dev/null +++ b/docs/articles/HiddenMarkovModels/hmm-overview.md @@ -0,0 +1,134 @@ +# Bonsai.ML.HiddenMarkovModels Overview + +The HiddenMarkovModels package provides a Bonsai interface to interact with the [ssm](https://github.com/lindermanlab/ssm) package. + +## General Guide + +Since the package relies on both Bonsai and Python, installation steps for both are required. Detailed instructions are provided for installing the package in a new environment, integrating it with existing workflows, and running examples from the example folder. + +- To install the package for integrating with existing workflows, see the [Installation Guide](#installation-guide). +- To get started with integrating the package into workflows, see the [Getting Started](hmm-getting-started.md) section. +- To test the specific examples provided, check out the [Examples](~/examples/README.md) tab. + +## Installation Guide + +### Dependencies + +To get started, you must install the following tools: + +- [Python (v3.10)](https://www.python.org/downloads/) +- [dotnet-sdk (v8)](https://dotnet.microsoft.com/en-us/download) +- [Git](https://git-scm.com/downloads) +- [Bonsai-Rx Templates tool](https://www.nuget.org/packages/Bonsai.Templates) + + > [!TIP] + > Install Python through the standard installer and add to the system PATH. + +### Installation Guide - Windows + +#### Creating New Project Environment + +1. Open the terminal and create a project folder: + ```cmd + cd ~\Desktop + mkdir HiddenMarkovModels + cd .\HiddenMarkovModels + ``` + +2. Create a Python virtual environment: + ```cmd + python -m venv .venv + ``` + +3. Create a Bonsai environment: + ```cmd + dotnet new bonsaienv + ``` + +#### Python Environment Setup + +1. Activate the Python environment: + ```cmd + .\.venv\Scripts\activate + ``` + +2. Install the ssm package: + ```cmd + pip install numpy cython + pip install ssm@git+https://github.com/lindermanlab/ssm@6c856ad3967941d176eb348bcd490cfaaa08ba60 + ``` + +3. Verify installation: + ```python + import ssm + ``` + +#### Bonsai Environment Setup + +1. Launch Bonsai: + ```cmd + .bonsai\Bonsai.exe + ``` + +2. Install the `Bonsai.ML.HiddenMarkovModels` package from the Package Manager. + > [!TIP] + > You can quickly search for the package by entering `Bonsai.ML.HiddenMarkovModels` into the search bar. + +### Installation Guide - Linux + +#### Creating New Project Environment + +1. Create a project folder: + ```cmd + cd ~/Desktop + mkdir HiddenMarkovModels + cd HiddenMarkovModels + ``` + +2. Create a Python virtual environment: + ```cmd + python3 -m venv .venv + ``` + > [!TIP] + > Install the virtual environment package if needed: + > ```cmd + > sudo apt install python3.10-venv + > ``` + +3. Create a Bonsai environment: + ```cmd + dotnet new bonsaienv + ``` + > [!NOTE] + > This step uses the [Bonsai Linux Environment Template tool](https://github.com/ncguilbeault/bonsai-linux-environment-template) for easy creation of bonsai environments on Linux. + > See [this discussion](https://github.com/orgs/bonsai-rx/discussions/1101) for more information on getting Bonsai running on Linux. + +#### Python Environment Setup + +1. Activate the Python environment: + ```cmd + source .venv/bin/activate + ``` + +2. Install the ssm package: + ```cmd + pip install numpy cython + pip install ssm@git+https://github.com/lindermanlab/ssm@6c856ad3967941d176eb348bcd490cfaaa08ba60 + ``` + +3. Verify installation: + ```python + import ssm + ``` + +#### Bonsai Environment Setup + +1. Activate and launch Bonsai: + ```cmd + source .bonsai/activate + bonsai + ``` + +2. Install the `Bonsai.ML.HiddenMarkovModels` package from the Package Manager. + > [!TIP] + > You can quickly search for the package by entering `Bonsai.ML.HiddenMarkovModels` into the search bar. diff --git a/docs/articles/LinearDynamicalSystems/lds-installation-guide-linux.md b/docs/articles/LinearDynamicalSystems/lds-installation-guide-linux.md index 5e3d1726..9e2e7056 100644 --- a/docs/articles/LinearDynamicalSystems/lds-installation-guide-linux.md +++ b/docs/articles/LinearDynamicalSystems/lds-installation-guide-linux.md @@ -46,7 +46,7 @@ python3 -m venv .venv dotnet new bonsaienvl ``` -When prompted, enter yes to run the powershell setup script. +When prompted, enter yes to run the setup script. ### Python Environment Setup Guide diff --git a/docs/articles/LinearDynamicalSystems/lds-installation-guide-windows.md b/docs/articles/LinearDynamicalSystems/lds-installation-guide-windows.md index 8326a985..d4b5f999 100644 --- a/docs/articles/LinearDynamicalSystems/lds-installation-guide-windows.md +++ b/docs/articles/LinearDynamicalSystems/lds-installation-guide-windows.md @@ -10,7 +10,6 @@ To get started, you must install the following tools: - [dotnet-sdk (v8)](https://dotnet.microsoft.com/en-us/download) - [Git](https://git-scm.com/downloads) - [Bonsai-Rx Templates tool](https://www.nuget.org/packages/Bonsai.Templates) -- [Microsoft Visual C++ Redistributable](https://aka.ms/vs/16/release/vc_redist.x64.exe) > [!WARNING] > Be sure to check the specific python version and dotnet-sdk version you have installed, as different version than the ones we recommend may or may not work with this guide. @@ -32,7 +31,7 @@ python -m venv .venv ``` > [!TIP] -> If receive an error that says, `python cannot be found`, check to ensure that python is available on the system path. If you just installed python, it may be necessary to restart the terminal. +> If you receive an error that says, `python cannot be found`, check to ensure that python is available on the system path. If you just installed python, it may be necessary to restart the terminal. 3. Create a bonsai environment. When prompted, enter yes to run the powershell setup script. diff --git a/docs/articles/toc.yml b/docs/articles/toc.yml index e5756b6f..e22b0b80 100644 --- a/docs/articles/toc.yml +++ b/docs/articles/toc.yml @@ -8,4 +8,9 @@ - name: Installing on Linux href: LinearDynamicalSystems/lds-installation-guide-linux.md - name: Getting Started - href: LinearDynamicalSystems/lds-getting-started.md \ No newline at end of file + href: LinearDynamicalSystems/lds-getting-started.md +- name: HiddenMarkovModels +- name: Overview + href: HiddenMarkovModels/hmm-overview.md +- name: Getting Started + href: HiddenMarkovModels/hmm-getting-started.md \ No newline at end of file diff --git a/docs/workflows/HMMImplementation.bonsai b/docs/workflows/HMMImplementation.bonsai new file mode 100644 index 00000000..02366943 --- /dev/null +++ b/docs/workflows/HMMImplementation.bonsai @@ -0,0 +1,56 @@ + + + + + + + + + PythonEngine + + + PythonEngine + + + + hmm + 2 + 2 + Gaussian + Stationary + + + Data + + + + + + Observation + + + Observation + + + hmm + + + InferredState + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.LinearDynamicalSystems/Bonsai.ML.LinearDynamicalSystems.svg b/elementIcon.svg similarity index 100% rename from src/Bonsai.ML.LinearDynamicalSystems/Bonsai.ML.LinearDynamicalSystems.svg rename to elementIcon.svg diff --git a/src/Bonsai.ML.Data/ArrayHelper.cs b/src/Bonsai.ML.Data/ArrayHelper.cs new file mode 100644 index 00000000..d2f55999 --- /dev/null +++ b/src/Bonsai.ML.Data/ArrayHelper.cs @@ -0,0 +1,184 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Bonsai.ML.Data +{ + /// + /// Provides a set of static methods for working with arrays. + /// + public static class ArrayHelper + { + /// + /// Serializes the input data into a JSON string representation. + /// + /// The data to serialize. + /// A JSON string representation of the input data. + public static string SerializeToJson(object data) + { + if (data is Array array) + { + return SerializeArrayToJson(array); + } + else + { + return JsonConvert.SerializeObject(data); + } + } + + /// + /// Serializes the input array into a JSON string representation. + /// + /// The array to serialize. + /// A JSON string representation of the input array. + public static string SerializeArrayToJson(Array array) + { + StringBuilder sb = new StringBuilder(); + SerializeArrayRecursive(array, sb, [0]); + return sb.ToString(); + } + + private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) + { + if (indices.Length < array.Rank) + { + sb.Append("["); + int length = array.GetLength(indices.Length); + for (int i = 0; i < length; i++) + { + int[] newIndices = new int[indices.Length + 1]; + indices.CopyTo(newIndices, 0); + newIndices[indices.Length] = i; + SerializeArrayRecursive(array, sb, newIndices); + if (i < length - 1) + { + sb.Append(", "); + } + } + sb.Append("]"); + } + else + { + object value = array.GetValue(indices); + sb.Append(value.ToString()); + } + } + + private static bool IsValidJson(string input) + { + int squareBrackets = 0; + foreach (char c in input) + { + if (c == '[') squareBrackets++; + else if (c == ']') squareBrackets--; + } + return squareBrackets == 0; + } + + /// + /// Parses the input JSON string into an object of the specified type. If the input is a JSON array, the method will attempt to parse it into an array of the specified type. + /// + /// The JSON string to parse. + /// The data type of the object. + /// An object of the specified type containing the parsed JSON data. + public static object ParseString(string input, Type dtype = null) + { + if (!IsValidJson(input)) + { + throw new ArgumentException($"Parameter: {nameof(input)} is not valid JSON."); + } + var obj = JsonConvert.DeserializeObject(input); + int depth = ParseDepth(obj); + if (depth == 0) + { + return Convert.ChangeType(input, dtype); + } + int[] dimensions = ParseDimensions(obj, depth); + var resultArray = Array.CreateInstance(dtype, dimensions); + PopulateArray(obj, resultArray, [0], dtype); + return resultArray; + } + + private static int ParseDepth(JToken token, int currentDepth = 0) + { + if (token is JArray arr && arr.Count > 0) + { + return ParseDepth(arr[0], currentDepth + 1); + } + return currentDepth; + } + + private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) + { + if (depth == 0 || !(token is JArray)) + { + return [0]; + } + + List dimensions = new List(); + JToken current = token; + + while (current != null && current is JArray) + { + JArray currentArray = current as JArray; + dimensions.Add(currentArray.Count); + if (currentArray.Count > 0) + { + if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) + { + throw new ArgumentException($"Error parsing parameter: {nameof(token)}. Array dimensions are inconsistent."); + } + + if (!(currentArray.First() is JArray)) + { + if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) + { + throw new ArgumentException($"Error parsing parameter: {nameof(token)}. All values in the array must be of the same type. Only numeric or boolean types are supported."); + } + } + } + + current = currentArray.Count > 0 ? currentArray[0] : null; + } + + if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) + { + var subArrayDimensions = new HashSet(); + foreach (JArray subArr in arr) + { + int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); + subArrayDimensions.Add(string.Join(",", subDims)); + } + + if (subArrayDimensions.Count > 1) + { + throw new ArgumentException("Inconsistent array dimensions."); + } + } + + return dimensions.ToArray(); + } + + private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) + { + if (token is JArray arr) + { + for (int i = 0; i < arr.Count; i++) + { + int[] newIndices = new int[indices.Length + 1]; + Array.Copy(indices, newIndices, indices.Length); + newIndices[newIndices.Length - 1] = i; + PopulateArray(arr[i], array, newIndices, dtype); + } + } + else + { + var values = Convert.ChangeType(token, dtype); + array.SetValue(values, indices); + } + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Data/Bonsai.ML.Data.csproj b/src/Bonsai.ML.Data/Bonsai.ML.Data.csproj new file mode 100644 index 00000000..6f8a44b2 --- /dev/null +++ b/src/Bonsai.ML.Data/Bonsai.ML.Data.csproj @@ -0,0 +1,11 @@ + + + Bonsai.ML.Data + Provides common tools and functionality for working with ML data. + Bonsai Rx ML Machine Learning Data + net472;netstandard2.0 + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Bonsai.ML.HiddenMarkovModels.csproj b/src/Bonsai.ML.HiddenMarkovModels/Bonsai.ML.HiddenMarkovModels.csproj new file mode 100644 index 00000000..8dac151d --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Bonsai.ML.HiddenMarkovModels.csproj @@ -0,0 +1,23 @@ + + + Bonsai.ML.HiddenMarkovModels + A Bonsai package for hidden markov models. Interfaces with the SSM python package using the Bonsai.Scripting.Python package. + Bonsai Rx ML SSM Hidden Markov Models + net472;netstandard2.0 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/CheckFitFinished.bonsai b/src/Bonsai.ML.HiddenMarkovModels/CheckFitFinished.bonsai new file mode 100644 index 00000000..caef8d29 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/CheckFitFinished.bonsai @@ -0,0 +1,151 @@ + + + + + + Source1 + + + + + + + CheckFitFinished + + + + + + + + PT0S + PT1S + + + + + + + + + + hmm + + + Name + + + {0}.get_fit_finished() + + + + + + + + + + + HMMModule + + + + + + + + + hmm.get_fit_finished() + + + + FitFinished? + + + + Source1 + + + it.ToString() == "True" + + + + + + + + + + + hmm + + + Name + + + {0}.reset_fit_loop() + + + + + + + + + + + HMMModule + + + + + + + + + hmm.reset_fit_loop() + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/CreateHMM.bonsai b/src/Bonsai.ML.HiddenMarkovModels/CreateHMM.bonsai new file mode 100644 index 00000000..83cce5b3 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/CreateHMM.bonsai @@ -0,0 +1,128 @@ + + + + + + + + + + hmm + + + + hmm + + + Name + + + Source1 + + + + + + + + + + + 2 + 2 + Gaussian + Stationary + + + + + + + + + + {0}=HiddenMarkovModel({1}) + Item1,Item2 + + + + + + + + HMMModule + + + + + + + + + hmm=HiddenMarkovModel(num_states=14, dimensions=2, observation_type="gaussian", init_state_distribution=[-18.4206626312063, -18.4206809139513, -1.60158960596107E-07, -18.4206809139513, -18.4206809139513, -18.4206809077833, -18.4206809139513, -18.4206809139513, -18.4206809139513, -18.4206809139512, -18.4206809139513, -18.4206809139513, -18.4206809139513, -18.4206809139513, -18.4206809139513, -18.4206809139513, -18.4049264491465], transition_matrix=[[-0.0917453060101488, -81.2027391869505, -81.2024595933819, -81.1937739522692, -81.2027400688462, -2.87807199972272, -81.2024665218064, -81.2027400688462, -81.2027400688462, -69.7226036942302, -81.0262420410483, -81.2027400688462, -81.2027400688462, -81.2027400688462, -72.3853538894001, -81.2027400686459, -3.46032916644607], [-79.3509060625393, -0.238318392160603, -79.3509060627742, -3.59028488542753, -79.3406966323115, -79.3509060343066, -2.07482406652899, -79.3509049453246, -71.8712269840478, -79.1042606674896, -64.7841761999145, -79.3509060627742, -79.3509060627736, -4.91052913823847, -79.3501000762402, -2.96596669611815, -79.3509060627742], [-6.56494712317959, -80.2574114399651, -0.026045953329265, -80.257411439958, -80.2574114399651, -80.2530342501165, -80.2574114399645, -80.2574114399651, -80.2574114399651, -80.2574114383822, -80.257411439963, -80.2574114399651, -80.2574114399651, -80.2574114399651, -80.2574111255061, -80.2574114399651, -3.71724663052794], [-79.7429093428404, -2.53565795122399, -79.7429229020156, -0.224476567256381, -79.5630094990036, -79.7427375380492, -19.7913509211774, -79.7429070271501, -79.7429229020155, -2.4282565377201, -21.6260812291176, -79.7429229020156, -79.7429229020155, -79.7429186717266, -79.5090809839381, -3.39127701396958, -79.7429229020155], [-78.8818948027961, -78.7832645463526, -78.8818948027961, -78.8499675310739, -0.121636510293837, -78.8818948027961, -78.8818779633122, -3.58627108739195, -78.4075671859456, -42.3643447224343, -78.8818948020075, -4.51454403425192, -78.7231231874378, -78.87538610611, -78.8818948027961, -2.57860476672775, -78.8818948027961], [-2.7555263648427, -81.0801260570515, -81.0801929012697, -80.8956240194587, -81.0801928008559, -0.138839571009248, -81.0735688412014, -81.0801929022382, -81.0801929022382, -7.82562008352599, -8.61673274109047, -81.0801929022382, -81.0801929022382, -81.0801929022365, -2.72606578281343, -81.0801915285753, -81.0781516310129], [-79.5041067991954, -4.1599732418435, -79.5041087741123, -2.68539625218839, -79.5041061909036, -79.5040779393979, -0.134619177577566, -79.5041087741123, -79.5041087741123, -73.6956564343321, -3.22954046210177, -79.5041087741123, -79.5041087741123, -79.4233470364088, -5.96249967131849, -79.4986866951327, -79.5041087741123], [-77.6194474368312, -77.6193823080006, -77.6194474368312, -77.619444418796, -77.6081779744754, -77.6194474368312, -77.6194468907703, -0.103028273335152, -77.6190640785059, -77.6194036382625, -77.6194474368306, -77.5387276499137, -2.3238237214254, -77.6193398550474, -77.6194474368312, -77.6189244782998, -77.6194474368312], [-79.1272032685144, -79.1271836355738, -79.1272032685144, -79.1272032685144, -4.66185628154607, -79.1272032685144, -79.1272032685137, -79.1271923677142, -0.0773138048193137, -79.1272032685144, -79.1272032685144, -4.01038951846901, -79.1225247474742, -4.13699455343842, -79.1272032685144, -3.47846842197873, -79.1272032685144], [-79.7969561517743, -79.765068795109, -79.7990815280627, -4.09625383296592, -75.3723004292721, -79.7911688467326, -79.720196927761, -79.7990800333231, -79.7990815280627, -0.133182168954187, -2.25632214661011, -79.7990815280627, -79.7990815280627, -79.7990814068833, -79.561391245246, -5.70642909748951, -79.799081522695], [-80.4624207711008, -80.4542080911501, -80.4674009479227, -2.94217521646244, -80.4642331700328, -80.4245500525541, -4.96905052833919, -80.4674009479227, -80.4674009479227, -4.80319039045672, -0.17771751882944, -80.4674009479227, -80.4674009479227, -80.4674008912272, -2.35475258982744, -80.3817582244586, -80.4674009478945], [-78.6981767986761, -78.6981767986758, -78.6981767986761, -78.6981767986761, -7.12957810042029, -78.6981767986761, -78.6981767986761, -78.6963612933677, -2.80129905358812, -78.6981767986761, -78.6981767986761, -0.0852294830554562, -3.90374244629814, -78.6981474734794, -78.6981767986761, -78.6869101761337, -78.6981767986761], [-78.3498622799896, -78.3498622799895, -78.3498622799896, -78.3498622799896, -4.01387798881909, -78.3498622799896, -78.3498622799895, -78.3452810994854, -78.2643761749177, -78.3498622799896, -78.3498622799896, -2.85311264769245, -0.0787486209906157, -78.3498612775223, -78.3498622799896, -78.3492719452334, -78.3498622799896], [-79.3538024195476, -3.13614011174884, -79.3538024195476, -79.3530894101277, -79.35193649913, -79.3538024195476, -79.33901250438, -79.3538024052913, -4.27770214689655, -79.3537999739468, -79.3538024152413, -79.3537476832867, -79.3538001630013, -0.0590336585215668, -79.3538024194494, -15.1403292390908, -79.3538024195476], [-80.8141761499455, -80.8299171351327, -80.8321243984454, -79.1140751828713, -80.7513948366163, -2.46917174255065, -59.1530651603166, -80.8321243984454, -80.8321243984454, -45.6935849036213, -2.7323847350908, -80.8321243984454, -80.8321243984454, -80.832124389641, -0.16311012032014, -7.15177346756846, -80.8321242340716], [-79.4246034910015, -3.48334852232484, -79.4246034910015, -5.7737892503062, -2.93766284446338, -79.4246034909991, -79.336913429818, -79.4164695681632, -4.34565213609343, -3.36034672586263, -79.4246014668271, -79.4244070303921, -79.4246022597813, -3.36103501832537, -79.4246034902494, -0.18535102870502, -79.4246034910015], [-3.13408601573862, -80.8656106188802, -4.3248839327091, -80.8654855265545, -80.8656106208069, -80.1649555362992, -80.8655973607584, -80.8656106208069, -80.8656106208069, -80.8382022891591, -80.8655171006304, -80.8656106208069, -80.8656106208069, -80.8656106208069, -80.8555535140235, -80.8656106208068, -0.0584500177614161]], observation_means=[[2.05933262307724, 10.8204332645415], [68.7980115395418, 149.313879689245], [0.0113367588749664, 0.0808845676209286], [36.8854430152106, 135.261998572096], [97.9835708320062, 475.786905396804], [4.85126962230773, 23.6361472322517], [44.6355192653933, 49.5566309333179], [266.329727183569, 710.965268736304], [173.521982388192, 173.718928476886], [17.616028067975, 196.233369659563], [17.0875828299435, 70.9425248823598], [204.737788814457, 189.798442944057], [260.143513242843, 262.732194667054], [122.001566709437, 101.435491122701], [9.67037560981843, 42.5625437060252], [96.2721341933364, 273.356625327438], [0.605044362615348, 3.14081365313523]], observation_covs=[[[0.905758344483852, 0], [1.2373681195321, 4.4831763035586]], [[14.3996184698293, 0], [10.9077146872554, 38.9063669683094]], [[0.015436036882364, 0], [0.109020732271086, 0.0444526519767825]], [[8.62408922415511, 0], [23.396230710446, 33.1375162351934]], [[65.0846425847929, 0], [-28.8024157037375, 67.0423181141901]], [[1.82643624659801, 0], [-1.17971110603607, 9.14393736242647]], [[10.9198108801625, 0], [3.33008442316356, 22.7612245320485]], [[51.3305896755247, 0], [13.743041597268, 189.84256808673]], [[12.5414378446675, 0], [-15.4565754923402, 78.6949525926226]], [[8.58696987374004, 0], [9.38430390783284, 66.0305935687787]], [[6.0341314486845, 0], [-9.64384045715633, 21.4295441508927]], [[9.1823933029414, 0], [13.3851117087156, 57.213049685513]], [[28.9288965124962, 0], [31.0513047915833, 60.3716689888219]], [[18.062239316743, 0], [11.8053467734538, 44.8280733649598]], [[3.79541470022462, 0], [-7.87657635998183, 13.537109427599]], [[28.3235630671307, 0], [-20.844035765898, 49.090267281817]], [[0.435879792544103, 0], [1.50536883080018, 1.53287076550159]]]) + + + + + + + + + + + + HMMModule + + + + + + + + + hmm + + + + + 2 + 2 + Gaussian + Stationary + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/DeserializeFromJson.cs b/src/Bonsai.ML.HiddenMarkovModels/DeserializeFromJson.cs new file mode 100644 index 00000000..a3c5ceb0 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/DeserializeFromJson.cs @@ -0,0 +1,54 @@ +using System.ComponentModel; +using System; +using System.Reactive.Linq; +using Bonsai.Expressions; +using System.Xml.Serialization; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Linq; +using Newtonsoft.Json; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// Deserializes a sequence of JSON strings into data model objects. + /// + [DefaultProperty(nameof(Type))] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [Description("Deserializes a sequence of JSON strings into data model objects.")] + public partial class DeserializeFromJson : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public DeserializeFromJson() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type of the object to deserialize. + /// + [Description("The type of the object to deserialize.")] + public TypeMapping Type { get; set; } + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + return Expression.Call( + typeof(DeserializeFromJson), + nameof(Process), + [ returnType ], + Enumerable.Single(arguments)); + } + + private static IObservable Process(IObservable source) + { + return source.Select(JsonConvert.DeserializeObject); + } + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/GetGaussianObservationsStatistics.bonsai b/src/Bonsai.ML.HiddenMarkovModels/GetGaussianObservationsStatistics.bonsai new file mode 100644 index 00000000..5dc43d91 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/GetGaussianObservationsStatistics.bonsai @@ -0,0 +1,68 @@ + + + + + + Source1 + + + + + + + + + + hmm + + + Name + + + + + + + + + + + HMMModule + + + + + + + + + hmm + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai new file mode 100644 index 00000000..67cad75f --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai @@ -0,0 +1,108 @@ + + + + + + Source1 + + + + + + + + + hmm + + + Name + + + + + + InferState + + + + Source1 + + + {0}.infer_state({1}) + Item2, Item1 + + + + + + + + HMMModule + + + + + + + + + hmm.most_likely_states([59.7382107943162,3.99285183724331]) + + + + + + + + + + + + + + + + + + + + + HMMModule + + + + + + + + + hmm + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/LoadHMMModule.bonsai b/src/Bonsai.ML.HiddenMarkovModels/LoadHMMModule.bonsai new file mode 100644 index 00000000..632b53bc --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/LoadHMMModule.bonsai @@ -0,0 +1,40 @@ + + + + + + Source1 + + + LoadHMMModule + + + + + HiddenMarkovModels + Bonsai.ML.HiddenMarkovModels:main.py + + + + + + + + + + + HMMModule + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/ModelParameters.cs b/src/Bonsai.ML.HiddenMarkovModels/ModelParameters.cs new file mode 100644 index 00000000..e2a446ec --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/ModelParameters.cs @@ -0,0 +1,224 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Text; +using Newtonsoft.Json; +using System.Xml.Serialization; +using Python.Runtime; +using Bonsai.ML.HiddenMarkovModels.Observations; +using Bonsai.ML.HiddenMarkovModels.Transitions; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// Represents the model parameters of a Hidden Markov Model (HMM). + /// + [Combinator] + [Description("Model parameters of a Hidden Markov Model (HMM).")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonConverter(typeof(ModelParametersJsonConverter))] + public class ModelParameters : PythonStringBuilder + { + private int numStates; + private int dimensions; + private ObservationsModelType observationsModelType; + private TransitionsModelType transitionsModelType; + private StateParameters stateParameters = null; + + /// + /// The number of states of the HMM model. + /// + [Description("The number of discrete latent states of the HMM model")] + [Category("ModelSpecification")] + public int NumStates + { + get => numStates; + set + { + numStates = value; + UpdateString(); + } + } + + /// + /// The dimensionality of the observations into the HMM model. + /// + [Description("The dimensionality of the observations into the HMM model")] + [Category("ModelSpecification")] + public int Dimensions + { + get => dimensions; + set + { + dimensions = value; + UpdateString(); + } + } + + /// + /// The type of distribution that the HMM will use to model the emission of data observations. + /// + [Description("The type of distribution that the HMM will use to model the emission of data observations.")] + [Category("ModelSpecification")] + public ObservationsModelType ObservationsModelType + { + get => observationsModelType; + set + { + observationsModelType = value; + UpdateString(); + } + } + + /// + /// The type of transition model that the HMM will use to calculate the probabilities of transitioning between states. + /// + [Description("The type of transition model that the HMM will use to calculate the probabilities of transitioning between states.")] + [Category("ModelSpecification")] + public TransitionsModelType TransitionsModelType + { + get => transitionsModelType; + set + { + transitionsModelType = value; + UpdateString(); + } + } + + /// + /// The state parameters of the HMM model. + /// + [XmlIgnore] + [Description("The state parameters of the HMM model.")] + [Category("ModelState")] + public StateParameters StateParameters + { + get => stateParameters; + set + { + stateParameters = value; + if (value != null) + { + if (stateParameters.Observations != null) + { + ObservationsModelType = stateParameters.Observations.ObservationsModelType; + } + if (stateParameters.Transitions != null) + { + TransitionsModelType = stateParameters.Transitions.TransitionsModelType; + } + } + UpdateString(); + } + } + + /// + /// Initializes a new instance of the class. + /// + public ModelParameters() + { + NumStates = 2; + Dimensions = 2; + ObservationsModelType = ObservationsModelType.Gaussian; + TransitionsModelType = TransitionsModelType.Stationary; + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new ModelParameters() + { + NumStates = NumStates, + Dimensions = Dimensions, + ObservationsModelType = ObservationsModelType, + TransitionsModelType = TransitionsModelType, + StateParameters = StateParameters + }); + } + + /// + /// Takes an observable seqence and returns an observable sequence of + /// objects that are emitted every time the input sequence emits a new element. + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, item => + { + return new ModelParameters() + { + NumStates = NumStates, + Dimensions = Dimensions, + ObservationsModelType = ObservationsModelType, + TransitionsModelType = TransitionsModelType, + StateParameters = StateParameters + }; + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + var sharedSource = source.Publish().RefCount(); + var stateParametersObservable = new StateParameters().Process(sharedSource); + return sharedSource.Select(pyObject => + { + numStates = pyObject.GetAttr("num_states"); + dimensions = pyObject.GetAttr("dimensions"); + var observationsModelTypeStrPyObj = pyObject.GetAttr("observations_model_type"); + var transitionsModelTypeStrPyObj = pyObject.GetAttr("transitions_model_type"); + + observationsModelType = ObservationsModelLookup.GetFromString(observationsModelTypeStrPyObj); + transitionsModelType = TransitionsModelLookup.GetFromString(transitionsModelTypeStrPyObj); + + return new ModelParameters() + { + NumStates = NumStates, + Dimensions = Dimensions, + ObservationsModelType = ObservationsModelType, + TransitionsModelType = TransitionsModelType + }; + }).Zip(stateParametersObservable, (modelParameters, stateParameters) => + { + modelParameters.StateParameters = stateParameters; + return modelParameters; + }); + } + + /// + protected override string BuildString() + { + StringBuilder.Clear(); + StringBuilder.Append($"num_states={numStates},") + .Append($"dimensions={dimensions},"); + if (stateParameters == null || string.IsNullOrEmpty(stateParameters.ToString())) + { + StringBuilder.Append($"observations_model_type=\"{ObservationsModelLookup.GetString(observationsModelType)}\","); + StringBuilder.Append($"transitions_model_type=\"{TransitionsModelLookup.GetString(transitionsModelType)}\""); + } + else + { + StringBuilder.Append($"{stateParameters},"); + if (stateParameters.Observations == null) + { + StringBuilder.Append($"observations_model_type=\"{ObservationsModelLookup.GetString(observationsModelType)}\","); + } + if (stateParameters.Transitions == null) + { + StringBuilder.Append($"transitions_model_type=\"{TransitionsModelLookup.GetString(transitionsModelType)}\","); + } + StringBuilder.Remove(StringBuilder.Length - 1, 1); + } + return StringBuilder.ToString(); + } + } +} + diff --git a/src/Bonsai.ML.HiddenMarkovModels/ModelParametersJsonConverter.cs b/src/Bonsai.ML.HiddenMarkovModels/ModelParametersJsonConverter.cs new file mode 100644 index 00000000..e2c9ccee --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/ModelParametersJsonConverter.cs @@ -0,0 +1,73 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using Bonsai.ML.HiddenMarkovModels.Observations; +using Bonsai.ML.HiddenMarkovModels.Transitions; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// Provides a type converter to convert between and a JSON string representation. + /// + public class ModelParametersJsonConverter : JsonConverter + { + /// + public override ModelParameters ReadJson(JsonReader reader, Type objectType, ModelParameters existingValue, bool hasExistingValue, JsonSerializer serializer) + { + JObject jo = JObject.Load(reader); + ModelParameters result = new ModelParameters(); + + result.NumStates = jo["num_states"]?.ToObject() ?? result.NumStates; + result.Dimensions = jo["dimensions"]?.ToObject() ?? result.Dimensions; + result.StateParameters = jo["StateParameters"]?.ToObject(); + + result.ObservationsModelType = result.StateParameters?.Observations?.ObservationsModelType + ?? ObservationsModelLookup.GetFromString(jo["observations_model_type"]?.ToObject()); + + result.TransitionsModelType = result.StateParameters?.Transitions?.TransitionsModelType + ?? TransitionsModelLookup.GetFromString(jo["transitions_model_type"]?.ToObject()); + + return result; + } + + /// + public override void WriteJson(JsonWriter writer, ModelParameters value, JsonSerializer serializer) + { + writer.WriteStartObject(); + + writer.WritePropertyName("num_states"); + serializer.Serialize(writer, value.NumStates); + + writer.WritePropertyName("dimensions"); + serializer.Serialize(writer, value.Dimensions); + + if (value.StateParameters != null) + { + writer.WritePropertyName("StateParameters"); + serializer.Serialize(writer, value.StateParameters); + + if (value.StateParameters.Observations == null) + { + writer.WritePropertyName("observations_model_type"); + serializer.Serialize(writer, value.ObservationsModelType); + } + + if (value.StateParameters.Transitions == null) + { + writer.WritePropertyName("transitions_model_type"); + serializer.Serialize(writer, value.TransitionsModelType); + } + } + else + { + writer.WritePropertyName("observations_model_type"); + serializer.Serialize(writer, ObservationsModelLookup.GetString(value.ObservationsModelType)); + + writer.WritePropertyName("transitions_model_type"); + serializer.Serialize(writer, TransitionsModelLookup.GetString(value.TransitionsModelType)); + } + + writer.WriteEndObject(); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservations.cs new file mode 100644 index 00000000..b5018dbd --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservations.cs @@ -0,0 +1,184 @@ +using System; +using System.Collections.Generic; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of AutoRegressiveObservations objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class AutoRegressiveObservations : ObservationsModel + { + /// + /// The lags of the observations for each state. + /// + [Description("The lags of the observations for each state.")] + public int Lags { get; set; } = 1; + + /// + /// The As of the observations for each state. + /// + [XmlIgnore] + [Description("The As of the observations for each state.")] + public double[,,] As { get; set; } = null; + + /// + /// The bs of the observations for each state. + /// + [XmlIgnore] + [Description("The bs of the observations for each state.")] + public double[,] Bs { get; set; } = null; + + /// + /// The Vs of the observations for each state. + /// + [XmlIgnore] + [Description("The Vs of the observations for each state.")] + public double[,,] Vs { get; set; } = null; + + /// + /// The square root sigmas of the observations for each state. + /// + [XmlIgnore] + [Description("The square root sigmas of the observations for each state.")] + public double[,,] SqrtSigmas { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(ObservationsModelTypeJsonConverter))] + [Browsable(false)] + public override ObservationsModelType ObservationsModelType => ObservationsModelType.AutoRegressive; + + /// + [JsonProperty] + public override object[] Params + { + get =>[ As, Bs, Vs, SqrtSigmas ]; + } + + /// + [JsonProperty] + [XmlIgnore] + public override Dictionary Kwargs => new Dictionary + { + ["lags"] = Lags, + }; + + /// + [XmlIgnore] + public static new string[] KwargsArray => [ "lags" ]; + + /// + public AutoRegressiveObservations () : base() + { + } + + /// + public AutoRegressiveObservations (params object[] args) : base(args) + { + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is null || kwargs.Length != 1) + { + throw new ArgumentException($"The AutoRegressiveObservations operator requires exactly one constructor argument: {nameof(Lags)}."); + } + } + + /// + protected override void UpdateKwargs(params object[] args) + { + Lags = args[0] switch + { + int lags => lags, + var lags => Convert.ToInt32(lags) + }; + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 4) + { + throw new ArgumentException($"The {nameof(AutoRegressiveObservations)} operator requires exactly four parameters: {nameof(As)}, {nameof(Bs)}, {nameof(Vs)}, and {nameof(SqrtSigmas)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + As = @params[0] switch + { + double[,,] As => As, + _ => null + }; + + Bs = @params[1] switch + { + double[,] Bs => Bs, + _ => null + }; + + Vs = @params[2] switch + { + double[,,] Vs => Vs, + _ => null + }; + + SqrtSigmas = @params[3] switch + { + double[,,] SqrtSigmas => SqrtSigmas, + _ => null + }; + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new AutoRegressiveObservations (Lags) { + Params = [ As, Bs, Vs, SqrtSigmas ], + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var lagsPyObj = (int)pyObject.GetArrayAttr("lags"); + var asPyObj = (double[,,])pyObject.GetArrayAttr("As"); + var bsPyObj = (double[,])pyObject.GetArrayAttr("bs"); + var vsPyObj = (double[,,])pyObject.GetArrayAttr("Vs"); + var sqrtSigmasPyObj = (double[,,])pyObject.GetArrayAttr("_sqrt_Sigmas"); + + return new AutoRegressiveObservations(Lags) + { + Params = [ + asPyObj, + bsPyObj, + vsPyObj, + sqrtSigmasPyObj + ] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservations.cs new file mode 100644 index 00000000..048d8423 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservations.cs @@ -0,0 +1,103 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of BernoulliObservations objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class BernoulliObservations : ObservationsModel + { + /// + /// The logit P of the observations for each state. + /// + [XmlIgnore] + [Description("The logit P of the observations for each state.")] + public double[,] LogitPs { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(ObservationsModelTypeJsonConverter))] + [Browsable(false)] + public override ObservationsModelType ObservationsModelType => ObservationsModelType.Bernoulli; + + /// + [JsonProperty] + public override object[] Params + { + get => [ LogitPs ]; + } + + /// + public BernoulliObservations () : base() + { + } + + /// + public BernoulliObservations (params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 1) + { + throw new ArgumentException($"The {nameof(BernoulliObservations)} operator requires exactly one parameter: {nameof(LogitPs)}."); + } + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is null || kwargs.Length != 0) + { + throw new ArgumentException($"The {nameof(BernoulliObservations)} operator requires exactly zero constructor arguments."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + LogitPs = (double[,])@params[0]; + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new BernoulliObservations { + Params = [ LogitPs ] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var logitPsPyObj = (double[,])pyObject.GetArrayAttr("logit_ps"); + + return new BernoulliObservations + { + Params = [ logitPsPyObj ] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/CategoricalObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/CategoricalObservations.cs new file mode 100644 index 00000000..01897a16 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/CategoricalObservations.cs @@ -0,0 +1,135 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Collections.Generic; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of CategoricalObservations objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class CategoricalObservations : ObservationsModel + { + /// + /// The number of categories in the observations. + /// + [Description("The number of categories in the observations.")] + public int Categories { get; set; } = 2; + + /// + /// The logit of the observations for each state. + /// + [XmlIgnore] + [Description("The logit of the observations for each state.")] + public double[,,] Logits { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(ObservationsModelTypeJsonConverter))] + [Browsable(false)] + public override ObservationsModelType ObservationsModelType => ObservationsModelType.Categorical; + + /// + [JsonProperty] + public override object[] Params + { + get => [ Logits ]; + } + + /// + [JsonProperty] + [XmlIgnore] + [Browsable(false)] + public override Dictionary Kwargs => new Dictionary + { + ["C"] = Categories, + }; + + /// + [XmlIgnore] + [Browsable(false)] + public static new string[] KwargsArray => [ "C" ]; + + /// + public CategoricalObservations() : base() + { + } + + /// + public CategoricalObservations (params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is null || kwargs.Length != 1) + { + throw new ArgumentException($"The {nameof(CategoricalObservations)} operator requires exactly one keyword argument: {nameof(Categories)}."); + } + } + + /// + protected override void UpdateKwargs(params object[] kwargs) + { + Categories = kwargs[0] switch + { + int c => c, + var c => Convert.ToInt32(c), + }; + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 1) + { + throw new ArgumentException($"The {nameof(CategoricalObservations)} operator requires exactly one parameter: {nameof(Logits)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + Logits = (double[,,])@params[0]; + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new CategoricalObservations (Categories) { + Params = [ Logits ], + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var categoriesPyObj = (int)pyObject.GetArrayAttr("C"); + var logitsPyObj = (double[,,])pyObject.GetArrayAttr("logits"); + + return new CategoricalObservations(categoriesPyObj) + { + Params = [ logitsPyObj ] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservations.cs new file mode 100644 index 00000000..32aeb45e --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservations.cs @@ -0,0 +1,104 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of ExponentialObservations objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class ExponentialObservations : ObservationsModel + { + + /// + /// The log lambdas of the observations for each state. + /// + [XmlIgnore] + [Description("The log lambdas of the observations for each state.")] + public double[,] LogLambdas { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(ObservationsModelTypeJsonConverter))] + [Browsable(false)] + public override ObservationsModelType ObservationsModelType => ObservationsModelType.Exponential; + + /// + [JsonProperty] + public override object[] Params + { + get => [ LogLambdas ]; + } + + /// + public ExponentialObservations () : base() + { + } + + /// + public ExponentialObservations (params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 1) + { + throw new ArgumentException($"The {nameof(ExponentialObservations)} operator requires exactly one parameter: {nameof(LogLambdas)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + LogLambdas = (double[,])@params[0]; + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is null || kwargs.Length != 0) + { + throw new ArgumentException($"The {nameof(ExponentialObservations)} operator requires exactly zero constructor arguments."); + } + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new ExponentialObservations { + Params = [ LogLambdas ] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var logLambdasPyObj = (double[,])pyObject.GetArrayAttr("log_lambdas"); + + return new ExponentialObservations + { + Params = [ logLambdasPyObj ] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservations.cs new file mode 100644 index 00000000..f540fb12 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservations.cs @@ -0,0 +1,111 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of GaussianObservations objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class GaussianObservations : ObservationsModel + { + /// + /// The means of the observations for each state. + /// + [XmlIgnore] + [Description("The means of the observations for each state.")] + public double[,] Mus { get; set; } = null; + + /// + /// The standard deviations of the observations for each state. + /// + [XmlIgnore] + [Description("The standard deviations of the observations for each state.")] + public double[,,] SqrtSigmas { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(ObservationsModelTypeJsonConverter))] + [Browsable(false)] + public override ObservationsModelType ObservationsModelType => ObservationsModelType.Gaussian; + + /// + [JsonProperty] + public override object[] Params + { + get => [ Mus, SqrtSigmas ]; + } + + /// + public GaussianObservations () : base() + { + } + + /// + public GaussianObservations (params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 2) + { + throw new ArgumentException($"The {nameof(GaussianObservations)} operator requires exactly two parameters: {nameof(Mus)} and {nameof(SqrtSigmas)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + Mus = (double[,])@params[0]; + SqrtSigmas = (double[,,])@params[1]; + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is null || kwargs.Length != 0) + { + throw new ArgumentException($"The {nameof(GaussianObservations)} operator requires exactly zero constructor arguments."); + } + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new GaussianObservations { + Params = [ Mus, SqrtSigmas ] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var musPyObj = (double[,])pyObject.GetArrayAttr("mus"); + var sqrtSigmasPyObj = (double[,,])pyObject.GetArrayAttr("_sqrt_Sigmas"); + + return new GaussianObservations { + Params = [ Mus, SqrtSigmas ] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs new file mode 100644 index 00000000..ffeb6c00 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs @@ -0,0 +1,97 @@ +using System.ComponentModel; +using Python.Runtime; +using System.Xml.Serialization; +using System; +using System.Reactive.Linq; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Represents an operator that will transform an observable sequence of + /// into an observable sequence of . + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class GaussianObservationsStatistics + { + /// + /// The means of the observations for each state. + /// + [Description("The means of the observations for each state.")] + [XmlIgnore] + public double[,] Means { get; set; } + + /// + /// The standard deviations of the observations for each state. + /// + [Description("The standard deviations of the observations for each state.")] + [XmlIgnore] + public double[,] StdDevs { get; set; } + + /// + /// The covariance matrices of the observations for each state. + /// + [Description("The covariance matrices of the observations for each state.")] + [XmlIgnore] + public double[,,] CovarianceMatrices { get; set; } + + /// + /// The batch observations that the model has seen. + /// + [Description("The batch observations that the model has seen.")] + [XmlIgnore] + public double[,] BatchObservations { get; set; } + + /// + /// The sequence of inferred most probable states. + /// + [Description("The sequence of inferred most probable states.")] + [XmlIgnore] + public int[] InferredMostProbableStates { get; set; } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var observationsPyObj = pyObject.GetAttr("observations"); + var meansPyObj = (double[,])observationsPyObj.GetArrayAttr("mus"); + var covarianceMatricesPyObj = (double[,,])observationsPyObj.GetArrayAttr("Sigmas"); + var stdDevsPyObj = DiagonalSqrt(covarianceMatricesPyObj); + var batchObservationsPyObj = (double[,])pyObject.GetArrayAttr("batch_observations"); + var inferredMostProbableStatesPyObj = (int[])pyObject.GetArrayAttr("inferred_most_probable_states"); + + return new GaussianObservationsStatistics + { + Means = meansPyObj, + StdDevs = stdDevsPyObj, + CovarianceMatrices = covarianceMatricesPyObj, + BatchObservations = batchObservationsPyObj, + InferredMostProbableStates = inferredMostProbableStatesPyObj + }; + }); + } + + private static double[,] DiagonalSqrt(double[,,] matrix) + { + var states = matrix.GetLength(0); + var dimensions = matrix.GetLength(1); + var diagonal = new double[states, dimensions]; + + for (int i = 0; i < states; i++) + { + for (int j = 0; j < dimensions; j++) + { + diagonal[i, j] = Math.Sqrt(matrix[i, j, j]); + } + } + + return diagonal; + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModel.cs new file mode 100644 index 00000000..301b4e87 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModel.cs @@ -0,0 +1,33 @@ +using System.ComponentModel; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// An abstract class for creating an Observations model. + /// + public abstract class ObservationsModel : PythonModel + { + /// + /// The type of Observations model. + /// + public abstract ObservationsModelType ObservationsModelType { get; } + + /// + [Browsable(false)] + protected override string ModelName => "observations"; + + /// + [Browsable(false)] + protected override string ModelType => ObservationsModelLookup.GetString(ObservationsModelType); + + /// + public ObservationsModel() : base() + { + } + + /// + public ObservationsModel(params object[] kwargs) : base(kwargs) + { + } + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelLookup.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelLookup.cs new file mode 100644 index 00000000..65aaa403 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelLookup.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// A lookup class for relating different to + /// and their corresponding Python string representations. + /// + public static class ObservationsModelLookup + { + private static readonly Dictionary _lookup = new Dictionary + { + { ObservationsModelType.Gaussian, (typeof(GaussianObservations), "gaussian") }, + { ObservationsModelType.Exponential, (typeof(ExponentialObservations), "exponential") }, + { ObservationsModelType.Bernoulli, (typeof(BernoulliObservations), "bernoulli") }, + { ObservationsModelType.Poisson, (typeof(PoissonObservations), "poisson") }, + { ObservationsModelType.AutoRegressive, (typeof(AutoRegressiveObservations), "autoregressive") }, + { ObservationsModelType.Categorical, (typeof(CategoricalObservations), "categorical") } + }; + + /// + /// Gets the of the corresponding to the given . + /// + public static Type GetObservationsClassType(ObservationsModelType type) => _lookup[type].Type; + + /// + /// Gets the Python string representation of the given . + /// + public static string GetString(ObservationsModelType type) => _lookup[type].StringValue; + + /// + /// Gets the corresponding to the given Python string representation. + /// + public static ObservationsModelType GetFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Gets the corresponding to the given of . + /// + public static ObservationsModelType GetFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelType.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelType.cs new file mode 100644 index 00000000..42810c8f --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelType.cs @@ -0,0 +1,38 @@ +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Represents the type of observations in a hidden Markov model. + /// + public enum ObservationsModelType + { + /// + /// Gaussian observations. + /// + Gaussian, + + /// + /// Exponential observations. + /// + Exponential, + + /// + /// Bernoulli observations. + /// + Bernoulli, + + /// + /// Poisson observations. + /// + Poisson, + + /// + /// Autoregressive observations. + /// + AutoRegressive, + + /// + /// Categorical observations. + /// + Categorical + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelTypeJsonConverter.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelTypeJsonConverter.cs new file mode 100644 index 00000000..2cec1a93 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelTypeJsonConverter.cs @@ -0,0 +1,24 @@ +using Newtonsoft.Json; +using System; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Provides a type converter to convert between and the corresponding Python string representation. + /// + public class ObservationsModelTypeJsonConverter : JsonConverter + { + /// + public override ObservationsModelType ReadJson(JsonReader reader, Type objectType, ObservationsModelType existingValue, bool hasExistingValue, JsonSerializer serializer) + { + string stringValue = reader.Value?.ToString(); + return ObservationsModelLookup.GetFromString(stringValue); + } + + /// + public override void WriteJson(JsonWriter writer, ObservationsModelType value, JsonSerializer serializer) + { + writer.WriteValue(ObservationsModelLookup.GetString(value)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservations.cs new file mode 100644 index 00000000..78b33822 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservations.cs @@ -0,0 +1,103 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Observations +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of PoissonObservations objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class PoissonObservations : ObservationsModel + { + + /// + /// The log lambdas of the observations for each state. + /// + [XmlIgnore] + [Description("The log lambdas of the observations for each state.")] + public double[,] LogLambdas { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(ObservationsModelTypeJsonConverter))] + [Browsable(false)] + public override ObservationsModelType ObservationsModelType => ObservationsModelType.Poisson; + + /// + [JsonProperty] + public override object[] Params + { + get => [ LogLambdas ]; + } + + /// + public PoissonObservations () : base() + { + } + + /// + public PoissonObservations (params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 1) + { + throw new ArgumentException($"The {nameof(PoissonObservations)} operator requires exactly one parameter: {nameof(LogLambdas)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + LogLambdas = (double[,])@params[0]; + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is null || kwargs.Length != 0) + { + throw new ArgumentException($"The {nameof(PoissonObservations)} operator requires exactly zero constructor arguments."); + } + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new PoissonObservations { + Params = [ LogLambdas ] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var logLambdasPyObj = (double[,])pyObject.GetArrayAttr("log_lambdas"); + + return new PoissonObservations { + Params = [ logLambdasPyObj ] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Properties/launchSettings.json b/src/Bonsai.ML.HiddenMarkovModels/Properties/launchSettings.json new file mode 100644 index 00000000..b48bcfa9 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "Bonsai": { + "commandName": "Executable", + "executablePath": "$(registry:HKEY_CURRENT_USER\\Software\\Bonsai Foundation\\Bonsai@InstallDir)Bonsai.exe", + "commandLineArgs": "--lib:\"$(TargetDir).\"", + "nativeDebugging": true + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/PythonModel.cs b/src/Bonsai.ML.HiddenMarkovModels/PythonModel.cs new file mode 100644 index 00000000..b94bafcf --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/PythonModel.cs @@ -0,0 +1,151 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Bonsai.ML.Data; +using System.Xml.Serialization; +using System.Linq; +using System.Reactive.Linq; +using System.ComponentModel; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// An abstract class for creating a Python model. + /// + public abstract class PythonModel : PythonStringBuilder + { + /// + /// The parameters that are used to define the model. + /// + [Browsable(false)] + [XmlIgnore] + public virtual object[] Params + { + get => null; + set + { + CheckParams(value); + UpdateParams(value); + UpdateString(); + } + } + + /// + /// The array of keyword arguments used to construct the model. + /// + public static string[] KwargsArray => null; + + /// + /// The dictionary of keyword arguments that are used to construct the model. + /// + [Browsable(false)] + [XmlIgnore] + public virtual Dictionary Kwargs => new(); + + /// + /// Checks if the keyword arguments are valid. + /// + /// The keyword arguments. + protected virtual void CheckKwargs(params object[] kwargs) + { + } + + /// + /// Updates the kwargs dictionary. + /// + /// The keyword arguments. + protected virtual void UpdateKwargs(params object[] kwargs) + { + } + + /// + /// Checks if the parameters are valid. + /// + /// The parameters. + protected virtual void CheckParams(params object[] @params) + { + } + + /// + /// Updates the parameters. + /// + /// The parameters. + protected virtual void UpdateParams(params object[] @params) + { + } + + /// + /// Constructs a new instance of the class. + /// + public PythonModel() + { + BuildString(); + } + + /// + /// Initializes a new instance of the class using keyword arguments. + /// + /// The keyword arguments. + public PythonModel(params object[] kwargs) + { + CheckKwargs(kwargs); + UpdateKwargs(kwargs); + UpdateString(); + } + + /// + /// The name of the base python model class. + /// + protected abstract string ModelName { get; } + + /// + /// The specific type of the model. + /// + protected abstract string ModelType { get; } + + /// + protected override string BuildString() + { + // StringBuilder.Clear(); + StringBuilder.Append($"{ModelName}_model_type=\"{ModelType}\""); + + if (Params != null && Params.Length > 0 && Params.All(p => p != null)) + { + var paramsStringBuilder = new StringBuilder(); + paramsStringBuilder.Append($",{ModelName}_params=("); + + foreach (var param in Params) { + if (param is null) { + paramsStringBuilder.Clear(); + break; + } + var arrString = param is Array array ? ArrayHelper.SerializeArrayToJson(array) : param.ToString(); + paramsStringBuilder.Append($"{arrString},"); + } + + if (paramsStringBuilder.Length > 0) { + paramsStringBuilder.Remove(paramsStringBuilder.Length - 1, 1); + paramsStringBuilder.Append(")"); + StringBuilder.Append(paramsStringBuilder); + } + } + + if (Kwargs is not null && Kwargs.Count > 0) + { + StringBuilder.Append($",{ModelName}_kwargs={{"); + foreach (var kp in Kwargs) { + StringBuilder.Append($"\"{kp.Key}\":{(kp.Value is null ? "None" + : kp.Value is Array array ? ArrayHelper.SerializeArrayToJson(array) + : kp.Value is string ? $"\"{kp.Value}\"" + : kp.Value)},"); + } + StringBuilder.Remove(StringBuilder.Length - 1, 1); + StringBuilder.Append("}"); + } + + var result = StringBuilder.ToString(); + StringBuilder.Clear(); + return result; + } + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/PythonStringBuilder.cs b/src/Bonsai.ML.HiddenMarkovModels/PythonStringBuilder.cs new file mode 100644 index 00000000..a480bc51 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/PythonStringBuilder.cs @@ -0,0 +1,48 @@ +using System.Text; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// Provides a base class for building string representations of Python objects. + /// + public abstract class PythonStringBuilder + { + + private string _cachedString; + private bool _updateString; + + /// + /// The internal string builder used to build the string representation. + /// + protected readonly StringBuilder StringBuilder = new StringBuilder(); + + /// + /// Sets a flag to update the string cache on the next call to the method. + /// + protected void UpdateString() + { + _updateString = true; + } + + /// + /// Method used to build a string representation of the object. + /// + protected virtual string BuildString() + { + return StringBuilder.ToString(); + } + + /// + public override string ToString() + { + if (_updateString) + { + _cachedString = BuildString(); + _updateString = false; + } + return _cachedString; + } + } +} + + diff --git a/src/Bonsai.ML.HiddenMarkovModels/RunFitAsync.bonsai b/src/Bonsai.ML.HiddenMarkovModels/RunFitAsync.bonsai new file mode 100644 index 00000000..8087ea1a --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/RunFitAsync.bonsai @@ -0,0 +1,200 @@ + + + + + + Source1 + + + + + + + 50 + + + + + + + + + + 50 + + + + + + + + + + 50 + 50 + + + + + + + + + + + + + hmm + + + Name + + + + + + + 50 + + + + + + + + 50 + + + + + + + + false + + + + + + + + true + + + + + + + + true + + + + + + + + true + + + + + + + vars_to_estimate={{"initial_state_distribution":{3},"transitions_params":{4},"observations_params":{5}}},batch_size={0},max_iter={1},flush_data_between_batches={2} + it.Item1, it.Item2, it.Item3,it.Item4,it.Item5,it.Item6 + + + + + + {0}.fit_async({1}, {2}) + it.Item1.Item2,it.Item1.Item1, it.Item2 + + + + + + + + HMMModule + + + + + + + + + hmm.fit_async([1.38021159584039,14.2434706648223], vars_to_estimate={"initial_state_distribution":True,"log_transition_probabilities":True,"observation_means":False,"observation_covs":False},batch_size=50,max_iter=50,flush_data_between_batches=False) + + + + it.ToString() == "True" + + + + + + IsRunning? + + + + Source1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/SerializeToJson.cs b/src/Bonsai.ML.HiddenMarkovModels/SerializeToJson.cs new file mode 100644 index 00000000..9ce0543a --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/SerializeToJson.cs @@ -0,0 +1,53 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Newtonsoft.Json; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// Serializes a sequence of data model objects into JSON strings. + /// + [Combinator] + [WorkflowElementCategory(ElementCategory.Transform)] + [Description("Serializes a sequence of data model objects into JSON strings.")] + public class SerializeToJson + { + private IObservable Process(IObservable source) + { + return source.Select(value => JsonConvert.SerializeObject(value)); + } + + /// + /// Serializes each object in the sequence to + /// a JSON string. + /// + /// + /// A sequence of objects. + /// + /// + /// A sequence of JSON strings representing the corresponding + /// object. + /// + public IObservable Process(IObservable source) + { + return Process(source); + } + + /// + /// Serializes each object in the sequence to + /// a JSON string. + /// + /// + /// A sequence of objects. + /// + /// + /// A sequence of JSON strings representing the corresponding + /// object. + /// + public IObservable Process(IObservable source) + { + return Process(source); + } + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/StateParameters.cs b/src/Bonsai.ML.HiddenMarkovModels/StateParameters.cs new file mode 100644 index 00000000..d778031b --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/StateParameters.cs @@ -0,0 +1,213 @@ +using System.ComponentModel; +using System; +using System.Reactive.Linq; +using Python.Runtime; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.HiddenMarkovModels.Observations; +using Bonsai.ML.HiddenMarkovModels.Transitions; +using Bonsai.ML.Python; +using Bonsai.ML.Data; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// Represents the state parameters of a Hidden Markov Model (HMM). + /// + [Combinator] + [JsonConverter(typeof(StateParametersJsonConverter))] + [Description("State parameters of a Hidden Markov Model (HMM).")] + [WorkflowElementCategory(ElementCategory.Source)] + public class StateParameters : PythonStringBuilder + { + + private double[] initialStateDistribution = null; + private TransitionsModel transitions = null; + private ObservationsModel observations = null; + + /// + /// The initial state distribution. + /// + [XmlIgnore] + [JsonProperty("initial_state_distribution")] + [Description("The initial state distribution.")] + [Category("ModelStateParameters")] + public double[] InitialStateDistribution + { + get => initialStateDistribution; + set + { + initialStateDistribution = value; + UpdateString(); + } + } + + /// + /// The transitions model. + /// + [XmlIgnore] + [JsonProperty("transitions_params")] + [Description("The transitions model.")] + [Category("ModelStateParameters")] + public TransitionsModel Transitions + { + get => transitions; + set + { + transitions = value; + UpdateString(); + } + } + + /// + /// The observations. + /// + [XmlIgnore] + [JsonProperty("observations_params")] + [Description("The observations.")] + [Category("ModelStateParameters")] + public ObservationsModel Observations + { + get => observations; + set + { + observations = value; + UpdateString(); + } + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new StateParameters() + { + InitialStateDistribution = InitialStateDistribution, + Transitions = Transitions, + Observations = Observations + } + ); + } + + /// + /// Takes an observable seqence and returns an observable sequence of + /// objects that are emitted every time the input sequence emits a new element. + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + return new StateParameters() + { + InitialStateDistribution = InitialStateDistribution, + Transitions = Transitions, + Observations = Observations + }; + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var initialStateDistributionPyObj = (double[])pyObject.GetArrayAttr("initial_state_distribution"); + + var transitionsModelTypePyObj = pyObject.GetAttr("transitions_model_type"); + var transitionsParamsPyObj = (Array)pyObject.GetArrayAttr("transitions_params"); + var transitionsParams = (object[])transitionsParamsPyObj; + + var transitionsModelType = TransitionsModelLookup.GetFromString(transitionsModelTypePyObj); + var transitionsClassType = TransitionsModelLookup.GetTransitionsClassType(transitionsModelType); + var transitionsKwargsProperty = transitionsClassType.GetProperty("KwargsArray"); + + object[] transitionsConstructorArgs = null; + if (transitionsKwargsProperty is not null) + { + var transitionsConstructorKeys = (string[])transitionsKwargsProperty.GetValue(null); + var transitionsConstructorKeysCount = transitionsConstructorKeys.Length; + if (transitionsConstructorKeysCount > 0) + { + transitionsConstructorArgs = new object[transitionsConstructorKeysCount]; + var transitionsPyObj = pyObject.GetAttr("transitions"); + for (int i = 0; i < transitionsConstructorKeysCount; i++) + { + transitionsConstructorArgs[i] = transitionsPyObj.GetArrayAttr(transitionsConstructorKeys[i]); + } + } + } + + transitions = (TransitionsModel)Activator.CreateInstance(transitionsClassType, transitionsConstructorArgs); + transitions.Params = transitionsParams; + + var observationsModelTypePyObj = pyObject.GetAttr("observations_model_type"); + var observationsParamsPyObj = (Array)pyObject.GetArrayAttr("observations_params"); + var observationsParams = (object[])observationsParamsPyObj; + + var observationsModelType = ObservationsModelLookup.GetFromString(observationsModelTypePyObj); + var observationsClassType = ObservationsModelLookup.GetObservationsClassType(observationsModelType); + + var observationsKwargsProperty = observationsClassType.GetProperty("KwargsArray"); + + object[] observationsConstructorArgs = null; + if (observationsKwargsProperty is not null) + { + var observationsConstructorKeys = (string[])observationsKwargsProperty.GetValue(null); + var observationsConstructorKeysCount = observationsConstructorKeys.Length; + if (observationsConstructorKeysCount > 0) + { + observationsConstructorArgs = new object[observationsConstructorKeysCount]; + var observationsPyObj = pyObject.GetAttr("observations"); + for (int i = 0; i < observationsConstructorKeysCount; i++) + { + observationsConstructorArgs[i] = observationsPyObj.GetArrayAttr(observationsConstructorKeys[i]); + } + } + } + + observations = (ObservationsModel)Activator.CreateInstance(observationsClassType, observationsConstructorArgs); + observations.Params = observationsParams; + + return new StateParameters() + { + InitialStateDistribution = initialStateDistributionPyObj, + Transitions = transitions, + Observations = observations + }; + }); + } + + /// + protected override string BuildString() + { + StringBuilder.Clear(); + + if (InitialStateDistribution != null) + { + StringBuilder.Append($"initial_state_distribution={ArrayHelper.SerializeToJson(InitialStateDistribution)},"); + } + + if (Transitions != null) + { + StringBuilder.Append($"{Transitions},"); + } + + if (Observations != null) + { + StringBuilder.Append($"{Observations},"); + } + + if (StringBuilder.Length > 0) + { + StringBuilder.Remove(StringBuilder.Length - 1, 1); + } + + return StringBuilder.ToString(); + } + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/StateParametersJsonConverter.cs b/src/Bonsai.ML.HiddenMarkovModels/StateParametersJsonConverter.cs new file mode 100644 index 00000000..c95db285 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/StateParametersJsonConverter.cs @@ -0,0 +1,127 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Linq; +using Bonsai.ML.HiddenMarkovModels.Observations; +using Bonsai.ML.HiddenMarkovModels.Transitions; +using Bonsai.ML.Data; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// Provides a type converter to convert between and a JSON string representation. + /// + public class StateParametersJsonConverter : JsonConverter + { + /// + public override StateParameters ReadJson(JsonReader reader, Type objectType, StateParameters existingValue, bool hasExistingValue, JsonSerializer serializer) + { + JObject jo = JObject.Load(reader); + StateParameters result = new StateParameters + { + InitialStateDistribution = jo["InitialStateDistribution"]?.ToObject() + }; + + var transitionsObj = (JObject)jo["Transitions"]; + var transitionsModelType = TransitionsModelLookup.GetFromString(transitionsObj["TransitionsModelType"]?.ToString()); + + object[] transitionsKwargsArray = null; + object[] transitionsParamsArray = []; + + if (transitionsObj.ContainsKey("Kwargs")) + { + var kwargs = (JObject)transitionsObj["Kwargs"]; + transitionsKwargsArray = kwargs.Properties() + .Select(p => p.Value.ToObject()) + .ToArray(); + if (transitionsKwargsArray.Count() == 0) + { + transitionsKwargsArray = null; + } + } + + var transitions = (TransitionsModel)Activator.CreateInstance(TransitionsModelLookup.GetTransitionsClassType(transitionsModelType), transitionsKwargsArray); + + if (transitionsObj.ContainsKey("Params")) + { + var paramsJArray = (JArray)transitionsObj["Params"]; + var nParams = paramsJArray.Count; + transitionsParamsArray = new object[nParams]; + for (int i = 0; i < nParams; i++) + { + try + { + transitionsParamsArray[i] = ArrayHelper.ParseString(paramsJArray[i].ToString(), typeof(double)); + } + catch + { + transitionsParamsArray[i] = JsonConvert.DeserializeObject(paramsJArray[i].ToString()); + } + } + } + + transitions.Params = transitionsParamsArray; + result.Transitions = transitions; + + var observationsObj = (JObject)jo["Observations"]; + var observationsModelType = ObservationsModelLookup.GetFromString(observationsObj["ObservationsModelType"]?.ToString()); + + object[] observationsKwargsArray = null; + object[] observationsParamsArray = []; + + if (observationsObj.ContainsKey("Kwargs")) + { + var kwargs = (JObject)observationsObj["Kwargs"]; + observationsKwargsArray = kwargs.Properties() + .Select(p => p.Value.ToObject()) + .ToArray(); + if (observationsKwargsArray.Count() == 0) + { + observationsKwargsArray = null; + } + } + + var observations = (ObservationsModel)Activator.CreateInstance(ObservationsModelLookup.GetObservationsClassType(observationsModelType), observationsKwargsArray); + + if (observationsObj.ContainsKey("Params")) + { + var paramsJArray = (JArray)observationsObj["Params"]; + var nParams = paramsJArray.Count; + observationsParamsArray = new object[nParams]; + for (int i = 0; i < nParams; i++) + { + try + { + observationsParamsArray[i] = ArrayHelper.ParseString(paramsJArray[i].ToString(), typeof(double)); + } + catch + { + observationsParamsArray[i] = JsonConvert.DeserializeObject(paramsJArray[i].ToString()); + } + } + } + + observations.Params = observationsParamsArray; + result.Observations = observations; + + return result; + } + + /// + public override void WriteJson(JsonWriter writer, StateParameters value, JsonSerializer serializer) + { + writer.WriteStartObject(); + + writer.WritePropertyName("InitialStateDistribution"); + serializer.Serialize(writer, value.InitialStateDistribution); + + writer.WritePropertyName("Transitions"); + serializer.Serialize(writer, value.Transitions); + + writer.WritePropertyName("Observations"); + serializer.Serialize(writer, value.Observations); + + writer.WriteEndObject(); + } + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/StateProbability.cs b/src/Bonsai.ML.HiddenMarkovModels/StateProbability.cs new file mode 100644 index 00000000..24857a48 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/StateProbability.cs @@ -0,0 +1,51 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using Python.Runtime; +using System.Xml.Serialization; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels +{ + /// + /// Represents the probabilities of being in each state of a Hidden Markov Model (HMM) given the observation. + /// + [Combinator] + [Description("The probability of being in each state of a Hidden Markov Model (HMM) given the observation.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class StateProbability + { + /// + /// The probability of being in each state given the observation. + /// + [XmlIgnore] + [Description("The probability of being in each state given the observation.")] + public double[] Probabilities { get; private set; } + + /// + /// The state with the highest probability. + /// + [XmlIgnore] + [Description("The state with the highest probability.")] + public int HighestProbableState => Array.IndexOf(Probabilities, Probabilities.Max()); + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing the `state_probabilities` attribute of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var probabilitiesPyObj = (double[])pyObject.GetArrayAttr("state_probabilities"); + + return new StateProbability + { + Probabilities = probabilitiesPyObj + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/ConstrainedStationaryTransitions.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/ConstrainedStationaryTransitions.cs new file mode 100644 index 00000000..35c6fb2c --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/ConstrainedStationaryTransitions.cs @@ -0,0 +1,175 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Collections.Generic; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Data; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of ConstrainedStationaryTransitions objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class ConstrainedStationaryTransitions : TransitionsModel + { + private int[,] transitionMask = null; + + /// + /// The mask which gets applied to the transition matrix to prohibit certain transitions. + /// It must be written in JSON format as an int[,] with the same shape as the transition matrix (nStates x nStates). + /// For example, the mask [[1, 0], [1, 1]] is valid and would prohibit transitions from state 0 to state 1. + /// + [Description("The mask which gets applied to the transition matrix to prohibit certain transitions. It must be written in JSON format as an int[,] with the same shape as the transition matrix (nStates x nStates). For example, the mask [[1, 0], [1, 1]] is valid and would prohibit transitions from state 0 to state 1.")] + public string TransitionMask + { + get => transitionMask != null ? ArrayHelper.SerializeToJson(transitionMask) : ""; + set => transitionMask = (int[,])ArrayHelper.ParseString(value, typeof(int)); + } + + /// + /// The Log Ps of the transitions. + /// + [XmlIgnore] + [Description("The log Ps of the transitions.")] + public double[,] LogPs { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(TransitionsModelTypeJsonConverter))] + [Browsable(false)] + public override TransitionsModelType TransitionsModelType => TransitionsModelType.ConstrainedStationary; + + /// + [JsonProperty] + [Browsable(false)] + public override object[] Params + { + get => [LogPs]; + } + + /// + [JsonProperty] + [XmlIgnore] + public override Dictionary Kwargs => new Dictionary + { + ["transition_mask"] = transitionMask, + }; + + /// + [XmlIgnore] + public static new string[] KwargsArray => [ "transition_mask" ]; + + /// + public ConstrainedStationaryTransitions() : base() + { + } + + /// + public ConstrainedStationaryTransitions (params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is null || kwargs.Length != 1) + { + throw new ArgumentException($"The ConstrainedStationaryTransitions operator requires exactly one keyword argument: {nameof(transitionMask)}."); + } + } + + /// + protected override void UpdateKwargs(params object[] kwargs) + { + transitionMask = kwargs[0] switch + { + int[,] mask => mask, + long[,] mask => ConvertLongArrayToIntArray(mask), + bool[,] mask => ConvertBoolArrayToIntArray(mask), + _ => null + }; + } + + private static int[,] ConvertLongArrayToIntArray(long[,] longArray) + { + int rows = longArray.GetLength(0); + int cols = longArray.GetLength(1); + int[,] intArray = new int[rows, cols]; + + for (int i = 0; i < rows; i++) + for (int j = 0; j < cols; j++) + intArray[i, j] = Convert.ToInt32(longArray[i, j]); + + return intArray; + } + + private static int[,] ConvertBoolArrayToIntArray(bool[,] boolArray) + { + int rows = boolArray.GetLength(0); + int cols = boolArray.GetLength(1); + int[,] intArray = new int[rows, cols]; + + for (int i = 0; i < rows; i++) + for (int j = 0; j < cols; j++) + intArray[i, j] = Convert.ToInt32(boolArray[i, j]); + + return intArray; + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 1) + { + throw new ArgumentException($"The ConstrainedStationaryTransitions operator requires exactly one parameter: {nameof(LogPs)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + LogPs = @params[0] switch { + double[,] logPs => logPs, + _ => null + }; + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new ConstrainedStationaryTransitions(transitionMask) + { + Params = [LogPs] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var logPsPyObj = (double[,])pyObject.GetArrayAttr("log_Ps"); + var transitionMaskPyObj = (int[,])pyObject.GetArrayAttr("transition_mask"); + + return new ConstrainedStationaryTransitions(transitionMaskPyObj) + { + Params = [logPsPyObj] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitions.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitions.cs new file mode 100644 index 00000000..d2926b81 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitions.cs @@ -0,0 +1,200 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Collections.Generic; +using Newtonsoft.Json; +using System.Xml.Serialization; +using Bonsai.ML.Python; +using System.Linq; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of NeuralNetworkRecurrentTransitions objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class NeuralNetworkRecurrentTransitions : TransitionsModel + { + /// + /// The sizes of the hidden layers. + /// + [Description("The sizes of the hidden layers.")] + public int[] HiddenLayerSizes { get; set; } = [50]; + + /// + /// The type of nonlinearity or activation function. + /// + [Description("The type of nonlinearity or activation function.")] + public NonlinearityType NonlinearityType { get; set; } = NonlinearityType.ReLU; + + /// + /// The Log Ps of the transitions. + /// + [Description("The log Ps of the transitions.")] + [XmlIgnore] + public double[,] LogPs { get; set; } = null; + + /// + /// The weights. + /// + [Description("The weights.")] + [XmlIgnore] + public List Weights { get; set; } = null; + + /// + /// The biases. + /// + [Description("The biases.")] + [XmlIgnore] + public List Biases { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(TransitionsModelTypeJsonConverter))] + [Browsable(false)] + public override TransitionsModelType TransitionsModelType => TransitionsModelType.NeuralNetworkRecurrent; + + private static readonly Dictionary nonlinearityTypeLookup = new Dictionary + { + { NonlinearityType.ReLU, "relu" }, + { NonlinearityType.Tanh, "tanh" }, + { NonlinearityType.Sigmoid, "sigmoid" } + }; + + /// + [JsonProperty] + [Browsable(false)] + public override object[] Params + { + get => [LogPs, Weights, Biases]; + } + + /// + [JsonProperty] + [XmlIgnore] + public override Dictionary Kwargs => new Dictionary + { + ["hidden_layer_sizes"] = HiddenLayerSizes, + ["nonlinearity_type"] = nonlinearityTypeLookup[NonlinearityType], + }; + + /// + [XmlIgnore] + public static new string[] KwargsArray => [ "hidden_layer_sizes", "nonlinearity_type" ]; + + /// + public NeuralNetworkRecurrentTransitions () : base() + { + } + + /// + public NeuralNetworkRecurrentTransitions (params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is null || kwargs.Length != 2) + { + throw new ArgumentException($"The NeuralNetworkRecurrentTransitions operator requires exactly one constructor argument: {nameof(HiddenLayerSizes)}."); + } + } + + /// + protected override void UpdateKwargs(params object[] kwargs) + { + HiddenLayerSizes = kwargs[0] switch + { + int[] layers => layers, + long[] layers => layers.Select(Convert.ToInt32).ToArray(), + _ => null + }; + try + { + NonlinearityType = (NonlinearityType)kwargs[1]; + } + catch (InvalidCastException) + { + try + { + NonlinearityType = nonlinearityTypeLookup.First(entry => entry.Value == (string)kwargs[1]).Key; + } + catch (KeyNotFoundException) + { + throw new ArgumentException($"The NeuralNetworkRecurrentTransitions operator requires a valid nonlinearity type. The provided value was: {kwargs[1]} which is neither a valid NonlinearityType nor a valid string representation of a nonlinearity type."); + } + } + + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 3) + { + throw new ArgumentException($"The NeuralNetworkRecurrentTransitions operator requires exactly three parameters: {nameof(LogPs)}, {nameof(Weights)}, and {nameof(Biases)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + LogPs = @params[0] switch + { + double[,] logPs => logPs, + _ => null + }; + + Weights = @params[1] switch + { + List weights => weights.Select(weight => (double[,])weight).ToList(), + _ => null + }; + + Biases = @params[2] switch + { + List biases => biases.Select(bias => (double[])bias).ToList(), + _ => null + }; + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new NeuralNetworkRecurrentTransitions([HiddenLayerSizes, NonlinearityType]) + { + Params = [LogPs, Weights, Biases] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var logPsPyObj = (double[,])pyObject.GetArrayAttr("log_Ps"); + var weightsPyObj = (List)pyObject.GetArrayAttr("weights"); + var biasesPyObj = (List)pyObject.GetArrayAttr("biases"); + var hiddenLayerSizesPyObj = (int[])pyObject.GetArrayAttr("hidden_layer_sizes"); + var nonlinearityTypePyObj = (string)pyObject.GetArrayAttr("nonlinearity_type"); + + return new NeuralNetworkRecurrentTransitions([HiddenLayerSizes, NonlinearityType]) + { + Params = [logPsPyObj, weightsPyObj, biasesPyObj] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/NonlinearityType.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/NonlinearityType.cs new file mode 100644 index 00000000..29cb7225 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/NonlinearityType.cs @@ -0,0 +1,23 @@ +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Represents the type of nonlinearity to use in a recurrent neural network. + /// + public enum NonlinearityType + { + /// + /// Rectified linear unit (ReLU) nonlinearity. + /// + ReLU, + + /// + /// Tanh nonlinearity. + /// + Tanh, + + /// + /// Sigmoid nonlinearity. + /// + Sigmoid + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/StationaryTransitions.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/StationaryTransitions.cs new file mode 100644 index 00000000..5d7d3301 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/StationaryTransitions.cs @@ -0,0 +1,105 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of StationaryTransitions objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class StationaryTransitions : TransitionsModel + { + /// + /// The Log Ps of the transitions. + /// + [XmlIgnore] + [Description("The log Ps of the transitions.")] + public double[,] LogPs { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(TransitionsModelTypeJsonConverter))] + [Browsable(false)] + public override TransitionsModelType TransitionsModelType => TransitionsModelType.Stationary; + + /// + [JsonProperty] + [Browsable(false)] + public override object[] Params + { + get => [LogPs]; + } + + /// + public StationaryTransitions() : base() + { + } + + /// + public StationaryTransitions(params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 1) + { + throw new ArgumentException($"The StickyTransitions operator requires exactly one parameter: {nameof(LogPs)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + LogPs = (double[,])@params[0]; + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is not null && kwargs.Length != 0) + { + throw new ArgumentException($"The StationaryTransitions operator requires zero constructor arguments."); + } + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return( + new StationaryTransitions + { + Params = [LogPs] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var logPsPyObj = (double[,])pyObject.GetArrayAttr("log_Ps"); + + return new StationaryTransitions + { + Params = [logPsPyObj] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/StickyTransitions.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/StickyTransitions.cs new file mode 100644 index 00000000..6853210c --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/StickyTransitions.cs @@ -0,0 +1,150 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using System.Collections.Generic; +using System.Xml.Serialization; +using Newtonsoft.Json; +using Bonsai.ML.Python; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of StickyTransitions objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class StickyTransitions : TransitionsModel + { + /// + /// The alpha parameter. + /// + [Description("The alpha parameter.")] + public double Alpha { get; set; } = 1.0; + + /// + /// The kappa parameter. + /// + [Description("The kappa parameter.")] + public double Kappa { get; set; } = 100.0; + + /// + /// The Log Ps of the transitions. + /// + [XmlIgnore] + [Description("The log Ps of the transitions.")] + public double[,] LogPs { get; set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(TransitionsModelTypeJsonConverter))] + [Browsable(false)] + public override TransitionsModelType TransitionsModelType => TransitionsModelType.Sticky; + + /// + [JsonProperty] + [Browsable(false)] + public override object[] Params + { + get => [LogPs]; + } + + /// + [JsonProperty] + [XmlIgnore] + public override Dictionary Kwargs => new Dictionary + { + ["alpha"] = Alpha, + ["kappa"] = Kappa, + }; + + /// + [XmlIgnore] + public static new string[] KwargsArray => [ "alpha", "kappa" ]; + + /// + public StickyTransitions() : base() + { + } + + /// + public StickyTransitions(params object[] kwargs) : base(kwargs) + { + } + + /// + protected override void CheckKwargs(params object[] kwargs) + { + if (kwargs is not null && kwargs.Length != 2) + { + throw new ArgumentException($"The StickyTransitions operator requires exactly two constructor arguments: {nameof(Alpha)} and {nameof(Kappa)}."); + } + } + + /// + protected override void UpdateKwargs(params object[] kwargs) + { + Alpha = kwargs[0] switch + { + double a => a, + var a => Convert.ToDouble(a) + }; + Kappa = kwargs[1] switch + { + double k => k, + var k => Convert.ToDouble(k) + }; + } + + /// + protected override void CheckParams(params object[] @params) + { + if (@params is not null && @params.Length != 1) + { + throw new ArgumentException($"The StickyTransitions operator requires exactly one parameter: {nameof(LogPs)}."); + } + } + + /// + protected override void UpdateParams(params object[] @params) + { + if (@params is not null) + { + LogPs = (double[,])@params[0]; + } + } + + /// + /// Returns an observable sequence of objects. + /// + public IObservable Process() + { + return Observable.Return(new StickyTransitions([Alpha, Kappa]) + { + Params = [LogPs] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var alphaPyObj = (int[,])pyObject.GetArrayAttr("alpha"); + var kappaPyObj = (int[,])pyObject.GetArrayAttr("kappa"); + var logPsPyObj = (double[,])pyObject.GetArrayAttr("log_Ps"); + + return new StickyTransitions([alphaPyObj, kappaPyObj]) + { + Params = [logPsPyObj] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModel.cs new file mode 100644 index 00000000..c241c669 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModel.cs @@ -0,0 +1,33 @@ +using System.ComponentModel; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// An abstract class for creating a Transitions model. + /// + public abstract class TransitionsModel : PythonModel + { + /// + /// The type of Transitions model. + /// + public abstract TransitionsModelType TransitionsModelType { get; } + + /// + [Browsable(false)] + protected override string ModelName => "transitions"; + + /// + [Browsable(false)] + protected override string ModelType => TransitionsModelLookup.GetString(TransitionsModelType); + + /// + public TransitionsModel() : base() + { + } + + /// + public TransitionsModel(params object[] kwargs) : base(kwargs) + { + } + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelLookup.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelLookup.cs new file mode 100644 index 00000000..4a46b6fb --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelLookup.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// A lookup class for relating different to + /// and their corresponding Python string representations. + /// + public static class TransitionsModelLookup + { + private static readonly Dictionary _lookup = new Dictionary + { + { TransitionsModelType.Stationary, (typeof(StationaryTransitions), "stationary") }, + { TransitionsModelType.ConstrainedStationary, (typeof(ConstrainedStationaryTransitions), "constrained") }, + { TransitionsModelType.Sticky, (typeof(StickyTransitions), "sticky") }, + { TransitionsModelType.NeuralNetworkRecurrent, (typeof(NeuralNetworkRecurrentTransitions), "nn_recurrent") } + }; + + /// + /// Gets the of the corresponding to the given . + /// + public static Type GetTransitionsClassType(TransitionsModelType type) => _lookup[type].Type; + + /// + /// Gets the Python string representation of the given . + /// + public static string GetString(TransitionsModelType type) => _lookup[type].StringValue; + + /// + /// Gets the corresponding to the given Python string representation. + /// + public static TransitionsModelType GetFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Gets the corresponding to the given of . + /// + public static TransitionsModelType GetFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelType.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelType.cs new file mode 100644 index 00000000..e8be1dea --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelType.cs @@ -0,0 +1,28 @@ +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Represents the type of transitions in a hidden Markov model. + /// + public enum TransitionsModelType + { + /// + /// Stationary transitions. + /// + Stationary, + + /// + /// Constrained stationary transitions. + /// + ConstrainedStationary, + + /// + /// Sticky transitions. + /// + Sticky, + + /// + /// Neural network recurrent transitions. + /// + NeuralNetworkRecurrent + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelTypeJsonConverter.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelTypeJsonConverter.cs new file mode 100644 index 00000000..05b53c73 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsModelTypeJsonConverter.cs @@ -0,0 +1,24 @@ +using Newtonsoft.Json; +using System; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Provides a type converter to convert between and the corresponding Python string representation. + /// + public class TransitionsModelTypeJsonConverter : JsonConverter + { + /// + public override TransitionsModelType ReadJson(JsonReader reader, Type objectType, TransitionsModelType existingValue, bool hasExistingValue, JsonSerializer serializer) + { + string stringValue = reader.Value?.ToString(); + return TransitionsModelLookup.GetFromString(stringValue); + } + + /// + public override void WriteJson(JsonWriter writer, TransitionsModelType value, JsonSerializer serializer) + { + writer.WriteValue(TransitionsModelLookup.GetString(value)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/main.py b/src/Bonsai.ML.HiddenMarkovModels/main.py new file mode 100644 index 00000000..aad1386e --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/main.py @@ -0,0 +1,266 @@ +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import threading +import asyncio +from ssm import HMM, util +import numpy as np +import autograd.numpy.random as npr +from scipy.optimize import linear_sum_assignment +from scipy.special import logsumexp +import pickle + +npr.seed(0) + + +class HiddenMarkovModel(HMM): + + def __init__( + self, + num_states: int, + dimensions: int, + observations_model_type: str, + transitions_model_type: str, + initial_state_distribution: list[float] = None, + observations_params: tuple = None, + observations_kwargs: dict = None, + transitions_params: tuple = None, + transitions_kwargs: dict = None + ): + + self.num_states = num_states + self.dimensions = dimensions + self.observations_model_type = observations_model_type + self.transitions_model_type = transitions_model_type + + if observations_kwargs is not None: + for (key, value) in observations_kwargs.items(): + if isinstance(value, list): + observations_kwargs[key] = np.array(value) + + if transitions_kwargs is not None: + for (key, value) in transitions_kwargs.items(): + if isinstance(value, list) and key != "hidden_layer_sizes": + transitions_kwargs[key] = np.array(value) + if key == "hidden_layer_sizes": + transitions_kwargs[key] = tuple(value) + + if "nonlinearity_type" in transitions_kwargs.keys(): + transitions_kwargs["nonlinearity"] = value + transitions_kwargs.pop(key) + + super(HiddenMarkovModel, self).__init__( + K=self.num_states, + D=self.dimensions, + observations=self.observations_model_type, + observation_kwargs=observations_kwargs, + transitions=self.transitions_model_type, + transition_kwargs=transitions_kwargs + ) + + self.update_params(initial_state_distribution, + transitions_params, observations_params) + + if self.transitions_model_type == "nn_recurrent": + hidden_layer_sizes = np.array([len(layer) for layer in self.transitions.weights[1:]]) + self.transitions.hidden_layer_sizes = hidden_layer_sizes + + def get_nonlinearity_type(func): + if func == util.relu: + return "relu" + elif func == util.logistic: + return "sigmoid" + else: + return "tanh" + + self.transitions.nonlinearity_type = get_nonlinearity_type(self.transitions.nonlinearity) + + self.log_alpha = None + self.state_probabilities = None + + self.batch = None + self.batch_observations = np.array([[]], dtype=float) + self.is_running = False + self._fit_finished = False + self.loop = None + self.thread = None + self.curr_batch_size = 0 + self.flush_data_between_batches = True + self.inferred_most_probable_states = np.array([], dtype=int) + + def update_params(self, initial_state_distribution, transitions_params, observations_params): + hmm_params = self.params + + if initial_state_distribution is not None: + hmm_params = ((np.array(initial_state_distribution),), + ) + hmm_params[1:] + + if transitions_params is not None: + trans_params = tuple([np.array(param) for param in transitions_params]) + if isinstance(hmm_params[1], tuple): + hmm_params = (hmm_params[0],) + (trans_params,) + (hmm_params[2],) + else: + hmm_params = (hmm_params[0],) + trans_params + (hmm_params[2],) + + if observations_params is not None: + obs_params = tuple([np.array(param) for param in observations_params]) + if isinstance(hmm_params[2], tuple): + hmm_params = hmm_params[:2] + (obs_params,) + else: + hmm_params = hmm_params[:2] + obs_params + + self.params = hmm_params + + self.initial_state_distribution = hmm_params[0][0] + + if isinstance(hmm_params[1], tuple): + self.transitions_params = hmm_params[1] + else: + self.transitions_params = (hmm_params[1],) + + if isinstance(hmm_params[2], tuple): + self.observations_params = hmm_params[2] + else: + self.observations_params = (hmm_params[2],) + + def infer_state(self, observation: list[float]): + + self.log_alpha = self.compute_log_alpha( + np.expand_dims(np.array(observation), 0), self.log_alpha) + self.state_probabilities = np.exp(self.log_alpha).astype(np.double) + return self.state_probabilities.argmax() + + def compute_log_alpha(self, obs, log_alpha=None): + + if log_alpha is None: + log_alpha = (np.log(self.init_state_distn.initial_state_distn) + + self.observations.log_likelihoods(obs, None, None, None)).squeeze() + return log_alpha - logsumexp(log_alpha) + + m = np.max(log_alpha) + + log_alpha = (np.log(np.dot(np.exp(log_alpha - m), self.transitions.transition_matrices(obs, None, None, None).squeeze()) + ) + m + self.observations.log_likelihoods(obs, None, None, None)).squeeze() + + return log_alpha - logsumexp(log_alpha) + + def save_model_to_pickle(self, path: str): + pickle.save(self, path) + + def load_model_from_pickle(self, path: str): + self = pickle.load(path) + + def fit_async(self, + observation: list[float], + vars_to_estimate: dict = None, + batch_size: int = 20, + max_iter: int = 50, + flush_data_between_batches: bool = False): + + self.flush_data_between_batches = flush_data_between_batches + + if self.batch is None: + self.batch = np.expand_dims(np.array(observation), 0) + self.curr_batch_size += 1 + + elif self.curr_batch_size < batch_size or not flush_data_between_batches: + self.batch = np.vstack( + [self.batch, np.expand_dims(np.array(observation), 0)]) + self.curr_batch_size += 1 + + elif self.curr_batch_size == batch_size: + self.batch = np.vstack( + [self.batch[1:], np.expand_dims(np.array(observation), 0)]) + + self.batch_observations = self.batch + + if not self.is_running and self.loop is None and self.thread is None: + + if self.curr_batch_size >= batch_size: + + if vars_to_estimate is None or vars_to_estimate == {}: + vars_to_estimate = { + "initial_state_distribution": True, + "transitions_params": True, + "observations_params": True + } + + def calculate_permutation(mat1, mat2): + num_states = mat1.shape[0] + cost_matrix = np.zeros((num_states, num_states)) + for i in range(num_states): + for j in range(num_states): + cost_matrix[i, j] = np.linalg.norm( + mat1[i] - mat2[j]) + return linear_sum_assignment(cost_matrix)[1] + + def start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + def on_completion(future): + + if self.observations_model_type == "gaussian": + permutation = calculate_permutation( + self.observations_params[0], self.params[2][0]) + super(HiddenMarkovModel, self).permute(permutation) + + initial_state_distribution = None if vars_to_estimate[ + "initial_state_distribution"] else self.initial_state_distribution + transitions_params = None if vars_to_estimate[ + "transitions_params"] else self.transitions_params + observations_params = None if vars_to_estimate[ + "observations_params"] else self.observations_params + + self.update_params(initial_state_distribution, + transitions_params, observations_params) + + self.is_running = False + self._fit_finished = True + self.curr_batch_size = 0 + + if self.flush_data_between_batches: + self.batch = None + + self.inferred_most_probable_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int) + + self.is_running = True + + if self.loop is None or self.loop.is_closed(): + self.loop = asyncio.new_event_loop() + + if self.thread is None: + self.thread = threading.Thread( + target=start_loop, args=(self.loop,)) + self.thread.start() + + future = asyncio.run_coroutine_threadsafe(self._fit_async( + self.batch, method="em", num_iters=max_iter, init_method="kmeans"), self.loop) + future.add_done_callback(on_completion) + + return self.is_running + + async def _fit_async(self, *args, **kwargs): + func = partial(super(HiddenMarkovModel, self).fit, *args, **kwargs) + with ThreadPoolExecutor() as pool: + await self.loop.run_in_executor(pool, func) + + def get_fit_finished(self): + return self._fit_finished + + def reset_fit_loop(self): + self._fit_finished = False + + if self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + + if self.thread.is_alive(): + self.thread.join() + + self.loop.stop() + self.loop.close() + + del (self.thread) + del (self.loop) + + self.thread = None + self.loop = None diff --git a/src/Bonsai.ML.LinearDynamicalSystems/Bonsai.ML.LinearDynamicalSystems.csproj b/src/Bonsai.ML.LinearDynamicalSystems/Bonsai.ML.LinearDynamicalSystems.csproj index 7d6d59b3..8110c215 100644 --- a/src/Bonsai.ML.LinearDynamicalSystems/Bonsai.ML.LinearDynamicalSystems.csproj +++ b/src/Bonsai.ML.LinearDynamicalSystems/Bonsai.ML.LinearDynamicalSystems.csproj @@ -13,6 +13,5 @@ - \ No newline at end of file diff --git a/src/Bonsai.ML.Python/Bonsai.ML.Python.csproj b/src/Bonsai.ML.Python/Bonsai.ML.Python.csproj new file mode 100644 index 00000000..2015f856 --- /dev/null +++ b/src/Bonsai.ML.Python/Bonsai.ML.Python.csproj @@ -0,0 +1,13 @@ + + + Bonsai.ML.Python + Bonsai Library containing python integration functions for machine learning. + Bonsai Rx Bonsai.ML.Python + net472;netstandard2.0 + + + + + + + diff --git a/src/Bonsai.ML.Python/FormatToPython.cs b/src/Bonsai.ML.Python/FormatToPython.cs new file mode 100644 index 00000000..d10ba36b --- /dev/null +++ b/src/Bonsai.ML.Python/FormatToPython.cs @@ -0,0 +1,28 @@ +using System.Linq; +using System.ComponentModel; +using System.Reactive.Linq; +using System; + +namespace Bonsai.ML.Python +{ + /// + /// Represents an operator that can convert an object into a properly formatted string that is consistent with python syntax. + /// For example, a tuple (1, 2, 3) will be converted to the string "(1, 2, 3)". A list of [0, 1, 2] will be converted to the string "[0, 1, 2]". + /// + [Combinator] + [Description("Represents an operator that can convert an object into a properly formatted string that is consistent with python.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class FormatToPython + { + /// + /// Transforms the elements of an observable sequence into a properly formatted string that is consistent with python syntax. + /// + public IObservable Process(IObservable source) + { + var stringFormatter = new StringFormatter(); + return source.Select(value => { + return stringFormatter.Format(value); + }); + } + } +} diff --git a/src/Bonsai.ML.Python/NumpyHelper.cs b/src/Bonsai.ML.Python/NumpyHelper.cs new file mode 100644 index 00000000..9ea3dfa7 --- /dev/null +++ b/src/Bonsai.ML.Python/NumpyHelper.cs @@ -0,0 +1,226 @@ +using System; +using Python.Runtime; +using System.Collections.Generic; +using System.ComponentModel; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System.Text; +using System.Linq; +using System.Runtime.InteropServices; + +namespace Bonsai.ML.Python +{ + /// + /// Provides a set of static methods for working with NumPy arrays. + /// + public static class NumpyHelper + { + /// + /// Represents a NumPy array interface for interacting with representing NumPy arrays. + /// + public class NumpyArrayInterface + { + /// + /// Initializes a new instance of the class. + /// + public NumpyArrayInterface(PyObject obj) + { + if (!IsNumPyArray(obj)) + { + throw new ArgumentException($"Object is not a numpy array.", nameof(obj)); + } + var meta = obj.GetAttr("__array_interface__"); + IsCStyleContiguous = meta["strides"] == null; + Address = new IntPtr(meta["data"][0].As()); + + var typestr = meta["typestr"].As(); + var dtype = typestr.Substring(1); + switch (dtype) + { + case "b1": + DataType = typeof(bool); + break; + case "f4": + DataType = typeof(float); + break; + case "f8": + DataType = typeof(double); + break; + case "i2": + DataType = typeof(short); + break; + case "i4": + DataType = typeof(int); + break; + case "i8": + DataType = typeof(long); + break; + case "u1": + DataType = typeof(byte); + break; + case "u2": + DataType = typeof(ushort); + break; + case "u4": + DataType = typeof(uint); + break; + case "u8": + DataType = typeof(ulong); + break; + default: + throw new ArgumentException($"Type is not currently supported.", nameof(dtype)); + } + Shape = obj.GetAttr("shape").As(); + NBytes = obj.GetAttr("nbytes").As(); + } + + /// + /// The memory address of the NumPy array data. + /// + public readonly IntPtr Address; + + /// + /// The C# data type representing the elements of the NumPy array. + /// + public readonly Type DataType; + + /// + /// The shape of the NumPy array. + /// + public readonly long[] Shape; + + /// + /// The number of bytes in the NumPy array. + /// + public readonly int NBytes; + + /// + /// A value indicating whether the NumPy array is C-style contiguous. + /// + public readonly bool IsCStyleContiguous; + } + + /// + /// Converts a representing a NumPy array to a C# . + /// + public static Array PyObjectToArray(PyObject array) + { + var info = new NumpyArrayInterface(array); + byte[] data = new byte[info.NBytes]; + Marshal.Copy(info.Address, data, 0, info.NBytes); + if (info.DataType == typeof(byte) && info.Shape.Length == 1) + { + return data; + } + var result = Array.CreateInstance(info.DataType, info.Shape); + Buffer.BlockCopy(data, 0, result, 0, info.NBytes); + return result; + } + + private static PyObject deepcopy; + + private static readonly Lazy np = new(InitializeNumpy); + + private static readonly Dictionary np_dtypes = new(); + + private static readonly Dictionary csharp_dtypes = new() + { + { "uint8", typeof(byte) }, + { "uint16", typeof(ushort) }, + { "uint32", typeof(uint) }, + { "uint64", typeof(ulong) }, + { "int16", typeof(short) }, + { "int32", typeof(int) }, + { "int64", typeof(long) }, + { "float32", typeof(float) }, + { "float64", typeof(double) }, + }; + + /// + /// Initializes the NumPy module and returns a reference to the module. + /// + public static PyObject InitializeNumpy() + { + var np = Py.Import("numpy"); + np_dtypes.Add(typeof(byte), np.GetAttr("uint8")); + np_dtypes.Add(typeof(ushort), np.GetAttr("uint16")); + np_dtypes.Add(typeof(uint), np.GetAttr("uint32")); + np_dtypes.Add(typeof(ulong), np.GetAttr("uint64")); + np_dtypes.Add(typeof(short), np.GetAttr("int16")); + np_dtypes.Add(typeof(int), np.GetAttr("int32")); + np_dtypes.Add(typeof(long), np.GetAttr("int64")); + np_dtypes.Add(typeof(float), np.GetAttr("float32")); + np_dtypes.Add(typeof(double), np.GetAttr("float64")); + var copy = Py.Import("copy"); + deepcopy = copy.GetAttr("deepcopy"); + return np; + } + + /// + /// Checks if the is a type of NumPy array. + /// + public static bool IsNumPyArray(PyObject obj) + { + dynamic numpy = np.Value; + return numpy.ndarray.__instancecheck__(obj); + } + + /// + /// Gets the NumPy data type for the specified C# type. + /// + public static PyObject GetNumpyDataType(Type type) + { + PyObject dtype; + np_dtypes.TryGetValue(type, out dtype); + if (dtype == null) + { + throw new ArgumentException("Type is not currently supported.", nameof(type)); + } + return dtype; + } + + /// + /// Gets the C# data type for the specified NumPy data type. + /// + public static Type GetCSharpDataType(string str) + { + Type type; + csharp_dtypes.TryGetValue(str, out type); + if (type == null) + { + throw new ArgumentException("Could not determine data type from string. Data type is either incorrect or not supported.", nameof(str)); + } + return type; + } + + /// + /// A custom type converter for NumPy data types. + /// + public class NumpyDataTypes : StringConverter + { + /// + public override bool GetStandardValuesSupported(ITypeDescriptorContext context) + { + return true; + } + + /// + public override StandardValuesCollection GetStandardValues(ITypeDescriptorContext context) + { + var dtypes = new List + { + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64" + }; + return new StandardValuesCollection(dtypes); + } + } + } +} diff --git a/src/Bonsai.ML.Python/Properties/launchSettings.json b/src/Bonsai.ML.Python/Properties/launchSettings.json new file mode 100644 index 00000000..b48bcfa9 --- /dev/null +++ b/src/Bonsai.ML.Python/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "Bonsai": { + "commandName": "Executable", + "executablePath": "$(registry:HKEY_CURRENT_USER\\Software\\Bonsai Foundation\\Bonsai@InstallDir)Bonsai.exe", + "commandLineArgs": "--lib:\"$(TargetDir).\"", + "nativeDebugging": true + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Python/PythonHelper.cs b/src/Bonsai.ML.Python/PythonHelper.cs new file mode 100644 index 00000000..c3326c2b --- /dev/null +++ b/src/Bonsai.ML.Python/PythonHelper.cs @@ -0,0 +1,107 @@ +using Python.Runtime; +using System; +using System.Collections.Generic; + +namespace Bonsai.ML.Python +{ + /// + /// Provides a set of extension methods for working with instances. + /// + public static class PythonHelper + { + /// + /// Gets the value of the specified array attribute from the Python object. + /// + /// The Python object. + /// The name of the attribute to retrieve. + /// The array value of the specified attribute. + public static object GetArrayAttr(this PyObject pyObject, string attributeName) + { + using var attr = pyObject.GetAttr(attributeName); + return ConvertPythonObjectToCSharp(attr); + } + + /// + /// Gets the value of the specified attribute from the Python object. + /// + /// The type of the attribute to retrieve. + /// The Python object. + /// The name of the attribute to retrieve. + /// The value of the specified attribute. + public static T GetAttr(this PyObject pyObject, string attributeName) + { + using var attr = pyObject.GetAttr(attributeName); + if (attr == null || attr.IsNone()) + { + return default; + } + return attr.As(); + } + + /// + /// Converts the specified Python object to a C# object. + /// + /// The Python object to convert. + /// The C# object representation of the Python object. + public static object ConvertPythonObjectToCSharp(PyObject pyObject) + { + if (pyObject == null || pyObject.IsNone()) + { + return null; + } + + if (PyInt.IsIntType(pyObject)) + { + return pyObject.As(); + } + + if (PyFloat.IsFloatType(pyObject)) + { + return pyObject.As(); + } + + if (PyString.IsStringType(pyObject)) + { + return pyObject.As(); + } + + if (PyList.IsListType(pyObject)) + { + var pyList = new PyList(pyObject); + var resultList = new List(); + foreach (PyObject item in pyList) + resultList.Add(ConvertPythonObjectToCSharp(item)); + return resultList; + } + + if (PyDict.IsDictType(pyObject)) + { + var pyDict = new PyDict(pyObject); + var resultDict = new Dictionary(); + foreach (PyObject key in pyDict.Keys()) + { + var value = pyDict[key]; + resultDict.Add(ConvertPythonObjectToCSharp(key), ConvertPythonObjectToCSharp(value)); + } + return resultDict; + } + + if (PyTuple.IsTupleType(pyObject)) + { + var pyTuple = new PyTuple(pyObject); + var resultArray = new object[pyTuple.Length()]; + for (int i = 0; i < pyTuple.Length(); i++) { + resultArray[i] = ConvertPythonObjectToCSharp(pyTuple[i]); + } + return resultArray; + } + + if (NumpyHelper.IsNumPyArray(pyObject)) + { + return NumpyHelper.PyObjectToArray(pyObject); + } + + throw new InvalidOperationException($"Unable to convert python data type to C#. Allowed data types include: integer, float, string, list, dictionary, and numpy arrays. Instead, got: {pyObject.GetPythonType()}"); + } + } +} diff --git a/src/Bonsai.ML.Python/StringFormatter.cs b/src/Bonsai.ML.Python/StringFormatter.cs new file mode 100644 index 00000000..f23f2066 --- /dev/null +++ b/src/Bonsai.ML.Python/StringFormatter.cs @@ -0,0 +1,160 @@ +using System; +using System.Text; +using System.Collections; +using System.Collections.Generic; +using System.Reflection; +using System.Linq; + +namespace Bonsai.ML.Python +{ + /// + /// Represents a C# to Python string formatter class. + /// + public class StringFormatter + { + private readonly Dictionary> typeHandlers; + private readonly Dictionary typeProperties; + private readonly StringBuilder sb; + + /// + /// Initializes a new instance of the class. + /// + public StringFormatter() + { + typeHandlers = new Dictionary>(); + typeProperties = new Dictionary(); + sb = new StringBuilder(); + } + + /// + /// Formats the specified object into a string that is consistent with Python syntax. + /// + /// The object to format. + /// A string that is consistent with Python syntax. + public string Format(object obj) + { + sb.Clear(); + ConvertCSharpToPythonStringInternal(obj, sb); + return sb.ToString(); + } + + private void ConvertCSharpToPythonStringInternal(object obj, StringBuilder sb) + { + if (obj == null) + { + sb.Append("None"); + return; + } + + var type = obj.GetType(); + + if (!typeHandlers.TryGetValue(type, out var handler)) + { + handler = CreateTypeHandler(type); + typeHandlers[type] = handler; + } + + handler(obj, sb); + } + + private Action CreateTypeHandler(Type type) + { + if (type == typeof(string) || type == typeof(char)) + { + return (obj, sb) => sb.Append('"').Append(obj).Append('"'); + } + + if (type == typeof(bool)) + { + return (obj, sb) => sb.Append(((bool)obj).ToString().ToLower()); + } + + if (type == typeof(int) || type == typeof(double) || type == typeof(float) || type == typeof(long) || type == typeof(short) || type == typeof(byte) || type == typeof(ushort) || type == typeof(uint) || type == typeof(ulong) || type == typeof(sbyte) || type == typeof(decimal)) + { + return (obj, sb) => sb.Append(obj); + } + + if (type.IsArray) + { + return (obj, sb) => + { + var array = (Array)obj; + sb.Append('['); + for (int i = 0; i < array.Length; i++) + { + if (i > 0) sb.Append(", "); + ConvertCSharpToPythonStringInternal(array.GetValue(i), sb); + } + sb.Append(']'); + }; + } + + if (typeof(IList).IsAssignableFrom(type)) + { + return (obj, sb) => + { + var list = (IList)obj; + sb.Append('['); + for (int i = 0; i < list.Count; i++) + { + if (i > 0) sb.Append(", "); + ConvertCSharpToPythonStringInternal(list[i], sb); + } + sb.Append(']'); + }; + } + + if (typeof(IDictionary).IsAssignableFrom(type)) + { + return (obj, sb) => + { + var dict = (IDictionary)obj; + sb.Append('{'); + bool first = true; + foreach (DictionaryEntry entry in dict) + { + if (!first) sb.Append(", "); + ConvertCSharpToPythonStringInternal(entry.Key, sb); + sb.Append(": "); + ConvertCSharpToPythonStringInternal(entry.Value, sb); + first = false; + } + sb.Append('}'); + }; + } + + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Tuple<,>)) + { + var itemProperties = type.GetProperties().Where(p => p.Name.StartsWith("Item")).OrderBy(p => p.Name).ToArray(); + return (obj, sb) => + { + sb.Append('('); + for (int i = 0; i < itemProperties.Length; i++) + { + if (i > 0) sb.Append(", "); + ConvertCSharpToPythonStringInternal(itemProperties[i].GetValue(obj), sb); + } + sb.Append(')'); + }; + } + + if (!typeProperties.TryGetValue(type, out var properties)) + { + properties = type.GetProperties(BindingFlags.Public | BindingFlags.Instance); + typeProperties[type] = properties; + } + + return (obj, sb) => + { + sb.Append('{'); + for (int i = 0; i < properties.Length; i++) + { + if (i > 0) sb.Append(", "); + sb.Append('"').Append(properties[i].Name).Append("\": "); + ConvertCSharpToPythonStringInternal(properties[i].GetValue(obj), sb); + } + sb.Append('}'); + }; + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Visualizers/BarSeriesOxyPlotBase.cs b/src/Bonsai.ML.Visualizers/BarSeriesOxyPlotBase.cs new file mode 100644 index 00000000..55188fea --- /dev/null +++ b/src/Bonsai.ML.Visualizers/BarSeriesOxyPlotBase.cs @@ -0,0 +1,252 @@ +using System.Windows.Forms; +using OxyPlot; +using OxyPlot.Series; +using OxyPlot.WindowsForms; +using System.Drawing; +using System; +using OxyPlot.Axes; +using System.Collections; + +namespace Bonsai.ML.Visualizers +{ + internal class BarSeriesOxyPlotBase : UserControl + { + private PlotView view; + private PlotModel model; + private OxyColor defaultBarSeriesColor = OxyColors.Automatic; + + internal Axis xAxis; + private Axis yAxis; + + private StatusStrip statusStrip; + + /// + /// Gets or sets the integer value that determines how many data points should be shown along the x axis. + /// + public int Capacity { get; set; } + + /// + /// Gets the status strip control. + /// + public StatusStrip StatusStrip => statusStrip; + + /// + /// Initializes a new instance of the class + /// + public BarSeriesOxyPlotBase() + { + Initialize(); + } + + private void Initialize() + { + view = new PlotView + { + Size = Size, + Dock = DockStyle.Fill, + }; + + model = new PlotModel(); + + xAxis = new CategoryAxis + { + Position = AxisPosition.Bottom, + Title = "Category", + MajorGridlineStyle = LineStyle.Solid, + MinorGridlineStyle = LineStyle.Dot, + FontSize = 9, + Key = "y1" + }; + + yAxis = new LinearAxis + { + Position = AxisPosition.Left, + Title = "Value", + Key = "x1" + }; + + model.Axes.Add(xAxis); + model.Axes.Add(yAxis); + + view.Model = model; + Controls.Add(view); + + statusStrip = new StatusStrip + { + Visible = false + }; + + view.MouseClick += new MouseEventHandler(onMouseClick); + Controls.Add(statusStrip); + + Controls.Add(statusStrip); + + AutoScaleDimensions = new SizeF(6F, 13F); + } + + private void onMouseClick(object sender, MouseEventArgs e) + { + if (e.Button == MouseButtons.Right) + { + statusStrip.Visible = !statusStrip.Visible; + } + } + + /// + /// Method to add a new bar series to the data plot. + /// Requires a string for the name of the bar series + /// Fill color of the bar series is optional. + /// + public BarSeries AddNewBarSeries(string barSeriesName, OxyColor? fillColor = null, OxyColor? strokeColor = null) + { + BarSeries barSeries = new BarSeries + { + Title = barSeriesName, + FillColor = fillColor ?? defaultBarSeriesColor, + StrokeColor = strokeColor ?? OxyColors.Automatic, + StrokeThickness = strokeColor != null ? 1 : 0, + XAxisKey = "x1", + YAxisKey = "y1" + }; + AddSeriesToModel(barSeries); + return barSeries; + } + + /// + /// Method to add a new bar series to the data plot. + /// Requires a string for the name of the bar series + /// Fill color of the bar series is optional. + /// + public ErrorBarSeries AddNewErrorBarSeries(string barSeriesName, OxyColor? fillColor = null, OxyColor? strokeColor = null) + { + ErrorBarSeries errorBarSeries = new ErrorBarSeries + { + Title = barSeriesName, + FillColor = fillColor ?? defaultBarSeriesColor, + StrokeColor = strokeColor ?? OxyColors.Automatic, + StrokeThickness = strokeColor != null ? 1 : 0, + XAxisKey = "x1", + YAxisKey = "y1" + }; + AddSeriesToModel(errorBarSeries); + return errorBarSeries; + } + + /// + /// Method to add a series to the plot model. + /// Requires the series. + /// + public void AddSeriesToModel(Series series) + { + model.Series.Add(series); + } + + /// + /// Method to add a value to a bar series. + /// Requires the bar series and value. + /// + public void AddValueToBarSeries(BarSeries barSeries, double value, OxyColor? fillColor = null) + { + var barItem = new BarItem { + Value = value, + Color = fillColor ?? barSeries.FillColor + }; + AddBarItemToBarSeries(barSeries, barItem); + } + + /// + /// Method to add a value to a bar series. + /// Requires the bar series and value. + /// + public void AddValueToBarSeries(BarSeries barSeries, int index, double value, OxyColor? fillColor = null) + { + var barItem = new BarItem { + CategoryIndex = index, + Value = value, + Color = fillColor ?? barSeries.FillColor, + }; + AddBarItemToBarSeries(barSeries, barItem); + } + + /// + /// Method to add value with error to a bar series. + /// Requires the bar series, value, and error. + /// + public void AddValueAndErrorToBarSeries(BarSeries barSeries, double value, double error, OxyColor? fillColor = null) + { + var barItem = new ErrorBarItem { + Value = value, + Error = error, + Color = fillColor ?? barSeries.FillColor + }; + AddBarItemToBarSeries(barSeries, barItem); + } + + /// + /// Method to add value with error to a bar series. + /// Requires the bar series, value, and error. + /// + public void AddValueAndErrorToBarSeries(BarSeries barSeries, int index, double value, double error, OxyColor? fillColor = null) + { + var barItem = new ErrorBarItem { + CategoryIndex = index, + Value = value, + Error = error, + Color = fillColor ?? barSeries.FillColor + }; + AddBarItemToBarSeries(barSeries, barItem); + } + + /// + /// Method to add bar item to a bar series. + /// Requires the bar item and bar series. + /// + public void AddBarItemToBarSeries(BarSeries barSeries, BarItem barItem) + { + barSeries.Items.Add(barItem); + } + + /// + /// Set the minimum and maximum values to show along the y axis. + /// Requires the minValue and maxValue. + /// + public void SetAxes(double minValue, double maxValue) + { + yAxis.Minimum = minValue; + yAxis.Maximum = maxValue; + } + + /// + /// Method to update the plot. + /// + public void UpdatePlot() + { + model.InvalidatePlot(true); + } + + /// + /// Method to reset the bar series. + /// + public void ResetBarSeries(BarSeries barSeries) + { + barSeries.Items.Clear(); + } + + /// + /// Method to reset all series in the current PlotModel. + /// + public void ResetModelSeries() + { + model.Series.Clear(); + } + + /// + /// Method to reset the x and y axes to their default. + /// + public void ResetAxes() + { + xAxis.Reset(); + yAxis.Reset(); + } + } +} diff --git a/src/Bonsai.ML.Visualizers/Bonsai.ML.Visualizers.csproj b/src/Bonsai.ML.Visualizers/Bonsai.ML.Visualizers.csproj index a3de92a3..89a48533 100644 --- a/src/Bonsai.ML.Visualizers/Bonsai.ML.Visualizers.csproj +++ b/src/Bonsai.ML.Visualizers/Bonsai.ML.Visualizers.csproj @@ -16,5 +16,6 @@ + \ No newline at end of file diff --git a/src/Bonsai.ML.Visualizers/GaussianObservationsStatisticsClustersVisualizer.cs b/src/Bonsai.ML.Visualizers/GaussianObservationsStatisticsClustersVisualizer.cs new file mode 100644 index 00000000..0393b837 --- /dev/null +++ b/src/Bonsai.ML.Visualizers/GaussianObservationsStatisticsClustersVisualizer.cs @@ -0,0 +1,332 @@ +using System; +using System.Windows.Forms; +using System.Collections.Generic; +using Bonsai; +using Bonsai.Design; +using Bonsai.ML.Visualizers; +using Bonsai.ML.HiddenMarkovModels.Observations; +using OxyPlot; +using OxyPlot.Series; +using OxyPlot.Axes; +using OxyPlot.WindowsForms; +using MathNet.Numerics.LinearAlgebra; + +[assembly: TypeVisualizer(typeof(GaussianObservationsStatisticsClustersVisualizer), Target = typeof(GaussianObservationsStatistics))] + +namespace Bonsai.ML.Visualizers +{ + /// + /// Provides a type visualizer of to display how the observations + /// cluster with respect to the mean and covariance of each state of an HMM with gaussian observations model. + /// + public class GaussianObservationsStatisticsClustersVisualizer : DialogTypeVisualizer + { + private PlotView view; + private PlotModel model; + private List allLineSeries = null; + private List allScatterSeries = null; + private List colorList = null; + private StatusStrip statusStrip; + private int dimension1SelectedIndex = 0; + private ToolStripComboBox dimension1ComboBox; + private ToolStripLabel dimension1Label; + private int dimension2SelectedIndex = 1; + private ToolStripComboBox dimension2ComboBox; + private ToolStripLabel dimension2Label; + private LinearAxis xAxis; + private LinearAxis yAxis; + + /// + /// Gets the status strip. + /// + public StatusStrip StatusStrip => statusStrip; + + /// + /// Gets the selected index of the first dimension. + /// + public int Dimension1SelectedIndex { get => dimension1SelectedIndex; set => dimension1SelectedIndex = value; } + + /// + /// Gets the selected index of the second dimension. + /// + public int Dimension2SelectedIndex { get => dimension2SelectedIndex; set => dimension2SelectedIndex = value; } + + /// + /// Gets the first dimension combo box. + /// + public ToolStripComboBox Dimension1ComboBox => dimension1ComboBox; + + /// + /// Gets the second dimension combo box. + /// + public ToolStripComboBox Dimension2ComboBox => dimension2ComboBox; + + /// + /// Gets or sets a value indicating whether the data should be buffered. + /// + public bool BufferData { get; set; } = true; + + /// + /// Gets or sets the buffer count. + /// + public int BufferCount { get; set; } = 250; + + /// + public override void Load(IServiceProvider provider) + { + view = new PlotView + { + Dock = DockStyle.Fill, + }; + + model = new PlotModel(); + + xAxis = new LinearAxis + { + Position = AxisPosition.Bottom, + Title = $"Observation Dimension: {dimension1SelectedIndex}", + }; + + yAxis = new LinearAxis + { + Position = AxisPosition.Left, + Title = $"Observation Dimension: {dimension2SelectedIndex}", + }; + + model.Axes.Add(xAxis); + model.Axes.Add(yAxis); + + view.Model = model; + + statusStrip = new StatusStrip + { + Visible = false + }; + + dimension1Label = new ToolStripLabel + { + Text = "X Axis Dimension:", + AutoSize = true + }; + + dimension1ComboBox = new ToolStripComboBox() + { + Name = "X Axis Dimension", + AutoSize = true, + }; + + dimension2Label = new ToolStripLabel + { + Text = "Y Axis Dimension:", + AutoSize = true + }; + + dimension2ComboBox = new ToolStripComboBox() + { + Name = "Y Axis Dimension", + AutoSize = true, + }; + + statusStrip.Items.AddRange(new ToolStripItem[] { + dimension1Label, + dimension1ComboBox, + dimension2Label, + dimension2ComboBox + }); + + view.MouseClick += new MouseEventHandler(onMouseClick); + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + + if (visualizerService != null) + { + visualizerService.AddControl(view); + visualizerService.AddControl(statusStrip); + } + } + + /// + public override void Show(object value) + { + if (value is GaussianObservationsStatistics gaussianObservationsStatistics) + { + + var statesCount = gaussianObservationsStatistics.Means.GetLength(0); + var observationDimensions = gaussianObservationsStatistics.Means.GetLength(1); + + if (dimension1ComboBox.Items.Count == 0) + { + for (int i = 0; i < observationDimensions; i++) + { + dimension1ComboBox.Items.Add(i); + dimension2ComboBox.Items.Add(i); + } + + dimension1ComboBox.SelectedIndexChanged += dimension1ComboBoxSelectedIndexChanged; + dimension1ComboBox.SelectedIndex = dimension1SelectedIndex; + + dimension2ComboBox.SelectedIndexChanged += dimension2ComboBoxSelectedIndexChanged; + dimension2SelectedIndex = Math.Max(dimension2ComboBox.Items.Count - 1, dimension2SelectedIndex); + dimension2ComboBox.SelectedIndex = dimension2SelectedIndex; + } + + if (colorList == null) + { + colorList = new List(); + for (int i = 0; i < statesCount; i++) + { + OxyColor color = OxyPalettes.Jet(statesCount).Colors[i]; + colorList.Add(color); + } + } + + if (allScatterSeries != null) + { + foreach (var scatterSeries in allScatterSeries) + { + scatterSeries.Points.Clear(); + } + } + else + { + allScatterSeries = new List(); + for (int i = 0; i < statesCount; i++) + { + var scatterSeries = new ScatterSeries + { + MarkerType = MarkerType.Circle, + MarkerSize = 4, + MarkerFill = colorList[i], + MarkerStroke = OxyColors.Black, + MarkerStrokeThickness = 1 + }; + allScatterSeries.Add(scatterSeries); + model.Series.Add(scatterSeries); + } + } + + if (allLineSeries != null) + { + foreach (var lineSeries in allLineSeries) + { + lineSeries.Points.Clear(); + } + } + else + { + allLineSeries = new List(); + for (int i = 0; i < statesCount; i++) + { + for (int j = 0; j < 3; j++) + { + var lineSeries = new LineSeries { Color = colorList[i] }; + allLineSeries.Add(lineSeries); + model.Series.Add(lineSeries); + } + } + } + + var batchObservationsCount = gaussianObservationsStatistics.BatchObservations.GetLength(0); + var offset = BufferData && batchObservationsCount > BufferCount ? batchObservationsCount - BufferCount : 0; + for (int i = offset; i < batchObservationsCount; i++) + { + var dim1 = gaussianObservationsStatistics.BatchObservations[i, dimension1SelectedIndex]; + var dim2 = gaussianObservationsStatistics.BatchObservations[i, dimension2SelectedIndex]; + var state = gaussianObservationsStatistics.InferredMostProbableStates[i]; + allScatterSeries[(int)state].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state)); + } + + for (int i = 0; i < statesCount; i++) + { + var xMean = gaussianObservationsStatistics.Means[i, dimension1SelectedIndex]; + var yMean = gaussianObservationsStatistics.Means[i, dimension2SelectedIndex]; + + var xVar = gaussianObservationsStatistics.CovarianceMatrices[i, dimension1SelectedIndex, dimension1SelectedIndex]; + var yVar = gaussianObservationsStatistics.CovarianceMatrices[i, dimension2SelectedIndex, dimension2SelectedIndex]; + var xyCov = gaussianObservationsStatistics.CovarianceMatrices[i, dimension2SelectedIndex, dimension1SelectedIndex]; + + var covariance = Matrix.Build.DenseOfArray(new double[,] { + { + xVar, + xyCov + }, + { + xyCov, + yVar + }, + }); + + var evd = covariance.Evd(); + var evals = evd.EigenValues.Real(); + evals = evals.PointwiseAbsoluteMaximum(0); + var evecs = evd.EigenVectors; + + double angle = Math.Atan2(evecs[1, 0], evecs[0, 0]); + + for (int j = 1; j < 4; j++) + { + + var majorAxis = j * Math.Sqrt(evals[0]); + var minorAxis = j * Math.Sqrt(evals[1]); + + var points = new List(); + int numPoints = 100; + for (int k = 0; k < numPoints + 1; k++) + { + double theta = 2 * Math.PI * k / numPoints; + double x = majorAxis * Math.Cos(theta); + double y = minorAxis * Math.Sin(theta); + + double xRot = x * Math.Cos(angle) - y * Math.Sin(angle); + double yRot = x * Math.Sin(angle) + y * Math.Cos(angle); + + points.Add(new DataPoint(xMean + xRot, yMean + yRot)); + } + + allLineSeries[i * 3 + j - 1].Points.AddRange(points); + + } + } + model.InvalidatePlot(true); + } + } + + private void dimension1ComboBoxSelectedIndexChanged(object sender, EventArgs e) + { + if (dimension1ComboBox.SelectedIndex != dimension1SelectedIndex) + { + dimension1SelectedIndex = dimension1ComboBox.SelectedIndex; + xAxis.Title = $"Observation Dimension: {dimension1SelectedIndex}"; + } + } + + private void dimension2ComboBoxSelectedIndexChanged(object sender, EventArgs e) + { + if (dimension2ComboBox.SelectedIndex != dimension2SelectedIndex) + { + dimension2SelectedIndex = dimension2ComboBox.SelectedIndex; + yAxis.Title = $"Observation Dimension: {dimension2SelectedIndex}"; + } + } + + private void onMouseClick(object sender, MouseEventArgs e) + { + if (e.Button == MouseButtons.Right) + { + statusStrip.Visible = !statusStrip.Visible; + } + } + + /// + public override void Unload() + { + allLineSeries = null; + colorList = null; + allScatterSeries = null; + if (!view.IsDisposed) + { + view.Dispose(); + } + } + } +} diff --git a/src/Bonsai.ML.Visualizers/GaussianObservationsStatisticsVisualizer.cs b/src/Bonsai.ML.Visualizers/GaussianObservationsStatisticsVisualizer.cs new file mode 100644 index 00000000..64a5884a --- /dev/null +++ b/src/Bonsai.ML.Visualizers/GaussianObservationsStatisticsVisualizer.cs @@ -0,0 +1,117 @@ +using System; +using System.Windows.Forms; +using System.Collections.Generic; +using Bonsai; +using Bonsai.Design; +using Bonsai.ML.Visualizers; +using Bonsai.ML.HiddenMarkovModels.Observations; +using OxyPlot; +using OxyPlot.Series; + +[assembly: TypeVisualizer(typeof(GaussianObservationStatisticsVisualizer), Target = typeof(GaussianObservationsStatistics))] + +namespace Bonsai.ML.Visualizers +{ + /// + /// Provides a type visualizer of to display the means and standard + /// deviations of each state of an HMM with gaussian observations model. + /// + public class GaussianObservationStatisticsVisualizer : DialogTypeVisualizer + { + + private BarSeriesOxyPlotBase Plot; + private List allBarSeries = null; + private GaussianObservationsStatistics shown = null; + + /// + public override void Load(IServiceProvider provider) + { + Plot = new BarSeriesOxyPlotBase() + { + Dock = DockStyle.Fill, + }; + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + + if (visualizerService != null) + { + visualizerService.AddControl(Plot); + } + + if (shown != null) + { + var value = shown; + shown = null; + Show(value); + } + } + + /// + public override void Show(object value) + { + if (value is GaussianObservationsStatistics statistics && statistics != shown) + { + if (statistics.Means == null || statistics.StdDevs == null) + { + return; + } + + if (allBarSeries == null) + { + allBarSeries = new List(); + var seriesCount = statistics.Means.GetLength(1); + + for (int i = 0; i < seriesCount; i++) + { + allBarSeries.Add(Plot.AddNewErrorBarSeries($"Dimension: {i}", strokeColor: OxyColors.Black)); + } + } + + foreach (var barSeries in allBarSeries) + { + Plot.ResetBarSeries(barSeries); + } + + var minValue = 0.0; + var maxValue = 0.0; + var paddingPercentage = 0.05; + + var nStates = statistics.Means.GetLength(0); + var nDims = statistics.Means.GetLength(1); + + for (int i = 0; i < nStates; i++) + { + OxyColor fillColor = OxyPalettes.Jet(nStates).Colors[i]; + for (int j = 0; j < nDims; j++) + { + var val = statistics.Means[i, j]; + var err = statistics.StdDevs[i, j]; + + minValue = Math.Min(minValue, val - err); + maxValue = Math.Max(maxValue, val + err); + + Plot.AddValueAndErrorToBarSeries(allBarSeries[j], val, err, fillColor: fillColor); + } + } + + var pad = Math.Max(Math.Abs(minValue), Math.Abs(maxValue)) * paddingPercentage; + + Plot.SetAxes(minValue - pad, maxValue + pad); + + shown = statistics; + + Plot.UpdatePlot(); + } + } + + /// + public override void Unload() + { + allBarSeries = null; + if (!Plot.IsDisposed) + { + Plot.Dispose(); + } + } + } +} diff --git a/src/Bonsai.ML.Visualizers/StateProbabilityVisualizer.cs b/src/Bonsai.ML.Visualizers/StateProbabilityVisualizer.cs new file mode 100644 index 00000000..13ddf4fd --- /dev/null +++ b/src/Bonsai.ML.Visualizers/StateProbabilityVisualizer.cs @@ -0,0 +1,101 @@ +using System; +using System.Windows.Forms; +using System.Collections.Generic; +using System.Linq; +using Bonsai; +using Bonsai.Design; +using Bonsai.ML.Visualizers; +using Bonsai.ML.HiddenMarkovModels; +using OxyPlot; +using OxyPlot.Series; +using OxyPlot.Axes; + +[assembly: TypeVisualizer(typeof(StateProbabilityVisualizer), Target = typeof(StateProbability))] + +namespace Bonsai.ML.Visualizers +{ + /// + /// Provides a type visualizer of to display the probabilities + /// of being in each state of an HMM given the observation. + /// + public class StateProbabilityVisualizer : DialogTypeVisualizer + { + private BarSeriesOxyPlotBase Plot; + + private List allBarSeries; + + /// + public override void Load(IServiceProvider provider) + { + Plot = new BarSeriesOxyPlotBase() + { + Dock = DockStyle.Fill, + }; + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + + if (visualizerService != null) + { + visualizerService.AddControl(Plot); + } + } + + /// + public override void Show(object value) + { + if (value is StateProbability stateProbability) + { + if (allBarSeries == null) + { + allBarSeries = new List(); + var statesCount = stateProbability.Probabilities.Length; + + for (int i = 0; i < statesCount; i++) + { + OxyColor fillColor = OxyPalettes.Jet(statesCount).Colors[i]; + allBarSeries.Add(Plot.AddNewBarSeries($"State: {i}", fillColor: fillColor, strokeColor: OxyColors.Black)); + } + } + + foreach (var barSeries in allBarSeries) + { + Plot.ResetBarSeries(barSeries); + } + + var minValue = 0.0; + var maxValue = 0.0; + var paddingPercentage = 0.05; + + var nStates = stateProbability.Probabilities.Length; + CategoryAxis categoryAxis = (CategoryAxis)Plot.xAxis; + categoryAxis.ItemsSource = Enumerable.Range(0, nStates); + + for (int i = 0; i < nStates; i++) + { + var val = stateProbability.Probabilities[i]; + + minValue = Math.Min(minValue, val); + maxValue = Math.Max(maxValue, val); + + Plot.AddValueToBarSeries(allBarSeries[i], i, val); + } + + var pad = Math.Max(Math.Abs(minValue), Math.Abs(maxValue)) * paddingPercentage; + + Plot.SetAxes(minValue - pad, maxValue + pad); + + Plot.UpdatePlot(); + } + } + + /// + public override void Unload() + { + allBarSeries = null; + if (!Plot.IsDisposed) + { + Plot.Dispose(); + } + } + } +} diff --git a/src/Bonsai.ML/Bonsai.ML.csproj b/src/Bonsai.ML/Bonsai.ML.csproj index e3600d14..e87b8d67 100644 --- a/src/Bonsai.ML/Bonsai.ML.csproj +++ b/src/Bonsai.ML/Bonsai.ML.csproj @@ -1,14 +1,11 @@ - - + Bonsai - ML Bonsai Library containing reactive infrastructure for machine learning. Bonsai Rx ML Machine Learning net472;netstandard2.0 - - - + \ No newline at end of file diff --git a/src/Bonsai.ML/CreateModelReference.cs b/src/Bonsai.ML/CreateModelReference.cs new file mode 100644 index 00000000..086d5931 --- /dev/null +++ b/src/Bonsai.ML/CreateModelReference.cs @@ -0,0 +1,33 @@ +using System.ComponentModel; +using System; +using System.Reactive.Linq; + +namespace Bonsai.ML +{ + /// + /// Represents an operator that creates a reference for a named model. + /// + [Combinator] + [Description("Creates a reference for a named model.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class CreateModelReference : INamedElement + { + /// + /// Gets or sets the name of the model to reference. + /// + [Description("The name of the model to reference.")] + public string Name { get ; set; } + + /// + /// Generates an observable sequence that contains the model reference object. + /// + /// + /// A sequence containing a single instance of the + /// class. + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(new ModelReference(Name))); + } + } +} diff --git a/src/Bonsai.ML/ModelReference.cs b/src/Bonsai.ML/ModelReference.cs new file mode 100644 index 00000000..fa8f1390 --- /dev/null +++ b/src/Bonsai.ML/ModelReference.cs @@ -0,0 +1,24 @@ + +namespace Bonsai.ML +{ + /// + /// Bonsai.ML model reference base class + /// + public class ModelReference + { + /// + /// Gets or sets the name of the referenced model. + /// + public string Name { get ; set; } + + /// + /// Initializes a new instance of the class + /// with the specified name. + /// + /// The name of the referenced model. + public ModelReference(string name) + { + Name = name; + } + } +}