Skip to content

Commit

Permalink
Added TextSlice; Lark can now parse/lex a text-slice
Browse files Browse the repository at this point in the history
Based on previous PR by MegaIng
  • Loading branch information
erezsh committed Aug 19, 2024
1 parent bd70893 commit 0b131a0
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 38 deletions.
3 changes: 2 additions & 1 deletion lark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .lark import Lark
from .lexer import Token
from .tree import ParseTree, Tree
from .utils import logger
from .utils import logger, TextSlice
from .visitors import Discard, Transformer, Transformer_NonRecursive, Visitor, v_args

__version__: str = "1.2.2"
Expand All @@ -33,6 +33,7 @@
"Discard",
"Transformer",
"Transformer_NonRecursive",
"TextSlice",
"Visitor",
"v_args",
)
8 changes: 4 additions & 4 deletions lark/lark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .parser_frontends import ParsingFrontend

from .exceptions import ConfigurationError, assert_config, UnexpectedInput
from .utils import Serialize, SerializeMemoizer, FS, logger
from .utils import Serialize, SerializeMemoizer, FS, logger, TextOrSlice
from .load_grammar import load_grammar, FromPackageLoader, Grammar, verify_used_files, PackageResource, sha256_digest
from .tree import Tree
from .common import LexerConf, ParserConf, _ParserArgType, _LexerArgType
Expand Down Expand Up @@ -598,7 +598,7 @@ def __repr__(self):
return 'Lark(open(%r), parser=%r, lexer=%r, ...)' % (self.source_path, self.options.parser, self.options.lexer)


def lex(self, text: str, dont_ignore: bool=False) -> Iterator[Token]:
def lex(self, text: TextOrSlice, dont_ignore: bool=False) -> Iterator[Token]:
"""Only lex (and postlex) the text, without parsing it. Only relevant when lexer='basic'
When dont_ignore=True, the lexer will return all tokens, even those marked for %ignore.
Expand All @@ -620,7 +620,7 @@ def get_terminal(self, name: str) -> TerminalDef:
"""Get information about a terminal"""
return self._terminals_dict[name]

def parse_interactive(self, text: Optional[str]=None, start: Optional[str]=None) -> 'InteractiveParser':
def parse_interactive(self, text: Optional[TextOrSlice]=None, start: Optional[str]=None) -> 'InteractiveParser':
"""Start an interactive parsing session.
Parameters:
Expand All @@ -634,7 +634,7 @@ def parse_interactive(self, text: Optional[str]=None, start: Optional[str]=None)
"""
return self.parser.parse_interactive(text, start=start)

def parse(self, text: str, start: Optional[str]=None, on_error: 'Optional[Callable[[UnexpectedInput], bool]]'=None) -> 'ParseTree':
def parse(self, text: TextOrSlice, start: Optional[str]=None, on_error: 'Optional[Callable[[UnexpectedInput], bool]]'=None) -> 'ParseTree':
"""Parse the given text, according to the options provided.
Parameters:
Expand Down
49 changes: 32 additions & 17 deletions lark/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .common import LexerConf
from .parsers.lalr_parser_state import ParserState

from .utils import classify, get_regexp_width, Serialize, logger
from .utils import classify, get_regexp_width, Serialize, logger, TextSlice, TextOrSlice
from .exceptions import UnexpectedCharacters, LexError, UnexpectedToken
from .grammar import TOKEN_DEFAULT_PRIORITY

Expand Down Expand Up @@ -289,7 +289,7 @@ def __eq__(self, other):

return self.char_pos == other.char_pos and self.newline_char == other.newline_char

def feed(self, token: Token, test_newline=True):
def feed(self, token: TextOrSlice, test_newline=True):
"""Consume a token and calculate the new line & column.
As an optional optimization, set test_newline=False if token doesn't contain a newline.
Expand Down Expand Up @@ -382,9 +382,9 @@ def _build_mres(self, terminals, max_size):
terminals = terminals[max_size:]
return mres

def match(self, text, pos):
def match(self, text: TextSlice, pos):
for mre in self._mres:
m = mre.match(text, pos)
m = mre.match(text.text, pos, text.end)
if m:
return m.group(0), m.lastgroup

Expand All @@ -394,6 +394,7 @@ def fullmatch(self, text: str) -> Optional[str]:
m = mre.fullmatch(text)
if m:
return m.lastgroup
return

def _regexp_has_newline(r: str):
r"""Expressions that may indicate newlines in a regexp:
Expand All @@ -413,20 +414,31 @@ class LexerState:

__slots__ = 'text', 'line_ctr', 'last_token'

text: str
text: TextSlice
line_ctr: LineCounter
last_token: Optional[Token]

def __init__(self, text: str, line_ctr: Optional[LineCounter]=None, last_token: Optional[Token]=None):
def __init__(self, text: TextSlice, line_ctr: Optional[LineCounter] = None, last_token: Optional[Token]=None):
if line_ctr is None:
line_ctr = LineCounter(b'\n' if isinstance(text.text, bytes) else '\n')

if text.start > 0:
# Advance the line-count until line_ctr.char_pos == text.start
line_ctr.feed(TextSlice(text.text, 0, text.start))

if not (text.start <= line_ctr.char_pos <= text.end):
raise ValueError("LineCounter.char_pos is out of bounds")

self.text = text
self.line_ctr = line_ctr or LineCounter(b'\n' if isinstance(text, bytes) else '\n')
self.line_ctr = line_ctr
self.last_token = last_token


def __eq__(self, other):
if not isinstance(other, LexerState):
return NotImplemented

return self.text is other.text and self.line_ctr == other.line_ctr and self.last_token == other.last_token
return self.text == other.text and self.line_ctr == other.line_ctr and self.last_token == other.last_token

def __copy__(self):
return type(self)(self.text, copy(self.line_ctr), self.last_token)
Expand All @@ -441,10 +453,13 @@ def __init__(self, lexer: 'Lexer', lexer_state: LexerState):
self.state = lexer_state

@classmethod
def from_text(cls, lexer: 'Lexer', text: str) -> 'LexerThread':
def from_text(cls, lexer: 'Lexer', text_or_slice: TextOrSlice) -> 'LexerThread':
text = TextSlice.cast_from(text_or_slice)
return cls(lexer, LexerState(text))

def lex(self, parser_state):
if self.state is None:
raise TypeError("Cannot lex: No text assigned to lexer state")
return self.lexer.lex(self.state, parser_state)

def __copy__(self):
Expand All @@ -465,9 +480,9 @@ class Lexer(ABC):
def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]:
return NotImplemented

def make_lexer_state(self, text):
def make_lexer_state(self, text: str):
"Deprecated"
return LexerState(text)
return LexerState(TextSlice.cast_from(text))


def _check_regex_collisions(terminal_to_regexp: Dict[TerminalDef, str], comparator, strict_mode, max_collisions_to_show=8):
Expand Down Expand Up @@ -569,7 +584,7 @@ def __init__(self, conf: 'LexerConf', comparator=None) -> None:

self._scanner = None

def _build_scanner(self):
def _build_scanner(self) -> Scanner:
terminals, self.callback = _create_unless(self.terminals, self.g_regex_flags, self.re, self.use_bytes)
assert all(self.callback.values())

Expand All @@ -580,26 +595,26 @@ def _build_scanner(self):
else:
self.callback[type_] = f

self._scanner = Scanner(terminals, self.g_regex_flags, self.re, self.use_bytes)
return Scanner(terminals, self.g_regex_flags, self.re, self.use_bytes)

@property
def scanner(self):
def scanner(self) -> Scanner:
if self._scanner is None:
self._build_scanner()
self._scanner = self._build_scanner()
return self._scanner

def match(self, text, pos):
return self.scanner.match(text, pos)

def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token:
line_ctr = lex_state.line_ctr
while line_ctr.char_pos < len(lex_state.text):
while line_ctr.char_pos < lex_state.text.end:
res = self.match(lex_state.text, line_ctr.char_pos)
if not res:
allowed = self.scanner.allowed_types - self.ignore_types
if not allowed:
allowed = {"<END-OF-FILE>"}
raise UnexpectedCharacters(lex_state.text, line_ctr.char_pos, line_ctr.line, line_ctr.column,
raise UnexpectedCharacters(lex_state.text.text, line_ctr.char_pos, line_ctr.line, line_ctr.column,
allowed=allowed, token_history=lex_state.last_token and [lex_state.last_token],
state=parser_state, terminals_by_name=self.terminals_by_name)

Expand Down
41 changes: 30 additions & 11 deletions lark/parser_frontends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, Optional, Collection, Union, TYPE_CHECKING

from .exceptions import ConfigurationError, GrammarError, assert_config
from .utils import get_regexp_width, Serialize
from .utils import get_regexp_width, Serialize, TextOrSlice, TextSlice
from .lexer import LexerThread, BasicLexer, ContextualLexer, Lexer
from .parsers import earley, xearley, cyk
from .parsers.lalr_parser import LALR_Parser
Expand All @@ -15,16 +15,31 @@
###{standalone

def _wrap_lexer(lexer_class):
future_interface = getattr(lexer_class, '__future_interface__', False)
if future_interface:
future_interface = getattr(lexer_class, '__future_interface__', 0)
if future_interface == 2:
return lexer_class
else:
class CustomLexerWrapper(Lexer):
elif future_interface == 1:
class CustomLexerWrapper1(Lexer):
def __init__(self, lexer_conf):
self.lexer = lexer_class(lexer_conf)
def lex(self, lexer_state, parser_state):
return self.lexer.lex(lexer_state.text)
return CustomLexerWrapper
if not lexer_state.text.is_complete_text():
raise TypeError("Interface=1 Custom Lexer don't support TextSlice")
lexer_state.text = lexer_state.text
return self.lexer.lex(lexer_state, parser_state)
return CustomLexerWrapper1
elif future_interface == 0:
class CustomLexerWrapper0(Lexer):
def __init__(self, lexer_conf):
self.lexer = lexer_class(lexer_conf)

def lex(self, lexer_state, parser_state):
if not lexer_state.text.is_complete_text():
raise TypeError("Interface=0 Custom Lexer don't support TextSlice")
return self.lexer.lex(lexer_state.text.text)
return CustomLexerWrapper0
else:
raise ValueError(f"Unknown __future_interface__ value {future_interface}, integer 0-2 expected")


def _deserialize_parsing_frontend(data, memo, lexer_conf, callbacks, options):
Expand Down Expand Up @@ -93,17 +108,21 @@ def _verify_start(self, start=None):
raise ConfigurationError("Unknown start rule %s. Must be one of %r" % (start, self.parser_conf.start))
return start

def _make_lexer_thread(self, text: str) -> Union[str, LexerThread]:
def _make_lexer_thread(self, text: Optional[TextOrSlice]) -> Union[TextOrSlice, LexerThread]:
cls = (self.options and self.options._plugins.get('LexerThread')) or LexerThread
return text if self.skip_lexer else cls.from_text(self.lexer, text)
return text if self.skip_lexer else cls(self.lexer, None) if text is None else cls.from_text(self.lexer, text)

def parse(self, text: Optional[TextOrSlice], start=None, on_error=None):
if self.lexer_conf.lexer_type in ("dynamic", "dynamic_complete"):
if isinstance(text, TextSlice) and not text.is_complete_text():
raise TypeError(f"Lexer {self.lexer_conf.lexer_type} does not support text slices.")

def parse(self, text: str, start=None, on_error=None):
chosen_start = self._verify_start(start)
kw = {} if on_error is None else {'on_error': on_error}
stream = self._make_lexer_thread(text)
return self.parser.parse(stream, chosen_start, **kw)

def parse_interactive(self, text: Optional[str]=None, start=None):
def parse_interactive(self, text: Optional[TextOrSlice]=None, start=None):
# TODO BREAK - Change text from Optional[str] to text: str = ''.
# Would break behavior of exhaust_lexer(), which currently raises TypeError, and after the change would just return []
chosen_start = self._verify_start(start)
Expand Down
2 changes: 1 addition & 1 deletion lark/parsers/lalr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parse(self, lexer, start, on_error=None):
if isinstance(e, UnexpectedCharacters):
# If user didn't change the character position, then we should
if p == s.line_ctr.char_pos:
s.line_ctr.feed(s.text[p:p+1])
s.line_ctr.feed(s.text.text[p:p+1])

try:
return e.interactive_parser.resume_parse()
Expand Down
47 changes: 47 additions & 0 deletions lark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
###{standalone
import sys, re
import logging
from dataclasses import dataclass
from typing import Generic, AnyStr

logger: logging.Logger = logging.getLogger("lark")
logger.addHandler(logging.StreamHandler())
Expand Down Expand Up @@ -158,6 +160,50 @@ def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
else:
return 0, int(MAXWIDTH)


@dataclass(frozen=True)
class TextSlice(Generic[AnyStr]):
text: AnyStr
start: int
end: int

def __post_init__(self):
if self.start < 0:
object.__setattr__(self, 'start', self.start + len(self.text))
assert self.start >=0

if self.end is None:
object.__setattr__(self, 'end', len(self.text))
elif self.end < 0:
object.__setattr__(self, 'end', self.end + len(self.text))
assert self.end <= len(self.text)

@classmethod
def cast_from(cls, text: Union[AnyStr, 'TextSlice[AnyStr]', None]) -> 'TextSlice[AnyStr]':
if isinstance(text, TextSlice):
return text

assert isinstance(text, (str, bytes)), text
return cls(text, 0, len(text))

def is_complete_text(self):
return self.start == 0 and self.end == len(self.text)

def start_from(self, pos: int):
return TextSlice(self.text, pos, self.end)

def __len__(self):
return self.end - self.start

def count(self, substr: AnyStr):
return self.text.count(substr, self.start, self.end)

def rindex(self, substr: AnyStr):
return self.text.rindex(substr, self.start, self.end)


TextOrSlice = Union[str, 'TextSlice']

###}


Expand Down Expand Up @@ -344,3 +390,4 @@ def __len__(self) -> int:

def __repr__(self):
return f"{type(self).__name__}({', '.join(map(repr,self))})"

17 changes: 15 additions & 2 deletions tests/test_lexer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import TestCase, main

from lark import Lark, Tree
from lark import Lark, Tree, TextSlice


class TestLexer(TestCase):
def setUp(self):
Expand All @@ -18,6 +19,18 @@ def test_basic(self):
res = list(p.lex("abc cba dd", dont_ignore=True))
assert res == list('abc cba dd')

def test_subset_lex(self):
p = Lark("""
start: "a" "b" "c" "d"
%ignore " "
""")

res = list(p.lex(TextSlice("xxxabc cba ddxx", 3, -2)))
assert res == list('abccbadd')

res = list(p.lex(TextSlice("aaaabc cba dddd", 3, -2)))
assert res == list('abccbadd')


if __name__ == '__main__':
main()
main()
Loading

0 comments on commit 0b131a0

Please sign in to comment.