diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 6d76db48f..ab141a2cf 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -149,10 +149,6 @@ impl<'a, 'py> SerializationState<'a, 'py> { self.include_exclude.1.as_ref() } - pub(crate) fn model_type_name(&self) -> Option> { - self.model.as_ref().and_then(|model| model.get_type().name().ok()) - } - pub fn serialize_infer<'slf>( &'slf mut self, value: &'slf Bound<'py, PyAny>, diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index a71103285..c79cfeb74 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -171,6 +171,7 @@ impl GeneralFieldsSerializer { pub(crate) fn main_to_python<'py>( &self, py: Python<'py>, + model: &Bound<'py, PyAny>, main_iter: impl Iterator, Bound<'py, PyAny>)>>, state: &mut SerializationState<'_, 'py>, ) -> PyResult> { @@ -218,7 +219,7 @@ impl GeneralFieldsSerializer { return Err(PydanticSerializationUnexpectedValue::new( Some(format!("Unexpected field `{key}`")), Some(key_str.to_string()), - state.model_type_name().map(|bound| bound.to_string()), + model_type_name(model), None, ) .to_py_err()); @@ -244,8 +245,8 @@ impl GeneralFieldsSerializer { Err(PydanticSerializationUnexpectedValue::new( Some(format!("Expected {required_fields} fields but got {used_req_fields}").to_string()), state.field_name.as_ref().map(ToString::to_string), - state.model_type_name().map(|bound| bound.to_string()), - state.model.clone().map(Bound::unbind), + model_type_name(model), + Some(model.clone().unbind()), ) .to_py_err()) } else { @@ -353,7 +354,6 @@ impl GeneralFieldsSerializer { state: &mut SerializationState<'_, 'py>, ) -> PyResult<()> { if let Some(ref computed_fields) = self.computed_fields { - let state = &mut state.scoped_set(|s| &mut s.model, Some(model.clone())); computed_fields.to_python(model, output_dict, &self.filter, state)?; } Ok(()) @@ -366,7 +366,6 @@ impl GeneralFieldsSerializer { state: &mut SerializationState<'_, 'py>, ) -> Result<(), S::Error> { if let Some(ref computed_fields) = self.computed_fields { - // FIXME: need to match state.model setting above in `add_computed_fields_python`?? computed_fields.serde_serialize::(model, map, &self.filter, state)?; } Ok(()) @@ -390,21 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer { ) -> PyResult> { let py = value.py(); let missing_sentinel = get_missing_sentinel_object(py); - // If there is already a model registered (from a dataclass, BaseModel) - // then do not touch it - // If there is no model, we (a TypedDict) are the model - let model = state.model.clone().unwrap_or_else(|| value.clone()); + + let model = get_model(state)?; let Some((main_dict, extra_dict)) = self.extract_dicts(value) else { state.warn_fallback_py(self.get_name(), value)?; return infer_to_python(value, state); }; - let output_dict = self.main_to_python( - py, - dict_items(&main_dict), - // FIXME: should also set model for extra serialization? - &mut state.scoped_set(|s| &mut s.model, Some(model.clone())), - )?; + let output_dict = self.main_to_python(py, &model, dict_items(&main_dict), state)?; // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { @@ -448,10 +440,7 @@ impl TypeSerializer for GeneralFieldsSerializer { return infer_serialize(value, serializer, state); }; let missing_sentinel = get_missing_sentinel_object(value.py()); - // If there is already a model registered (from a dataclass, BaseModel) - // then do not touch it - // If there is no model, we (a TypedDict) are the model - let model = state.model.clone().unwrap_or_else(|| value.clone()); + let model = get_model(state).map_err(py_err_se_err)?; let expected_len = match self.mode { FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(), @@ -459,13 +448,7 @@ impl TypeSerializer for GeneralFieldsSerializer { }; // NOTE! As above, we maintain the order of the input dict assuming that's right // we don't both with `used_req_fields` here because on unions, `to_python(..., mode='json')` is used - let mut map = self.main_serde_serialize( - dict_items(&main_dict), - expected_len, - serializer, - // FIXME: should also set model for extra serialization? - &mut state.scoped_set(|s| &mut s.model, Some(model.clone())), - )?; + let mut map = self.main_serde_serialize(dict_items(&main_dict), expected_len, serializer, state)?; // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { @@ -507,3 +490,19 @@ fn dict_items<'py>( let main_items: SmallVec<[_; 16]> = main_dict.iter().collect(); main_items.into_iter().map(Ok) } + +fn get_model<'py>(state: &mut SerializationState<'_, 'py>) -> PyResult> { + state.model.clone().ok_or_else(|| { + PydanticSerializationUnexpectedValue::new( + Some("No model found for fields serialization".to_string()), + None, + None, + None, + ) + .to_py_err() + }) +} + +fn model_type_name(model: &Bound<'_, PyAny>) -> Option { + model.get_type().name().ok().map(|s| s.to_string()) +} diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 98e91098c..f8a020169 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -110,7 +110,6 @@ combined_serializer! { super::type_serializers::function::FunctionPlainSerializerBuilder; super::type_serializers::function::FunctionWrapSerializerBuilder; super::type_serializers::model::ModelFieldsBuilder; - super::type_serializers::typed_dict::TypedDictBuilder; } // `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in // `find_serializer` so they can be used via a `type` str. @@ -151,6 +150,7 @@ combined_serializer! { Recursive: super::type_serializers::definitions::DefinitionRefSerializer; Tuple: super::type_serializers::tuple::TupleSerializer; Complex: super::type_serializers::complex::ComplexSerializer; + TypedDict: super::type_serializers::typed_dict::TypedDictSerializer; } } @@ -356,6 +356,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::TypedDict(inner) => inner.py_gc_traverse(visit), } } } diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index 94126547b..f9ccd2bbc 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -150,21 +150,21 @@ impl TypeSerializer for DataclassSerializer { value: &Bound<'py, PyAny>, state: &mut SerializationState<'_, 'py>, ) -> PyResult> { - let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone())); if self.allow_value(value, state)? { + let model = value; + let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone())); let py = value.py(); if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { let output_dict: Bound = - fields_serializer.main_to_python(py, known_dataclass_iter(&self.fields, value), state)?; + fields_serializer.main_to_python(py, model, known_dataclass_iter(&self.fields, model), state)?; - fields_serializer.add_computed_fields_python(value, &output_dict, state)?; + fields_serializer.add_computed_fields_python(model, &output_dict, state)?; Ok(output_dict.into()) } else { let inner_value = self.get_inner_value(value)?; self.serializer.to_python(&inner_value, state) } } else { - // FIXME: probably don't want to have state.model set here, should move the scoped_set above? state.warn_fallback_py(self.get_name(), value)?; infer_to_python(value, state) } @@ -189,8 +189,8 @@ impl TypeSerializer for DataclassSerializer { serializer: S, state: &mut SerializationState<'_, 'py>, ) -> Result { - let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone())); if self.allow_value(value, state).map_err(py_err_se_err)? { + let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone())); if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { let expected_len = self.fields.len() + fields_serializer.computed_field_count(); let mut map = fields_serializer.main_serde_serialize( @@ -206,7 +206,6 @@ impl TypeSerializer for DataclassSerializer { self.serializer.serde_serialize(&inner_value, serializer, state) } } else { - // FIXME: probably don't want to have state.model set here, should move the scoped_set above? state.warn_fallback_ser::(self.get_name(), value)?; infer_serialize(value, serializer, state) } diff --git a/src/serializers/type_serializers/typed_dict.rs b/src/serializers/type_serializers/typed_dict.rs index 9b1751bc2..422525807 100644 --- a/src/serializers/type_serializers/typed_dict.rs +++ b/src/serializers/type_serializers/typed_dict.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::sync::Arc; use pyo3::intern; @@ -9,14 +10,20 @@ use ahash::AHashMap; use crate::build_tools::py_schema_err; use crate::build_tools::{py_schema_error_type, schema_or_config, ExtraBehavior}; use crate::definitions::DefinitionsBuilder; +use crate::serializers::shared::TypeSerializer; +use crate::serializers::SerializationState; use crate::tools::SchemaDict; use super::{BuildSerializer, CombinedSerializer, ComputedFields, FieldsMode, GeneralFieldsSerializer, SerField}; #[derive(Debug)] -pub struct TypedDictBuilder; +pub struct TypedDictSerializer { + serializer: GeneralFieldsSerializer, +} + +impl_py_gc_traverse!(TypedDictSerializer { serializer }); -impl BuildSerializer for TypedDictBuilder { +impl BuildSerializer for TypedDictSerializer { const EXPECTED_TYPE: &'static str = "typed-dict"; fn build( @@ -82,10 +89,51 @@ impl BuildSerializer for TypedDictBuilder { } } + // FIXME: computed fields do not work for TypedDict, and may never + // see the closed https://github.com/pydantic/pydantic-core/pull/1018 let computed_fields = ComputedFields::new(schema, config, definitions)?; Ok(Arc::new( - GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields).into(), + Self { + serializer: GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields), + } + .into(), )) } } + +impl TypeSerializer for TypedDictSerializer { + fn to_python<'py>( + &self, + value: &Bound<'py, PyAny>, + state: &mut SerializationState<'_, 'py>, + ) -> PyResult> { + self.serializer + .to_python(value, &mut state.scoped_set(|s| &mut s.model, Some(value.clone()))) + } + + fn json_key<'a, 'py>( + &self, + key: &'a Bound<'py, PyAny>, + state: &mut SerializationState<'_, 'py>, + ) -> PyResult> { + self.invalid_as_json_key(key, state, "typed-dict") + } + + fn serde_serialize<'py, S: serde::ser::Serializer>( + &self, + value: &Bound<'py, PyAny>, + serializer: S, + state: &mut SerializationState<'_, 'py>, + ) -> Result { + self.serializer.serde_serialize( + value, + serializer, + &mut state.scoped_set(|s| &mut s.model, Some(value.clone())), + ) + } + + fn get_name(&self) -> &'static str { + "typed-dict" + } +} diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index 6626a3981..7917ec088 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -376,3 +376,35 @@ class Model(TypedDict): ) s = SchemaSerializer(schema, config=core_schema.CoreConfig(serialize_by_alias=config or False)) assert s.to_python(Model(my_field=1), by_alias=runtime) == expected + + +def test_nested_typed_dict_field_serializers(): + class Model(TypedDict): + x: Any + + class OuterModel(TypedDict): + model: Model + + schema = core_schema.typed_dict_schema( + { + 'x': core_schema.typed_dict_field( + core_schema.any_schema( + serialization=core_schema.wrap_serializer_function_ser_schema( + # in an incorrect core implementation, self could be OuterModel here + lambda self, v, serializer: f'{list(self.keys())}', + is_field_serializer=True, + schema=core_schema.any_schema(), + ) + ) + ) + } + ) + outer_schema = core_schema.typed_dict_schema({'model': core_schema.typed_dict_field(schema)}) + + s = SchemaSerializer(schema) + assert s.to_python(Model(x=None)) == {'x': "['x']"} + + outer_s = SchemaSerializer(outer_schema) + # if the inner field serializer incorrectly receives OuterModel as self, the keys + # will be ['model'] instead of ['x'] + assert outer_s.to_python(OuterModel(model=Model(x=None))) == {'model': {'x': "['x']"}}