Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ARIMA model #577

Merged
merged 68 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
9b1f555
added arima model
Gerhardsa0 Mar 6, 2024
5098a6d
updated some stuff
Gerhardsa0 Mar 11, 2024
493f1af
Merge branch 'main' of https://github.com/Safe-DS/Library into 570-fe…
Gerhardsa0 Mar 18, 2024
ee8655e
added split_rows to time series
Gerhardsa0 Mar 18, 2024
4c21e3c
updated test
Gerhardsa0 Mar 19, 2024
4fde946
added arima functionalitie and docstrings
Gerhardsa0 Mar 19, 2024
c8fa2c6
updated poetry
Gerhardsa0 Mar 20, 2024
bd355fb
updated test
Gerhardsa0 Mar 20, 2024
efb5074
updated test
Gerhardsa0 Mar 20, 2024
f4ec286
added test for timeseries_from_csv
Gerhardsa0 Mar 20, 2024
1fb3f7c
updated test
Gerhardsa0 Mar 20, 2024
ccd52e9
updated test
Gerhardsa0 Mar 20, 2024
451145f
updated test
Gerhardsa0 Mar 20, 2024
0fe5527
saved files for now
Gerhardsa0 Mar 20, 2024
80dee60
saved files for now
Gerhardsa0 Mar 25, 2024
6b2e434
fixed linter changes
Gerhardsa0 Mar 25, 2024
410f142
fixed linter changes
Gerhardsa0 Mar 25, 2024
ffcfdb2
fixed linter changes
Gerhardsa0 Mar 25, 2024
23e2acc
fixed linter changes
Gerhardsa0 Mar 25, 2024
86bc0c9
fixed linter changes
Gerhardsa0 Mar 25, 2024
d3d881b
fixed linter changes
Gerhardsa0 Mar 25, 2024
bc4d99d
fixed linter changes
Gerhardsa0 Mar 25, 2024
add42a9
style: apply automated linter fixes
megalinter-bot Mar 25, 2024
52f6542
style: apply automated linter fixes
megalinter-bot Mar 25, 2024
f81a7c5
fixed linter changes
Gerhardsa0 Mar 26, 2024
aaa75d6
Merge remote-tracking branch 'origin/570-feat-add-arima-model' into 5…
Gerhardsa0 Mar 26, 2024
6162161
fixed linter changes
Gerhardsa0 Mar 26, 2024
a610b72
style: apply automated linter fixes
megalinter-bot Mar 26, 2024
2c265af
fixed mk docs
Gerhardsa0 Mar 26, 2024
20822af
Merge branch 'main' into 570-feat-add-arima-model
Gerhardsa0 Mar 26, 2024
a8edca8
added the compared plot
Gerhardsa0 Mar 27, 2024
a915294
Merge remote-tracking branch 'origin/570-feat-add-arima-model' into 5…
Gerhardsa0 Mar 27, 2024
a9d33b3
pushed pictures
Gerhardsa0 Mar 27, 2024
038ca3e
fixed linter
Gerhardsa0 Mar 27, 2024
e970daa
style: apply automated linter fixes
megalinter-bot Mar 27, 2024
faf3964
changed the workflow with arima model
Gerhardsa0 Mar 28, 2024
ad4a979
Merge remote-tracking branch 'origin/570-feat-add-arima-model' into 5…
Gerhardsa0 Mar 28, 2024
9b16821
fixed linter and snapshot
Gerhardsa0 Mar 28, 2024
172950d
style: apply automated linter fixes
megalinter-bot Mar 28, 2024
5eaacaf
added Hash function
Gerhardsa0 Apr 1, 2024
038dec6
Merge branch 'main' of https://github.com/Safe-DS/Library into 570-fe…
Gerhardsa0 Apr 1, 2024
ded4818
resolved merge conflict
Gerhardsa0 Apr 1, 2024
f50b2f7
style: apply automated linter fixes
megalinter-bot Apr 1, 2024
4c150a9
Update src/safeds/data/tabular/containers/_time_series.py
Gerhardsa0 Apr 2, 2024
d0c1f12
Update src/safeds/data/tabular/containers/_time_series.py
Gerhardsa0 Apr 2, 2024
bbadc77
Update src/safeds/ml/classical/regression/_arima.py
Gerhardsa0 Apr 2, 2024
43deda5
Update src/safeds/data/tabular/containers/_time_series.py
Gerhardsa0 Apr 2, 2024
4fc6c84
Update src/safeds/data/tabular/containers/_time_series.py
Gerhardsa0 Apr 2, 2024
d21c8d8
Update src/safeds/ml/classical/regression/_arima.py
Gerhardsa0 Apr 2, 2024
5c09057
Update src/safeds/ml/classical/regression/_arima.py
Gerhardsa0 Apr 2, 2024
b232e3c
Update src/safeds/ml/classical/regression/_arima.py
Gerhardsa0 Apr 2, 2024
94769a5
Update src/safeds/ml/classical/regression/_arima.py
Gerhardsa0 Apr 2, 2024
c052d06
Update src/safeds/ml/classical/regression/_arima.py
Gerhardsa0 Apr 2, 2024
b32efaf
Update src/safeds/ml/classical/regression/_arima.py
Gerhardsa0 Apr 2, 2024
069b944
Update src/safeds/ml/classical/regression/_arima.py
Gerhardsa0 Apr 2, 2024
5d246ba
added code review changes
Gerhardsa0 Apr 2, 2024
c55c83c
added enumerate
Gerhardsa0 Apr 2, 2024
bd91258
delted useless var
Gerhardsa0 Apr 2, 2024
9ba25c3
delted useless var
Gerhardsa0 Apr 2, 2024
4de046e
merge conflict
Gerhardsa0 Apr 2, 2024
ef4c8f3
Merge branch 'main' of https://github.com/Safe-DS/Library into 570-fe…
Gerhardsa0 Apr 2, 2024
67a7a2c
changed lock for docs
Gerhardsa0 Apr 2, 2024
06822e5
style: apply automated linter fixes
megalinter-bot Apr 2, 2024
2dc79ff
Apply suggestions from code review
Gerhardsa0 Apr 8, 2024
210a2a3
Update src/safeds/data/tabular/containers/_time_series.py
Gerhardsa0 Apr 8, 2024
3ea40e3
Update src/safeds/data/tabular/containers/_time_series.py
Gerhardsa0 Apr 8, 2024
af1104b
Update src/safeds/data/tabular/containers/_time_series.py
Gerhardsa0 Apr 8, 2024
0d5196f
Merge branch 'main' into 570-feat-add-arima-model
Gerhardsa0 Apr 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,546 changes: 838 additions & 708 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ seaborn = "^0.13.0"
torch = {version = "^2.2.0", source = "torch_cuda121"}
torchvision = {version = "^0.17.0", source = "torch_cuda121"}
xxhash = "^3.4.1"
statsmodels = "^0.14.1"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.2.1,<9.0.0"
Expand Down
165 changes: 164 additions & 1 deletion src/safeds/data/tabular/containers/_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Any


Expand All @@ -29,6 +30,58 @@ class TimeSeries(Table):
# ------------------------------------------------------------------------------------------------------------------
# Creation
# ------------------------------------------------------------------------------------------------------------------

@staticmethod
def timeseries_from_csv_file(
path: str | Path,
target_name: str,
time_name: str,
feature_names: list[str] | None = None,
) -> TimeSeries:
"""
Read data from a CSV file into a table.

Parameters
----------
path :
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
The path to the CSV file.

target_name:
The name of the target column

time_name :
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
The name of the time column

feature_names:
The name(s) of the column(s)

Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
table :
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
The time series created from the CSV file.

Raises
------
FileNotFoundError
If the specified file does not exist.
WrongFileExtensionError
If the file is not a csv file.
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
UnknownColumnNameError
If target_name or time_name matches none of the column names.
Value Error
If one column is target and feature
Value Error
If one column is time and feature

"""
return TimeSeries._from_table(
Table.from_csv_file(path=path),
target_name=target_name,
time_name=time_name,
feature_names=feature_names,
)

@staticmethod
def _from_tagged_table(
tagged_table: TaggedTable,
Expand Down Expand Up @@ -128,12 +181,17 @@ def _from_table(

if target_name not in table.column_names:
raise UnknownColumnNameError([target_name])

result = object.__new__(TimeSeries)
result._data = table._data

result._schema = table._schema
result._time = table.get_column(time_name)
result._target = table.get_column(target_name)
# empty Columns have dtype Object
if len(result._time._data) == 0:
result._time._data = pd.Series(name=time_name)
if len(result.target._data) == 0:
result.target._data = pd.Series(name=target_name)
if feature_names is None or len(feature_names) == 0:
result._feature_names = []
result._features = Table()
Expand Down Expand Up @@ -203,6 +261,11 @@ def __init__(
raise UnknownColumnNameError([time_name])
self._time: Column = _data.get_column(time_name)
self._target: Column = _data.get_column(target_name)
# empty Columns have dtype Object
if len(self._time._data) == 0:
self._time._data = pd.Series(name=time_name)
if len(self.target._data) == 0:
self.target._data = pd.Series(name=target_name)

def __eq__(self, other: object) -> bool:
"""
Expand All @@ -216,6 +279,7 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
if self is other:
return True

return (
self.time == other.time
and self.target == other.target
Expand Down Expand Up @@ -1113,3 +1177,102 @@ def plot_scatterplot(
buffer.seek(0)
self._data = self._data.reset_index()
return Image.from_bytes(buffer.read())

def split_rows(self, percentage_in_first: float) -> tuple[TimeSeries, TimeSeries]:
"""
Split the table into two new tables.

The original table is not modified.
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
percentage_in_first:
The desired size of the first table in percentage to the given table; must be between 0 and 1.
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
result : (TimeSeries, TimeSeries)
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
A tuple containing the two resulting tables. The first table has the specified size, the second table
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
contains the rest of the data.

Raises
------
ValueError:
if the 'percentage_in_first' is not between 0 and 1.

Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = TimeSeries({"time":[0, 1, 2, 3, 4], "temperature": [10, 15, 20, 25, 30], "sales": [54, 74, 90, 206, 210]}, time_name="time", target_name="sales")
>>> slices = table.split_rows(0.4)
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
>>> slices[0]
time temperature sales
0 0 10 54
1 1 15 74
>>> slices[1]
time temperature sales
0 2 20 90
1 3 25 206
2 4 30 210
"""
temp = self._as_table()
t1, t2 = temp.split_rows(percentage_in_first=percentage_in_first)
return (
TimeSeries._from_table(
t1,
time_name=self.time.name,
target_name=self._target.name,
feature_names=self._feature_names,
),
TimeSeries._from_table(
t2,
time_name=self.time.name,
target_name=self._target.name,
feature_names=self._feature_names,
),
)

def plot_compare_time_series(self, time_series: list[TimeSeries]) -> Image:
"""
Plot the given time series targets along the time on the x-axis.

Parameters
----------
time_series:
A list of time series to be plotted.

Returns
-------
plot:
A plot with all the time series targets plotted by the time on the x-axis.

Raises
------
NonNumericColumnError
if the target column contains non numerical values
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved

"""
if not self._target.type.is_numeric():
raise NonNumericColumnError("The time series plotted column contains non-numerical columns.")
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved

data = pd.DataFrame()
data[self.time.name] = self.time._data
data[self.target.name] = self.target._data
for index, ts in enumerate(time_series):
if not ts.target.type.is_numeric():
raise NonNumericColumnError("The time series plotted column contains non-numerical columns.")
data[ts.target.name + " " + str(index)] = ts.target._data
fig = plt.figure()

data = pd.melt(data, [self.time.name])
sns.lineplot(x=self.time.name, y="value", hue="variable", data=data)
plt.title("Multiple Series Plot")
plt.xlabel("Time")

plt.tight_layout()
buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
self._data = self._data.reset_index()
return Image.from_bytes(buffer.read())
2 changes: 2 additions & 0 deletions src/safeds/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DatasetMissesFeaturesError,
LearningError,
ModelNotFittedError,
NonTimeSeriesError,
PredictionError,
UntaggedTableError,
)
Expand Down Expand Up @@ -56,6 +57,7 @@
"DatasetMissesFeaturesError",
"LearningError",
"ModelNotFittedError",
"NonTimeSeriesError",
"PredictionError",
"UntaggedTableError",
# Other
Expand Down
12 changes: 12 additions & 0 deletions src/safeds/exceptions/_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ def __init__(self) -> None:
" features and which are the target to predict.\nUse Table.tag_column() to create a tagged table."
),
)


class NonTimeSeriesError(Exception):
"""Raised when a table is used instead of a TimeSeries in a regression or classification."""

def __init__(self) -> None:
super().__init__(
(
"This method needs a time series.\nA time series is a table that additionally knows which columns are"
" time and which are the target to predict.\n"
),
)
2 changes: 2 additions & 0 deletions src/safeds/ml/classical/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Models for regression tasks."""

from ._ada_boost import AdaBoostRegressor
from ._arima import ArimaModelRegressor
from ._decision_tree import DecisionTreeRegressor
from ._elastic_net_regression import ElasticNetRegressor
from ._gradient_boosting import GradientBoostingRegressor
Expand All @@ -14,6 +15,7 @@

__all__ = [
"AdaBoostRegressor",
"ArimaModelRegressor",
"DecisionTreeRegressor",
"ElasticNetRegressor",
"GradientBoostingRegressor",
Expand Down
Loading