Skip to content

Commit

Permalink
Flatten everything under the arguments validator
Browse files Browse the repository at this point in the history
  • Loading branch information
Viicos committed Sep 17, 2024
1 parent 5969925 commit 9cab63c
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 66 deletions.
49 changes: 8 additions & 41 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Pattern, Set, Tuple, Type, Union, overload
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Pattern, Set, Tuple, Type, Union

from typing_extensions import deprecated

Expand Down Expand Up @@ -3372,53 +3372,15 @@ def arguments_parameter(
return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias)


class VarKwargsSchema(TypedDict):
type: Literal['var_kwargs']
mode: Literal['single', 'typed_dict']
schema: CoreSchema


@overload
def var_kwargs_schema(
*,
mode: Literal['single'],
schema: CoreSchema,
) -> VarKwargsSchema: ...


@overload
def var_kwargs_schema(
*,
mode: Literal['typed_dict'],
schema: TypedDictSchema,
) -> VarKwargsSchema: ...


def var_kwargs_schema(
*,
mode: Literal['single', 'typed_dict'],
schema: CoreSchema,
) -> VarKwargsSchema:
"""Returns a schema describing the variadic keyword arguments of a callable.
Args:
mode: The validation mode to use. If `'single'`, every value of the keyword arguments will
be validated against the core schema from the `schema` argument. If `'typed_dict'`, the
`schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema].
"""

return _dict_not_none(
type='var_kwargs',
mode=mode,
schema=schema,
)
VarKwargsMode: TypeAlias = Literal['single', 'unpacked-typed-dict']


class ArgumentsSchema(TypedDict, total=False):
type: Required[Literal['arguments']]
arguments_schema: Required[List[ArgumentsParameter]]
populate_by_name: bool
var_args_schema: CoreSchema
var_kwargs_mode: VarKwargsMode
var_kwargs_schema: CoreSchema
ref: str
metadata: Dict[str, Any]
Expand All @@ -3430,6 +3392,7 @@ def arguments_schema(
*,
populate_by_name: bool | None = None,
var_args_schema: CoreSchema | None = None,
var_kwargs_mode: VarKwargsMode | None = None,
var_kwargs_schema: CoreSchema | None = None,
ref: str | None = None,
metadata: Dict[str, Any] | None = None,
Expand All @@ -3456,6 +3419,9 @@ def arguments_schema(
arguments: The arguments to use for the arguments schema
populate_by_name: Whether to populate by name
var_args_schema: The variable args schema to use for the arguments schema
var_kwargs_mode: The validation mode to use for variadic keyword arguments. If `'single'`, every value of the
keyword arguments will be validated against the `var_kwargs_schema` schema. If `'unpacked-typed-dict'`,
the `schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema]
var_kwargs_schema: The variable kwargs schema to use for the arguments schema
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
Expand All @@ -3466,6 +3432,7 @@ def arguments_schema(
arguments_schema=arguments,
populate_by_name=populate_by_name,
var_args_schema=var_args_schema,
var_kwargs_mode=var_kwargs_mode,
var_kwargs_schema=var_kwargs_schema,
ref=ref,
metadata=metadata,
Expand Down
109 changes: 88 additions & 21 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::str::FromStr;

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
Expand All @@ -15,6 +17,27 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, PartialEq)]
enum VarKwargsMode {
Single,
UnpackedTypedDict,
}

impl FromStr for VarKwargsMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"single" => Ok(Self::Single),
"unpacked-typed-dict" => Ok(Self::UnpackedTypedDict),
s => py_schema_err!(
"Invalid var_kwargs mode: `{}`, expected `single` or `unpacked-typed-dict`",
s
),
}
}
}

#[derive(Debug)]
struct Parameter {
positional: bool,
Expand All @@ -29,6 +52,7 @@ pub struct ArgumentsValidator {
parameters: Vec<Parameter>,
positional_params_count: usize,
var_args_validator: Option<Box<CombinedValidator>>,
var_kwargs_mode: VarKwargsMode,
var_kwargs_validator: Option<Box<CombinedValidator>>,
loc_by_alias: bool,
extra: ExtraBehavior,
Expand Down Expand Up @@ -117,17 +141,31 @@ impl BuildValidator for ArgumentsValidator {
});
}

let py_var_kwargs_mode: Bound<PyString> = match schema.get_as(intern!(py, "var_kwargs_mode"))? {
Some(v) => v,
None => PyString::new_bound(py, "single"),
};
let var_kwargs_mode = VarKwargsMode::from_str(py_var_kwargs_mode.to_string().as_str())?;
let var_kwargs_validator = match schema.get_item(intern!(py, "var_kwargs_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
};

if var_kwargs_mode == VarKwargsMode::UnpackedTypedDict && var_kwargs_validator.is_none() {
return py_schema_err!(
"`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
);
}

Ok(Self {
parameters,
positional_params_count,
var_args_validator: match schema.get_item(intern!(py, "var_args_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
},
var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
},
var_kwargs_mode,
var_kwargs_validator,
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
}
Expand Down Expand Up @@ -258,6 +296,8 @@ impl Validator for ArgumentsValidator {
// if there are kwargs check any that haven't been processed yet
if let Some(kwargs) = args.kwargs() {
if kwargs.len() > used_kwargs.len() {
let remaining_kwargs = PyDict::new_bound(py);

for result in kwargs.iter() {
let (raw_key, value) = result?;
let either_str = match raw_key
Expand All @@ -278,28 +318,55 @@ impl Validator for ArgumentsValidator {
Err(err) => return Err(err),
};
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
match self.var_kwargs_mode {
VarKwargsMode::Single => match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_kwargs
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
}
}
Err(err) => return Err(err),
},
None => {
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.clone(),
));
}
}
Err(err) => return Err(err),
},
None => {
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.clone(),
));
}
VarKwargsMode::UnpackedTypedDict => {
// Save to the remaining kwargs, we will validate as a single dict:
remaining_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
}
}
}

if self.var_kwargs_mode == VarKwargsMode::UnpackedTypedDict {
// `var_kwargs_validator` is guaranteed to be `Some`:
match self
.var_kwargs_validator
.as_ref()
.unwrap()
.validate(py, remaining_kwargs.as_any(), state)
{
Ok(value) => {
output_kwargs.update(value.downcast_bound::<PyDict>(py).unwrap().as_mapping())?;
}
Err(ValError::LineErrors(line_errors)) => {
for error in line_errors {
errors.push(error);
}
}
Err(err) => return Err(err),
}
}
}
Expand Down
3 changes: 0 additions & 3 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ mod union;
mod url;
mod uuid;
mod validation_state;
mod var_kwargs;
mod with_default;

pub use self::validation_state::{Exactness, ValidationState};
Expand Down Expand Up @@ -562,7 +561,6 @@ pub fn build_validator(
callable::CallableValidator,
// arguments
arguments::ArgumentsValidator,
var_kwargs::VarKwargsValidator,
// default value
with_default::WithDefaultValidator,
// chain validators
Expand Down Expand Up @@ -718,7 +716,6 @@ pub enum CombinedValidator {
Callable(callable::CallableValidator),
// arguments
Arguments(arguments::ArgumentsValidator),
VarKwargs(var_kwargs::VarKwargsValidator),
// default value
WithDefault(with_default::WithDefaultValidator),
// chain validators
Expand Down
57 changes: 56 additions & 1 deletion tests/validators/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,19 @@ def test_build_non_default_follows():
)


def test_build_missing_var_kwargs():
with pytest.raises(
SchemaError, match="`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
):
SchemaValidator(
{
'type': 'arguments',
'arguments_schema': [],
'var_kwargs_mode': 'unpacked-typed-dict',
}
)


@pytest.mark.parametrize(
'input_value,expected',
[
Expand All @@ -778,7 +791,7 @@ def test_build_non_default_follows():
],
ids=repr,
)
def test_kwargs(py_and_json: PyAndJson, input_value, expected):
def test_kwargs_single(py_and_json: PyAndJson, input_value, expected):
v = py_and_json(
{
'type': 'arguments',
Expand All @@ -796,6 +809,48 @@ def test_kwargs(py_and_json: PyAndJson, input_value, expected):
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize(
'input_value,expected',
[
[ArgsKwargs((), {'x': 1}), ((), {'x': 1})],
[ArgsKwargs((), {'x': 1.0}), Err('x\n Input should be a valid integer [type=int_type,')],
[ArgsKwargs((), {'x': 1, 'z': 'str'}), ((), {'x': 1, 'y': 'str'})],
[ArgsKwargs((), {'x': 1, 'y': 'str'}), Err('y\n Extra inputs are not permitted [type=extra_forbidden,')],
],
)
def test_kwargs_typed_dict(py_and_json: PyAndJson, input_value, expected):
v = py_and_json(
{
'type': 'arguments',
'arguments_schema': [],
'var_kwargs_mode': 'unpacked-typed-dict',
'var_kwargs_schema': {
'type': 'typed-dict',
'fields': {
'x': {
'type': 'typed-dict-field',
'schema': {'type': 'int', 'strict': True},
'required': True,
},
'y': {
'type': 'typed-dict-field',
'schema': {'type': 'str'},
'required': False,
'validation_alias': 'z',
},
},
'config': {'extra_fields_behavior': 'forbid'},
},
}
)

if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_test(input_value)
else:
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize(
'input_value,expected',
[
Expand Down

0 comments on commit 9cab63c

Please sign in to comment.