From cea0e9b34d5a08f7f33c405b54a4e9851a7787a3 Mon Sep 17 00:00:00 2001 From: Aahil Mehta Date: Fri, 4 Oct 2024 01:30:43 +0000 Subject: [PATCH 1/2] Fix embedding for list input shape --- keras/src/layers/core/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index 310596ce216..6600d72096f 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -146,7 +146,7 @@ def compute_mask(self, inputs, mask=None): return ops.not_equal(inputs, 0) def compute_output_shape(self, input_shape): - return input_shape + (self.output_dim,) + return (*input_shape, self.output_dim) def enable_lora( self, rank, a_initializer="he_uniform", b_initializer="zeros" From f0eb57089c7f22a129b12a467fca5abcae90274a Mon Sep 17 00:00:00 2001 From: Aahil Mehta Date: Sun, 13 Oct 2024 23:49:42 +0000 Subject: [PATCH 2/2] Allow passing custom data adapter --- keras/src/trainers/data_adapters/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keras/src/trainers/data_adapters/__init__.py b/keras/src/trainers/data_adapters/__init__.py index 3dc04b75498..b10e1d233a3 100644 --- a/keras/src/trainers/data_adapters/__init__.py +++ b/keras/src/trainers/data_adapters/__init__.py @@ -2,6 +2,7 @@ from keras.src.distribution import distribution_lib from keras.src.trainers.data_adapters import array_data_adapter +from keras.src.trainers.data_adapters import data_adapter from keras.src.trainers.data_adapters import py_dataset_adapter from keras.src.trainers.data_adapters.array_data_adapter import ArrayDataAdapter from keras.src.trainers.data_adapters.generator_data_adapter import ( @@ -23,6 +24,10 @@ def get_data_adapter( shuffle=False, class_weight=None, ): + # Allow passing a custom data adapter. + if isinstance(x, data_adapter.DataAdapter): + return x + # Check for multi-process/worker distribution. Since only tf.dataset # is supported at the moment, we will raise error if the inputs fail # the type check