Skip to content

Commit

Permalink
Add optional Arrow deserialization support (#2632)
Browse files Browse the repository at this point in the history
  • Loading branch information
pquentin committed Aug 7, 2024
1 parent d4eb86e commit ef3da6e
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/guide/configuration.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ The calculation is equal to `min(dead_node_backoff_factor * (2 ** (consecutive_f
[[serializer]]
=== Serializers

Serializers transform bytes on the wire into native Python objects and vice-versa. By default the client ships with serializers for `application/json`, `application/x-ndjson`, `text/*`, and `application/mapbox-vector-tile`.
Serializers transform bytes on the wire into native Python objects and vice-versa. By default the client ships with serializers for `application/json`, `application/x-ndjson`, `text/*`, `application/vnd.apache.arrow.stream` and `application/mapbox-vector-tile`.

You can define custom serializers via the `serializers` parameter:

Expand Down
34 changes: 34 additions & 0 deletions elasticsearch/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
_OrjsonSerializer = None # type: ignore[assignment,misc]


try:
import pyarrow as pa

__all__.append("PyArrowSerializer")
except ImportError:
pa = None


class JsonSerializer(_JsonSerializer):
mimetype: ClassVar[str] = "application/json"

Expand Down Expand Up @@ -114,6 +122,29 @@ def dumps(self, data: bytes) -> bytes:
raise SerializationError(f"Cannot serialize {data!r} into a MapBox vector tile")


if pa is not None:

class PyArrowSerializer(Serializer):
"""PyArrow serializer for deserializing Arrow Stream data."""

mimetype: ClassVar[str] = "application/vnd.apache.arrow.stream"

def loads(self, data: bytes) -> pa.Table:
try:
with pa.ipc.open_stream(data) as reader:
return reader.read_all()
except pa.ArrowException as e:
raise SerializationError(
message=f"Unable to deserialize as Arrow stream: {data!r}",
errors=(e,),
)

def dumps(self, data: Any) -> bytes:
raise SerializationError(
message="Elasticsearch does not accept Arrow input data"
)


DEFAULT_SERIALIZERS: Dict[str, Serializer] = {
JsonSerializer.mimetype: JsonSerializer(),
MapboxVectorTileSerializer.mimetype: MapboxVectorTileSerializer(),
Expand All @@ -122,6 +153,9 @@ def dumps(self, data: bytes) -> bytes:
CompatibilityModeNdjsonSerializer.mimetype: CompatibilityModeNdjsonSerializer(),
}

if pa is not None:
DEFAULT_SERIALIZERS[PyArrowSerializer.mimetype] = PyArrowSerializer()

# Alias for backwards compatibility
JSONSerializer = JsonSerializer

Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def lint(session):
session.run("flake8", *SOURCE_FILES)
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)

session.install(".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV)
session.install(".[async,requests,orjson,pyarrow,vectorstore_mmr]", env=INSTALL_ENV)

# Run mypy on the package and then the type examples separately for
# the two different mypy use-cases, ourselves and our users.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
async = ["aiohttp>=3,<4"]
requests = ["requests>=2.4.0, !=2.32.2, <3.0.0"]
orjson = ["orjson>=3"]
pyarrow = ["pyarrow>=1"]
# Maximal Marginal Relevance (MMR) for search results
vectorstore_mmr = ["numpy>=1", "simsimd>=3"]
dev = [
Expand All @@ -69,6 +70,7 @@ dev = [
"orjson",
"numpy",
"simsimd",
"pyarrow",
"pandas",
"mapbox-vector-tile",
]
Expand Down
2 changes: 2 additions & 0 deletions test_elasticsearch/test_client/test_deprecated_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class CustomSerializer(JsonSerializer):
"application/x-ndjson",
"application/json",
"text/*",
"application/vnd.apache.arrow.stream",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
}
Expand All @@ -154,6 +155,7 @@ class CustomSerializer(JsonSerializer):
"application/x-ndjson",
"application/json",
"text/*",
"application/vnd.apache.arrow.stream",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
"application/cbor",
Expand Down
3 changes: 3 additions & 0 deletions test_elasticsearch/test_client/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class CustomSerializer:
"application/json",
"text/*",
"application/x-ndjson",
"application/vnd.apache.arrow.stream",
"application/vnd.mapbox-vector-tile",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
Expand Down Expand Up @@ -121,6 +122,7 @@ class CustomSerializer:
"application/json",
"text/*",
"application/x-ndjson",
"application/vnd.apache.arrow.stream",
"application/vnd.mapbox-vector-tile",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
Expand All @@ -140,6 +142,7 @@ class CustomSerializer:
"application/json",
"text/*",
"application/x-ndjson",
"application/vnd.apache.arrow.stream",
"application/vnd.mapbox-vector-tile",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
Expand Down
27 changes: 26 additions & 1 deletion test_elasticsearch/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime
from decimal import Decimal

import pyarrow as pa
import pytest

try:
Expand All @@ -31,7 +32,12 @@

from elasticsearch import Elasticsearch
from elasticsearch.exceptions import SerializationError
from elasticsearch.serializer import JSONSerializer, OrjsonSerializer, TextSerializer
from elasticsearch.serializer import (
JSONSerializer,
OrjsonSerializer,
PyArrowSerializer,
TextSerializer,
)

requires_numpy_and_pandas = pytest.mark.skipif(
np is None or pd is None, reason="Test requires numpy and pandas to be available"
Expand Down Expand Up @@ -157,6 +163,25 @@ def test_serializes_pandas_category(json_serializer):
assert b'{"d":[1,2,3]}' == json_serializer.dumps({"d": cat})


def test_pyarrow_loads():
data = [
pa.array([1, 2, 3, 4]),
pa.array(["foo", "bar", "baz", None]),
pa.array([True, None, False, True]),
]
batch = pa.record_batch(data, names=["f0", "f1", "f2"])
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)

serializer = PyArrowSerializer()
assert serializer.loads(sink.getvalue()).to_pydict() == {
"f0": [1, 2, 3, 4],
"f1": ["foo", "bar", "baz", None],
"f2": [True, None, False, True],
}


def test_json_raises_serialization_error_on_dump_error(json_serializer):
with pytest.raises(SerializationError):
json_serializer.dumps(object())
Expand Down

0 comments on commit ef3da6e

Please sign in to comment.