Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization problem with GridSearchCV #6046

Open
hchekired opened this issue Aug 20, 2024 · 8 comments
Open

Serialization problem with GridSearchCV #6046

hchekired opened this issue Aug 20, 2024 · 8 comments

Comments

@hchekired
Copy link

hchekired commented Aug 20, 2024

Hello, I am a beginner with GPU accelerated computing and I can’t find what is wrong with my code. I am getting this serialization error and don’t understand why.

import numpy as np
import cudf
from dask.distributed import Client
from sklearn.metrics import classification_report
import pandas as pd
from dask_cuda import LocalCUDACluster
from cuml.dask.ensemble import RandomForestClassifier as cuRF
import dask_cudf
from cuml.dask.common.utils import persist_across_workers
import pickle
import cloudpickle
import dask_ml.model_selection as dcv

def generate_synthetic_data(n_samples=10000):
    np.random.seed(42)
    Y = np.random.randn(n_samples)
    A = np.random.randint(0, 5, size=n_samples)
    B = np.random.randint(0, 5, size=n_samples)
    C = np.random.randint(1, 3, size=n_samples)
    DATE = pd.date_range(start='1/1/2022', periods=n_samples, freq='min')

    data = {
        'A': A,
        'B': B,
        'C': C,
        'DATE': DATE,
        'Y': Y,
    }
    return pd.DataFrame(data)

def main():
    # Initialize Dask client for GPU with LocalCUDACluster
    cluster = LocalCUDACluster()
    client = Client(cluster)

    # Load and preprocess data
    df_data = generate_synthetic_data()

    # Data preprocessing
    df_data['DATE'] = pd.to_datetime(df_data['DATE'], errors='coerce')
    df_data.fillna(0, inplace=True)
    df_data['C'] = df_data['C'].astype(float)
    df_data.drop_duplicates(inplace=True)
    df_data = df_data.loc[df_data['C'] == 1]

    df_data['Y_category'] = df_data['Y'].apply(lambda x: 'over 0' if x > 0 else ('under 0' if x < 0 else 'equal to 0'))
    df_encoded = df_data.drop(columns=['Y'])

    label_mapping = {'over 0': 2, 'under 0': 1, 'equal to 0': 0}
    df_encoded['Y_category'] = df_encoded['Y_category'].map(label_mapping)
    df_encoded.sort_values(by='DATE', inplace=True)

    df_encoded = cudf.DataFrame.from_pandas(df_encoded)

    # Split data into features and target
    X = df_encoded.drop(columns=['DATE', 'Y_category']).astype('float32')
    y = df_encoded['Y_category'].astype('int32')
    split_point = int(len(df_encoded) * 0.8)
    X_train, X_test = X.iloc[:split_point], X.iloc[split_point:]
    y_train, y_test = y.iloc[:split_point], y.iloc[split_point:]

    # Balance the classes using undersampling
    y_train_counts = y_train.value_counts().to_pandas()
    min_samples = y_train_counts.min()

    sampled_indices = []
    for label in y_train_counts.index:
        indices = y_train[y_train == label].index.to_pandas().to_series()
        sampled = indices.sample(n=min_samples, random_state=42).tolist()
        sampled_indices.extend(sampled)

    sampled_indices = np.array(sampled_indices)

    # Ensure indices are unique and within bounds
    sampled_indices = np.unique(sampled_indices)
    sampled_indices = sampled_indices[sampled_indices < len(X_train)]

    X_train_balanced = X_train.iloc[sampled_indices]
    y_train_balanced = y_train.iloc[sampled_indices]

    # Convert to Dask DataFrame directly
    X_train_dask = dask_cudf.from_cudf(X_train_balanced, npartitions=10).persist(optimize_graph=True)
    y_train_dask = dask_cudf.from_cudf(y_train_balanced, npartitions=10).persist(optimize_graph=True)

    X_train_dask, y_train_dask = persist_across_workers(client,
                                                      [X_train_dask,
                                                       y_train_dask])

    #Define the parameter grid
    param_grid = {
        'max_depth': [10, 20, 30],
        'max_features': [0.1, 0.5, 0.75, "auto"],
        'n_estimators': [10, 20, 30]
    }

    # Initialize and fit the model using GridSearchCV
    model_rf = cuRF(random_state=42)
    grid_search = dcv.GridSearchCV(model_rf, param_grid, cv=5, scoring='f1_weighted')
    grid_search.fit(X_train_dask, y_train_dask)  # Fit with Dask arrays

    # Train the model with the best parameters
    best_rf = cuRF(**grid_search.best_params_, random_state=42)
    best_rf.fit(X_train_dask, y_train_dask)

    # Predict on the test set
    X_test_dask = dask_cudf.from_cudf(X_test, npartitions=1).to_dask_array(lengths=True)
    y_pred_best = best_rf.predict(X_test_dask)

    # Evaluate the model
    report_best = classification_report(y_test.to_pandas(), y_pred_best.compute().get())
    print(report_best)

if __name__ == "__main__":
    main()

The error I get is this

/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/dask_expr/_collection.py:301: UserWarning: Dask annotations {'workers': ['tcp://127.0.0.1:45649']} detected. Annotations will be ignored when using query-planning.
  warnings.warn(
/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cuml/internals/api_decorators.py:344: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams=1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  return func(**kwargs)
2024-08-20 14:17:44,015 - distributed.protocol.pickle - ERROR - Failed to serialize <ToPickle: HighLevelGraph with 1 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f18e4fadff0>
 0. 139744897497344
>.
Traceback (most recent call last):
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 63, in dumps
    result = pickle.dumps(x, **dump_kwargs)
_pickle.PicklingError: Can't pickle <function _concat at 0x7f19b0cfe050>: it's not the same object as dask.dataframe.core._concat

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 68, in dumps
    pickler.dump(x)
_pickle.PicklingError: Can't pickle <function _concat at 0x7f19b0cfe050>: it's not the same object as dask.dataframe.core._concat

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 81, in dumps
    result = cloudpickle.dumps(x, **dump_kwargs)
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cuml/dask/common/base.py", line 60, in __getstate__
    internal_model = self._get_internal_model().result()
AttributeError: 'NoneType' object has no attribute 'result'
Traceback (most recent call last):
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 63, in dumps
    result = pickle.dumps(x, **dump_kwargs)
_pickle.PicklingError: Can't pickle <function _concat at 0x7f19b0cfe050>: it's not the same object as dask.dataframe.core._concat

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 68, in dumps
    pickler.dump(x)
_pickle.PicklingError: Can't pickle <function _concat at 0x7f19b0cfe050>: it's not the same object as dask.dataframe.core._concat

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 366, in serialize
    header, frames = dumps(x, context=context) if wants_context else dumps(x)
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 78, in pickle_dumps
    frames[0] = pickle.dumps(
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 81, in dumps
    result = cloudpickle.dumps(x, **dump_kwargs)
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cuml/dask/common/base.py", line 60, in __getstate__
    internal_model = self._get_internal_model().result()
AttributeError: 'NoneType' object has no attribute 'result'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/c/Users/user/PythonProjects/Snowflake/Stats/Clean RF - Forums.py", line 164, in <module>
    main()
  File "/mnt/c/Users/user/PythonProjects/Snowflake/Stats/Clean RF - Forums.py", line 149, in main
    grid_search.fit(X_train_dask, y_train_dask)  # Fit with Dask arrays
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/dask_ml/model_selection/_search.py", line 1266, in fit
    futures = scheduler(
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/client.py", line 3456, in get
    futures = self._graph_to_futures(
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/client.py", line 3351, in _graph_to_futures
    header, frames = serialize(ToPickle(dsk), on_error="raise")
  File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 392, in serialize
    raise TypeError(msg, str_x) from exc
TypeError: ('Could not serialize object of type HighLevelGraph', '<ToPickle: HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f18e4fadff0>\n 0. 139744897497344\n>')
2024-08-20 14:17:44,026 - distributed.worker.state_machine - WARNING - Async instruction for <Task cancelled name="execute(('frompandas-a91f82e90b4590ee2b57246953d5e528', 7))" coro=<Worker.execute() done, defined at /home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/worker_state_machine.py:3615>> ended with CancelledError
2024-08-20 14:17:44,027 - distributed.scheduler - WARNING - Removing worker 'tcp://127.0.0.1:45649' caused the cluster to lose already computed task(s), which will be recomputed elsewhere: {('frompandas-9de00375790a2a30c476ba68f4ea2723', 8), '_construct_rf-44891487-927e-43da-8d8d-2bd8190520f0', ('frompandas-760adf4c445d5f36154f13a21c47df23', 9), ('frompandas-9de00375790a2a30c476ba68f4ea2723', 7), ('frompandas-760adf4c445d5f36154f13a21c47df23', 8), ('frompandas-9de00375790a2a30c476ba68f4ea2723', 9)} (stimulus_id='handle-worker-cleanup-1724177864.0271368')

I created a case on dask forum but they told me to put my problem here since it seemed to be caused by an incompatibility between dask-cuda and dask-ml.

Here are the info on the system I use:

Python version: 3.10.9 | packaged by conda-forge | (main, Feb 2 2023, 20:20:04) [GCC 11.3.0]
CUDA version:
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0

cuDF version: 24.08.00a405
cuML version: 24.08.00a50
Dask version: 2024.7.1
CUDA version (nvidia-smi):
Tue Aug 20 14:17:43 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.106 Driver Version: 552.86 CUDA Version: 12.4 |

Also, I am using a WSL2 environment.

Thanks a lot for your help!!

@hchekired
Copy link
Author

I just changed version to
cuDF version: 24.08.02
cuML version: 24.08.00

but it still does not work

@raydouglass raydouglass transferred this issue from rapidsai/dask-cuda Aug 26, 2024
@madsbk
Copy link
Member

madsbk commented Aug 26, 2024

Thanks @hchekired, I have moved the issue to cuml. It looks like a cuml issue, for some reason, _get_internal_model() returns None when dask uses cloudpickle.dumps() to serialize cuml's BaseEstimator.

@dantegd
Copy link
Member

dantegd commented Aug 26, 2024

I won't be around for the next couple of weeks, but @viclafargue will take a look here

@hchekired
Copy link
Author

Hello @viclafargue, hope everything is going well for you. Do you have an idea where the problem comes from?

Thanks.

@viclafargue
Copy link
Contributor

viclafargue commented Sep 3, 2024

Hello @hchekired, sorry for the late reply. Thanks for opening the issue. I could reproduce it successfully. It looks like the GridSearchCV estimator is serializing the estimator prior to training which causes a bug. I will open a PR to fix the serialization of MNMG estimators prior to training. There is also another issue that I am looking into to make things work. However, I would recommend to either use sklearn's GridSearchCV with cuML dask estimators, or (if a single GPU can handle it) dask-ml's GridSearchCV with local cuML estimators.

@hchekired
Copy link
Author

Thanks for the reply, let me know if find something else to make things work.

Thanks

@viclafargue
Copy link
Contributor

I fixed the issue that prevented serialization prior to training. But again, I am not quite sure if it is a good idea to use a Dask estimator with a Dask GridSearchCV. Maybe you should try either one of them in Dask and the other without. I will look into solving this when I have more time.

@hchekired
Copy link
Author

Hello @viclafargue, thanks for fixing the issue, how can I have access to the corrected version?

Also, why it is not a good idea to us a Dask estimator with a Dask GridSearchCV?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants