Skip to content

Commit

Permalink
Various bugfixes and adjustments to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
erezsh committed Mar 23, 2022
1 parent c903d5b commit 3dd791d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
3 changes: 2 additions & 1 deletion preql/core/autocomplete.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def autocomplete(state, code, source='<autocomplete>'):
tree = autocomplete_tree(e.interactive_parser)
if tree:
try:
stmts = parser.TreeToAst(code_ref=(code, source)).transform(tree)
with context(code_ref=(code, source)):
stmts = parser.TreeToAst().transform(tree)
except pql_SyntaxError as e:
return {}

Expand Down
7 changes: 4 additions & 3 deletions preql/core/sql_import_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _restructure_result(t: T.struct, i):

@dp_type
def _restructure_result(_t: T.union[T.primitive, T.nulltype], i):
return next(i)
return _from_sql_primitive(next(i))


@dp_type
Expand Down Expand Up @@ -103,7 +103,7 @@ def _extract_primitive(res, expected):
except ValueError:
raise Signal.make(T.TypeError, None, f"Expected a single {expected}. Got: '{res.value}'")

return item
return _from_sql_primitive(item)

@dp_inst
def sql_result_to_python(res: T.bool):
Expand All @@ -116,7 +116,6 @@ def sql_result_to_python(res: T.bool):
def sql_result_to_python(res: T.int):
item = _extract_primitive(res, 'int')
if not isinstance(item, int):
breakpoint()
raise Signal.make(T.ValueError, None, f"Expected SQL to return an int. Instead got '{item}'")
return item

Expand Down Expand Up @@ -144,6 +143,8 @@ def _from_sql_primitive(p):
if isinstance(p, decimal.Decimal):
# TODO Needs different handling when we expect a decimal
return float(p)
elif isinstance(p, bytearray):
return p.decode()
return p

@dp_inst
Expand Down
3 changes: 2 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from preql import Preql, settings

SQLITE_URI = 'sqlite://:memory:'
POSTGRES_URI = 'postgres://postgres:qweqwe123@localhost/postgres'
# POSTGRES_URI = 'postgres://postgres:qweqwe123@localhost/postgres'
POSTGRES_URI = 'postgres:///postgres'
MYSQL_URI = 'mysql://erez:qweqwe123@localhost/preql_tests'
DUCK_URI = 'duck://:memory:'
BIGQUERY_URI = 'bigquery:///aeyeconsole'
Expand Down
15 changes: 9 additions & 6 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,9 @@ def test_nested_projections(self):
res = [{'a': {'item': 1}, 'b': [3, 4]}, {'a': {'item': 2}, 'b': [3, 4]}]
self.assertEqual(preql("joinall(a:[1,2], b:[3, 4]) {a => b}" ), res)

res = [{'b': 5, 'a': [1, 2]}]
self.assertEqual(preql("joinall(a:[1,2], b:[2, 3]) {a: a.item => b: sum(b.item)} {b => a}"), res)
expected = [{'b': 5.0, 'a': [2, 1]}]
res = preql("joinall(a:[1,2], b:[2, 3]) {a: a.item => b: sum(b.item)} order {^a} {b => a}")
self.assertEqual(res, expected)
# preql("joinall(a:[1,2], b:[2, 3]) {a: a.item => b: b.item} {count(b) => a}")

res = preql("one joinall(a:[1,2], b:[2, 3]) {a: a.item => b: count(b.item)} {b => a: count(a)}")
Expand All @@ -540,8 +541,10 @@ def test_nested_projections(self):
res1 = preql("joinall(ab: joinall(a:[1,2], b:[2,3]), c: [4,5]) {ab.a.item, ab.b.item, c}")
assert len(res1) == 8

res1 = preql("joinall(ab: joinall(a:[1,2], b:[2,3]), c: [4,5]) {ab {b: b.item, a: a.item}, c}[..1]")
self.assertEqual(res1.to_json(), [{'ab': {'b': 2, 'a': 1}, 'c': {'item': 4}}])
if not preql._interp.state.db.target in (mysql, ):
# TODO Error in MySQL. Should work!
res1 = preql("joinall(ab: joinall(a:[1,2], b:[2,3]), c: [4,5]) {ab {b: b.item, a: a.item}, c}[..1]")
self.assertEqual(res1.to_json(), [{'ab': {'b': 2, 'a': 1}, 'c': {'item': 4}}])

def test_nested2(self):
preql = self.Preql()
Expand Down Expand Up @@ -1472,7 +1475,7 @@ def test_builtins(self):
assert p('list(["Ab", "Aab"]{str_index("b", item)})') == [1, 2]
assert p('str_index("b", "Ab")') == 1

assert p('char(65)') == 'A'
assert p('char(65)') == 'A', p('char(65)')
assert p('char_ord("A")') == 65
assert p('char_range("a", "c")') == ['a', 'b', 'c']

Expand All @@ -1498,7 +1501,7 @@ def test_join_on(self):
p("""
A = [1, 3]
B = [1, 2]
res = leftjoin(a: A, b: B, $on: a.item > b.item)
res = leftjoin(a: A, b: B, $on: a.item > b.item) order {a.item, b.item}
""")

assert p.res == [
Expand Down

0 comments on commit 3dd791d

Please sign in to comment.