Skip to content

Commit

Permalink
allow to quote % in non strict mode #21
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolay Kim committed Feb 7, 2017
1 parent 18761a5 commit 96d36c6
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 31 deletions.
2 changes: 1 addition & 1 deletion tests/test_quoting.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_quate_broken_unicode(quote):
def test_quate_ignore_broken_unicode(quote):
s = quote('j\x1a\udcf4q\udcda/\udc97g\udcee\udccb\x0ch\udccb'
'\x18\udce4v\x1b\udce2\udcce\udccecom/y\udccepj\x16',
errors='ignore')
strict=False)

assert s == 'j%1Aq%2Fg%0Ch%18v%1Bcom%2Fypj%16'

Expand Down
34 changes: 19 additions & 15 deletions yarl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ class URL:
# / path-noscheme
# / path-empty
# absolute-URI = scheme ":" hier-part [ "?" query ]
__slots__ = ('_cache', '_val')
__slots__ = ('_cache', '_val', '_strict')

def __new__(cls, val='', *, encoded=False):
def __new__(cls, val='', *, encoded=False, strict=False):
if isinstance(val, URL):
return val
else:
return super(URL, cls).__new__(cls)

def __init__(self, val='', *, encoded=False):
def __init__(self, val='', *, encoded=False, strict=False):
if isinstance(val, URL):
return
if isinstance(val, str):
Expand All @@ -154,6 +154,8 @@ def __init__(self, val='', *, encoded=False):
else:
raise TypeError("Constructor parameter should be str")

self._strict = strict

if not encoded:
if not val[1]: # netloc
netloc = ''
Expand Down Expand Up @@ -185,10 +187,10 @@ def __init__(self, val='', *, encoded=False):
val = SplitResult(
val[0], # scheme
netloc,
_quote(val[2], safe='@:', protected='/'),
_quote(val[2], safe='@:', protected='/', strict=strict),
query=_quote(val[3], safe='=+&?/:@',
protected=PROTECT_CHARS, qs=True),
fragment=_quote(val[4], safe='?/:@'))
protected=PROTECT_CHARS, qs=True, strict=strict),
fragment=_quote(val[4], safe='?/:@', strict=strict))

self._val = val
self._cache = {}
Expand Down Expand Up @@ -234,7 +236,7 @@ def __gt__(self, other):
return self._val > other._val

def __truediv__(self, name):
name = _quote(name, safe=':@', protected='/')
name = _quote(name, safe=':@', protected='/', strict=self._strict)
if name.startswith('/'):
raise ValueError("Appending path "
"starting from slash is forbidden")
Expand Down Expand Up @@ -645,7 +647,7 @@ def with_port(self, port):
def with_path(self, path, encoded=False):
"""Return a new URL with path replaced."""
if not encoded:
path=_quote(path, safe='@:', protected='/')
path=_quote(path, safe='@:', protected='/', strict=self._strict)
return URL(self._val._replace(path=path), encoded=True)

def with_query(self, *args, **kwargs):
Expand Down Expand Up @@ -675,7 +677,7 @@ def with_query(self, *args, **kwargs):
if query is None:
query = ''
elif isinstance(query, Mapping):
quoter = partial(_quote, safe='/?:@', qs=True)
quoter = partial(_quote, safe='/?:@', qs=True, strict=self._strict)
lst = []
for k, v in query.items():
if isinstance(v, str):
Expand All @@ -689,12 +691,13 @@ def with_query(self, *args, **kwargs):
query = '&'.join(lst)
elif isinstance(query, str):
query = _quote(query, safe='/?:@',
protected=PROTECT_CHARS, qs=True)
protected=PROTECT_CHARS,
qs=True, strict=self._strict)
elif isinstance(query, (bytes, bytearray, memoryview)):
raise TypeError("Invalid query type: bytes, bytearray and "
"memoryview are forbidden")
elif isinstance(query, Sequence):
quoter = partial(_quote, safe='/?:@', qs=True)
quoter = partial(_quote, safe='/?:@', qs=True, strict=self._strict)
query = '&'.join(quoter(k)+'='+quoter(v)
for k, v in query)
else:
Expand Down Expand Up @@ -722,8 +725,8 @@ def update_query(self, *args, **kwargs):
lambda x: x.split('=', 1),
_quote(new_query,
safe='/?:@', protected=PROTECT_CHARS,
qs=True).lstrip("?").split("&")
)
qs=True, strict=self._strict).lstrip("?").split("&")
)
)

else:
Expand All @@ -747,8 +750,9 @@ def with_fragment(self, fragment):
fragment = ''
elif not isinstance(fragment, str):
raise TypeError("Invalid fragment type")
return URL(self._val._replace(fragment=_quote(fragment, safe='?/:@')),
encoded=True)
return URL(self._val._replace(
fragment=_quote(fragment, safe='?/:@', strict=self._strict)),
encoded=True)

def with_name(self, name):
"""Return a new URL with name (last part of path) replaced.
Expand Down
38 changes: 27 additions & 11 deletions yarl/_quoting.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ cdef inline int _from_hex(Py_UCS4 v):
return -1


def _quote(val, *, str safe='', str protected='', bint qs=False, errors='strict'):
def _quote(val, *, str safe='', str protected='', bint qs=False, bint strict=True):
if val is None:
return None
if not isinstance(val, str):
raise TypeError("Argument should be str")
if not val:
return ''
return _do_quote(<str>val, safe, protected, qs, errors)
return _do_quote(<str>val, safe, protected, qs, strict)


cdef str _do_quote(str val, str safe, str protected, bint qs, errors):
cdef str _do_quote(str val, str safe, str protected, bint qs, bint strict):
cdef uint8_t b
cdef Py_UCS4 ch, unquoted
cdef str tmp
Expand All @@ -62,16 +62,32 @@ cdef str _do_quote(str val, str safe, str protected, bint qs, errors):
if not qs:
safe += '+&=;'
safe += protected
for ch in val:
cdef int idx = 0
while idx < len(val):
ch = val[idx]
idx += 1

if has_pct:
pct[has_pct-1] = ch
has_pct += 1
if has_pct == 3:
digit1 = _from_hex(pct[0])
digit2 = _from_hex(pct[1])
if digit1 == -1 or digit2 == -1:
raise ValueError("Unallowed PCT %{}{}".format(pct[0],
pct[1]))
if strict:
raise ValueError("Unallowed PCT %{}{}".format(pct[0],
pct[1]))
else:
PyUnicode_WriteChar(ret, ret_idx, '%')
ret_idx += 1
PyUnicode_WriteChar(ret, ret_idx, '2')
ret_idx += 1
PyUnicode_WriteChar(ret, ret_idx, '5')
ret_idx += 1
idx -= 2
has_pct = 0
continue

ch = <Py_UCS4>(digit1 << 4 | digit2)
has_pct = 0

Expand Down Expand Up @@ -111,7 +127,7 @@ cdef str _do_quote(str val, str safe, str protected, bint qs, errors):
ret_idx +=1
continue

ch_bytes = ch.encode("utf-8", errors=errors)
ch_bytes = ch.encode("utf-8", errors= 'strict' if strict else 'ignore')

for b in ch_bytes:
PyUnicode_WriteChar(ret, ret_idx, '%')
Expand Down Expand Up @@ -157,9 +173,9 @@ cdef str _do_unquote(str val, str unsafe='', bint qs=False):
pass
else:
if qs and unquoted in '+=&;':
ret.append(_do_quote(unquoted, '', '', True, 'strict'))
ret.append(_do_quote(unquoted, '', '', True, True))
elif unquoted in unsafe:
ret.append(_do_quote(unquoted, '', '', False, 'strict'))
ret.append(_do_quote(unquoted, '', '', False, True))
else:
ret.append(unquoted)
del pcts[:]
Expand Down Expand Up @@ -195,9 +211,9 @@ cdef str _do_unquote(str val, str unsafe='', bint qs=False):
ret.append(last_pct) # %F8
else:
if qs and unquoted in '+=&;':
ret.append(_do_quote(unquoted, '', '', True, 'strict'))
ret.append(_do_quote(unquoted, '', '', True, True))
elif unquoted in unsafe:
ret.append(_do_quote(unquoted, '', '', False, 'strict'))
ret.append(_do_quote(unquoted, '', '', False, True))
else:
ret.append(unquoted)
return ''.join(ret)
19 changes: 15 additions & 4 deletions yarl/quoting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,26 @@
ALLOWED = UNRESERVED + SUB_DELIMS_WITHOUT_QS


def _py_quote(val, *, safe='', protected='', qs=False, errors='strict'):
def _py_quote(val, *, safe='', protected='', qs=False, strict=True):
if val is None:
return None
if not isinstance(val, str):
raise TypeError("Argument should be str")
if not val:
return ''
val = val.encode('utf8', errors=errors)
val = val.encode('utf8', errors='strict' if strict else 'ignore')
ret = bytearray()
pct = b''
safe += ALLOWED
if not qs:
safe += '+&=;'
safe += protected
bsafe = safe.encode('ascii')
for ch in val:
idx = 0
while idx < len(val):
ch = val[idx]
idx += 1

if pct:
if ch in BASCII_LOWERCASE:
ch = ch - 32
Expand All @@ -35,7 +39,14 @@ def _py_quote(val, *, safe='', protected='', qs=False, errors='strict'):
try:
unquoted = chr(int(pct[1:].decode('ascii'), base=16))
except ValueError:
raise ValueError("Unallowed PCT {}".format(pct))
if strict:
raise ValueError("Unallowed PCT {}".format(pct))
else:
ret.extend(b'%25')
pct = b''
idx -= 2
continue

if unquoted in protected:
ret.extend(pct)
elif unquoted in safe:
Expand Down

0 comments on commit 96d36c6

Please sign in to comment.