Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ impl<'a, 'py> SerializationState<'a, 'py> {
self.include_exclude.1.as_ref()
}

pub(crate) fn model_type_name(&self) -> Option<Bound<'py, PyString>> {
self.model.as_ref().and_then(|model| model.get_type().name().ok())
}

pub fn serialize_infer<'slf>(
&'slf mut self,
value: &'slf Bound<'py, PyAny>,
Expand Down
51 changes: 25 additions & 26 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ impl GeneralFieldsSerializer {
pub(crate) fn main_to_python<'py>(
&self,
py: Python<'py>,
model: &Bound<'py, PyAny>,
main_iter: impl Iterator<Item = PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>>,
state: &mut SerializationState<'_, 'py>,
) -> PyResult<Bound<'py, PyDict>> {
Expand Down Expand Up @@ -218,7 +219,7 @@ impl GeneralFieldsSerializer {
return Err(PydanticSerializationUnexpectedValue::new(
Some(format!("Unexpected field `{key}`")),
Some(key_str.to_string()),
state.model_type_name().map(|bound| bound.to_string()),
model_type_name(model),
None,
)
.to_py_err());
Expand All @@ -244,8 +245,8 @@ impl GeneralFieldsSerializer {
Err(PydanticSerializationUnexpectedValue::new(
Some(format!("Expected {required_fields} fields but got {used_req_fields}").to_string()),
state.field_name.as_ref().map(ToString::to_string),
state.model_type_name().map(|bound| bound.to_string()),
state.model.clone().map(Bound::unbind),
model_type_name(model),
Some(model.clone().unbind()),
)
.to_py_err())
} else {
Expand Down Expand Up @@ -353,7 +354,6 @@ impl GeneralFieldsSerializer {
state: &mut SerializationState<'_, 'py>,
) -> PyResult<()> {
if let Some(ref computed_fields) = self.computed_fields {
let state = &mut state.scoped_set(|s| &mut s.model, Some(model.clone()));
computed_fields.to_python(model, output_dict, &self.filter, state)?;
}
Ok(())
Expand All @@ -366,7 +366,6 @@ impl GeneralFieldsSerializer {
state: &mut SerializationState<'_, 'py>,
) -> Result<(), S::Error> {
if let Some(ref computed_fields) = self.computed_fields {
// FIXME: need to match state.model setting above in `add_computed_fields_python`??
computed_fields.serde_serialize::<S>(model, map, &self.filter, state)?;
}
Ok(())
Expand All @@ -390,21 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer {
) -> PyResult<Py<PyAny>> {
let py = value.py();
let missing_sentinel = get_missing_sentinel_object(py);
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let model = state.model.clone().unwrap_or_else(|| value.clone());

let model = get_model(state)?;

let Some((main_dict, extra_dict)) = self.extract_dicts(value) else {
state.warn_fallback_py(self.get_name(), value)?;
return infer_to_python(value, state);
};
let output_dict = self.main_to_python(
py,
dict_items(&main_dict),
// FIXME: should also set model for extra serialization?
&mut state.scoped_set(|s| &mut s.model, Some(model.clone())),
)?;
let output_dict = self.main_to_python(py, &model, dict_items(&main_dict), state)?;

// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
Expand Down Expand Up @@ -448,24 +440,15 @@ impl TypeSerializer for GeneralFieldsSerializer {
return infer_serialize(value, serializer, state);
};
let missing_sentinel = get_missing_sentinel_object(value.py());
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let model = state.model.clone().unwrap_or_else(|| value.clone());
let model = get_model(state).map_err(py_err_se_err)?;

let expected_len = match self.mode {
FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(),
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
};
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_req_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = self.main_serde_serialize(
dict_items(&main_dict),
expected_len,
serializer,
// FIXME: should also set model for extra serialization?
&mut state.scoped_set(|s| &mut s.model, Some(model.clone())),
)?;
let mut map = self.main_serde_serialize(dict_items(&main_dict), expected_len, serializer, state)?;

// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
Expand Down Expand Up @@ -507,3 +490,19 @@ fn dict_items<'py>(
let main_items: SmallVec<[_; 16]> = main_dict.iter().collect();
main_items.into_iter().map(Ok)
}

fn get_model<'py>(state: &mut SerializationState<'_, 'py>) -> PyResult<Bound<'py, PyAny>> {
state.model.clone().ok_or_else(|| {
PydanticSerializationUnexpectedValue::new(
Some("No model found for fields serialization".to_string()),
None,
None,
None,
)
.to_py_err()
})
}

fn model_type_name(model: &Bound<'_, PyAny>) -> Option<String> {
model.get_type().name().ok().map(|s| s.to_string())
}
3 changes: 2 additions & 1 deletion src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ combined_serializer! {
super::type_serializers::function::FunctionPlainSerializerBuilder;
super::type_serializers::function::FunctionWrapSerializerBuilder;
super::type_serializers::model::ModelFieldsBuilder;
super::type_serializers::typed_dict::TypedDictBuilder;
}
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
// `find_serializer` so they can be used via a `type` str.
Expand Down Expand Up @@ -151,6 +150,7 @@ combined_serializer! {
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
Tuple: super::type_serializers::tuple::TupleSerializer;
Complex: super::type_serializers::complex::ComplexSerializer;
TypedDict: super::type_serializers::typed_dict::TypedDictSerializer;
}
}

Expand Down Expand Up @@ -356,6 +356,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TypedDict(inner) => inner.py_gc_traverse(visit),
}
}
}
Expand Down
11 changes: 5 additions & 6 deletions src/serializers/type_serializers/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,21 @@ impl TypeSerializer for DataclassSerializer {
value: &Bound<'py, PyAny>,
state: &mut SerializationState<'_, 'py>,
) -> PyResult<Py<PyAny>> {
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
if self.allow_value(value, state)? {
let model = value;
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
let py = value.py();
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
let output_dict: Bound<PyDict> =
fields_serializer.main_to_python(py, known_dataclass_iter(&self.fields, value), state)?;
fields_serializer.main_to_python(py, model, known_dataclass_iter(&self.fields, model), state)?;

fields_serializer.add_computed_fields_python(value, &output_dict, state)?;
fields_serializer.add_computed_fields_python(model, &output_dict, state)?;
Ok(output_dict.into())
} else {
let inner_value = self.get_inner_value(value)?;
self.serializer.to_python(&inner_value, state)
}
} else {
// FIXME: probably don't want to have state.model set here, should move the scoped_set above?
state.warn_fallback_py(self.get_name(), value)?;
infer_to_python(value, state)
}
Expand All @@ -189,8 +189,8 @@ impl TypeSerializer for DataclassSerializer {
serializer: S,
state: &mut SerializationState<'_, 'py>,
) -> Result<S::Ok, S::Error> {
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
if self.allow_value(value, state).map_err(py_err_se_err)? {
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
let expected_len = self.fields.len() + fields_serializer.computed_field_count();
let mut map = fields_serializer.main_serde_serialize(
Expand All @@ -206,7 +206,6 @@ impl TypeSerializer for DataclassSerializer {
self.serializer.serde_serialize(&inner_value, serializer, state)
}
} else {
// FIXME: probably don't want to have state.model set here, should move the scoped_set above?
state.warn_fallback_ser::<S>(self.get_name(), value)?;
infer_serialize(value, serializer, state)
}
Expand Down
54 changes: 51 additions & 3 deletions src/serializers/type_serializers/typed_dict.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::sync::Arc;

use pyo3::intern;
Expand All @@ -9,14 +10,20 @@ use ahash::AHashMap;
use crate::build_tools::py_schema_err;
use crate::build_tools::{py_schema_error_type, schema_or_config, ExtraBehavior};
use crate::definitions::DefinitionsBuilder;
use crate::serializers::shared::TypeSerializer;
use crate::serializers::SerializationState;
use crate::tools::SchemaDict;

use super::{BuildSerializer, CombinedSerializer, ComputedFields, FieldsMode, GeneralFieldsSerializer, SerField};

#[derive(Debug)]
pub struct TypedDictBuilder;
pub struct TypedDictSerializer {
serializer: GeneralFieldsSerializer,
}

impl_py_gc_traverse!(TypedDictSerializer { serializer });

impl BuildSerializer for TypedDictBuilder {
impl BuildSerializer for TypedDictSerializer {
const EXPECTED_TYPE: &'static str = "typed-dict";

fn build(
Expand Down Expand Up @@ -82,10 +89,51 @@ impl BuildSerializer for TypedDictBuilder {
}
}

// FIXME: computed fields do not work for TypedDict, and may never
// see the closed https://github.com/pydantic/pydantic-core/pull/1018
let computed_fields = ComputedFields::new(schema, config, definitions)?;

Ok(Arc::new(
GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields).into(),
Self {
serializer: GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields),
}
.into(),
))
}
}

impl TypeSerializer for TypedDictSerializer {
fn to_python<'py>(
&self,
value: &Bound<'py, PyAny>,
state: &mut SerializationState<'_, 'py>,
) -> PyResult<Py<PyAny>> {
self.serializer
.to_python(value, &mut state.scoped_set(|s| &mut s.model, Some(value.clone())))
}

fn json_key<'a, 'py>(
&self,
key: &'a Bound<'py, PyAny>,
state: &mut SerializationState<'_, 'py>,
) -> PyResult<Cow<'a, str>> {
self.invalid_as_json_key(key, state, "typed-dict")
}

fn serde_serialize<'py, S: serde::ser::Serializer>(
&self,
value: &Bound<'py, PyAny>,
serializer: S,
state: &mut SerializationState<'_, 'py>,
) -> Result<S::Ok, S::Error> {
self.serializer.serde_serialize(
value,
serializer,
&mut state.scoped_set(|s| &mut s.model, Some(value.clone())),
)
}

fn get_name(&self) -> &'static str {
"typed-dict"
}
}
32 changes: 32 additions & 0 deletions tests/serializers/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,35 @@ class Model(TypedDict):
)
s = SchemaSerializer(schema, config=core_schema.CoreConfig(serialize_by_alias=config or False))
assert s.to_python(Model(my_field=1), by_alias=runtime) == expected


def test_nested_typed_dict_field_serializers():
class Model(TypedDict):
x: Any

class OuterModel(TypedDict):
model: Model

schema = core_schema.typed_dict_schema(
{
'x': core_schema.typed_dict_field(
core_schema.any_schema(
serialization=core_schema.wrap_serializer_function_ser_schema(
# in an incorrect core implementation, self could be OuterModel here
lambda self, v, serializer: f'{list(self.keys())}',
is_field_serializer=True,
schema=core_schema.any_schema(),
)
)
)
}
)
outer_schema = core_schema.typed_dict_schema({'model': core_schema.typed_dict_field(schema)})

s = SchemaSerializer(schema)
assert s.to_python(Model(x=None)) == {'x': "['x']"}

outer_s = SchemaSerializer(outer_schema)
# if the inner field serializer incorrectly receives OuterModel as self, the keys
# will be ['model'] instead of ['x']
assert outer_s.to_python(OuterModel({'model': Model({'x': None})})) == {'model': {'x': "['x']"}}
Loading