Skip to content

Commit

Permalink
Fix error with _base
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Oct 9, 2024
1 parent 89d58f8 commit 6d27d37
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 4 deletions.
44 changes: 44 additions & 0 deletions plugins/ibis/plugin_test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,47 @@ def test_execute_complex_query_sqldb_auto_schema(db):
expected = [f"is a test {i}" for i in range(99, 89, -1)]
cur_this = [r["this"] for r in cur]
assert sorted(cur_this) == sorted(expected)


def test_select_using_ids(db):
db.cfg.auto_schema = True

table = db["documents"]
table.insert(
[Document({"this": f"is a test {i}", "id": str(i)}) for i in range(4)]
).execute()

basic_select = db['documents'].select()

assert len(basic_select.tolist()) == 4
assert len(basic_select.select_using_ids(['1', '2']).tolist()) == 2


def test_select_using_ids_of_outputs(db):
from superduper import model

@model
def my_func(x):
return x + ' ' + x

db.cfg.auto_schema = True

table = db["documents"]
table.insert(
[Document({"this": f"is a test {i}", "id": str(i)}) for i in range(4)]
).execute()

listener = my_func.to_listener(key='this', select=db['documents'].select())
db.apply(listener)

q1 = db[listener.outputs].select()
r1 = q1.tolist()

assert len(r1) == 4

ids = [x['id'] for x in r1]

q2 = q1.select_using_ids(ids[:2])
r2 = q2.tolist()

assert len(r2) == 2
8 changes: 7 additions & 1 deletion superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,10 @@ def __repr__(self):
for i, doc in enumerate(docs):
doc_string = str(doc)
if isinstance(doc, Document):
doc_string = str(doc.unpack())
r = doc.unpack()
if '_base' in r:
r = r['_base']
doc_string = str(r)
output = output.replace(f'documents[{i}]', doc_string)
return output

Expand Down Expand Up @@ -428,6 +431,9 @@ def _encode_or_unpack_args(self, r, db, method='encode', parent=None):
out.pop_builds()
out.pop_files()
out.pop_blobs()

if '_base' in out:
return out['_base']
return out

if isinstance(r, (tuple, list)):
Expand Down
1 change: 1 addition & 0 deletions superduper/backends/base/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _consume_event_type(event_type, ids, table, db: 'Datalayer'):
jobs=jobs,
event_type=event_type,
)

for job in sub_jobs:
job_lookup[component.uuid][job.method] = job.job_id
jobs += sub_jobs
Expand Down
2 changes: 2 additions & 0 deletions superduper/backends/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def submit(self, job: Job) -> str:
:param dependencies: List of `job_ids`
"""
args, kwargs = job.get_args_kwargs(self.futures[job.context])

component = self.db.load(uuid=job.uuid)
self.db.metadata.update_job(job.identifier, 'status', 'running')

Expand All @@ -71,6 +72,7 @@ def submit(self, job: Job) -> str:
except Exception as e:
self.db.metadata.update_job(job.identifier, 'status', 'failed')
raise e

self.db.metadata.update_job(job.identifier, 'status', 'success')
self.futures[job.context][job.job_id] = output
assert job.job_id is not None
Expand Down
4 changes: 1 addition & 3 deletions test/integration/usecase/test_vector_index.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import typing as t

import pytest

if t.TYPE_CHECKING:
from superduper.base.datalayer import Datalayer

from test.utils.usecase.vector_search import add_data, build_vector_index


@pytest.mark.skip
# @pytest.mark.skip
def test_vector_index(db: "Datalayer"):
def check_result(out, sample_data):
scores = out.scores
Expand Down

0 comments on commit 6d27d37

Please sign in to comment.