Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/scripts/_gen_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _snap_image(_obj: type, _name: str) -> str:
from qtpy.QtWidgets import QVBoxLayout, QWidget

outer = QWidget()
if _obj is widgets.Container:
if _obj in (widgets.Container, widgets.ModelContainerWidget):
return ""
if issubclass(_obj, widgets.FunctionGui):
return ""
Expand Down
18 changes: 11 additions & 7 deletions src/magicgui/schema/_guiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@

from magicgui.schema._ui_field import build_widget
from magicgui.widgets import PushButton
from magicgui.widgets.bases import BaseValueWidget, ContainerWidget
from magicgui.widgets.bases import BaseValueWidget

if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Protocol

from typing_extensions import TypeGuard

from magicgui.widgets._concrete import ModelContainerWidget
from magicgui.widgets.bases._container_widget import BaseContainerWidget

# fmt: off
class GuiClassProtocol(Protocol):
"""Protocol for a guiclass."""

@property
def gui(self) -> ContainerWidget: ...
def gui(self) -> ModelContainerWidget: ...
@property
def events(self) -> SignalGroup: ...
# fmt: on
Expand Down Expand Up @@ -233,7 +236,7 @@ def __set_name__(self, owner: type, name: str) -> None:
evented(owner, events_namespace=self._events_namespace)
setattr(owner, _GUICLASS_FLAG, True)

def widget(self) -> ContainerWidget:
def widget(self) -> ModelContainerWidget:
"""Return a widget for the dataclass or instance."""
if self._owner is None:
raise TypeError(
Expand All @@ -243,7 +246,7 @@ def widget(self) -> ContainerWidget:

def __get__(
self, instance: object | None, owner: type
) -> ContainerWidget[BaseValueWidget] | GuiBuilder:
) -> ModelContainerWidget[BaseValueWidget] | GuiBuilder:
if instance is None:
return self
wdg = build_widget(instance)
Expand All @@ -253,7 +256,8 @@ def __get__(
for k, v in vars(owner).items():
if hasattr(v, _BUTTON_ATTR):
kwargs = getattr(v, _BUTTON_ATTR)
button = PushButton(**kwargs)
# gui_only=True excludes button from model value construction
button = PushButton(gui_only=True, **kwargs)
if instance is not None:
# call the bound method if we're in an instance
button.clicked.connect(getattr(instance, k))
Expand All @@ -277,7 +281,7 @@ def __get__(


def bind_gui_to_instance(
gui: ContainerWidget, instance: Any, two_way: bool = True
gui: BaseContainerWidget, instance: Any, two_way: bool = True
) -> None:
"""Set change events in `gui` to update the corresponding attributes in `model`.

Expand Down Expand Up @@ -340,7 +344,7 @@ def bind_gui_to_instance(
signals[name].connect_setattr(widget, "value")


def unbind_gui_from_instance(gui: ContainerWidget, instance: Any) -> None:
def unbind_gui_from_instance(gui: BaseContainerWidget, instance: Any) -> None:
"""Unbind a gui from an instance.

This will disconnect all events that were connected by `bind_gui_to_instance`.
Expand Down
122 changes: 57 additions & 65 deletions src/magicgui/schema/_ui_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
Literal,
TypeVar,
Union,
cast,
)

from typing_extensions import TypeGuard, get_args, get_origin

from magicgui.types import JsonStringFormats, Undefined, _Undefined

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
from collections.abc import Iterator
from typing import Protocol

import attrs
Expand All @@ -32,7 +31,8 @@
from attrs import Attribute
from pydantic.fields import FieldInfo, ModelField

from magicgui.widgets.bases import BaseValueWidget, ContainerWidget
from magicgui.widgets import ModelContainerWidget
from magicgui.widgets.bases import BaseValueWidget

class HasAttrs(Protocol):
"""Protocol for objects that have an ``attrs`` attribute."""
Expand Down Expand Up @@ -441,6 +441,16 @@ def create_widget(self, value: T | _Undefined = Undefined) -> BaseValueWidget[T]
opts["min"] = d["exclusive_minimum"] + m

value = value if value is not Undefined else self.get_default() # type: ignore
# build a container widget from a dataclass-like object
# TODO: should this eventually move to get_widget_class?
if _is_dataclass_like(self.type):
wdg = build_widget(self.type)
wdg.name = self.name or ""
wdg.label = self.name or ""
if value is not None:
wdg.value = value
return wdg
# create widget subclass for everything else
cls, kwargs = get_widget_class(value=value, annotation=self.type, options=opts)
return cls(**kwargs) # type: ignore

Expand Down Expand Up @@ -718,6 +728,24 @@ def _ui_fields_from_annotation(cls: type) -> Iterator[UiField]:
yield field.parse_annotated()


def _is_dataclass_like(object: Any) -> bool:
# check if it's a pydantic1 style dataclass
model = _get_pydantic_model(object)
if model is not None:
if hasattr(model, "model_fields"):
return True
# check if it's a pydantic2 style dataclass
if hasattr(object, "__pydantic_fields__"):
return True
# check if it's a (non-pydantic) dataclass
if dc.is_dataclass(object):
return True
# check if it's an attrs class
if _is_attrs_model(object):
return True
return False


def _iter_ui_fields(object: Any) -> Iterator[UiField]:
# check if it's a pydantic model
model = _get_pydantic_model(object)
Expand Down Expand Up @@ -781,76 +809,40 @@ def get_ui_fields(cls_or_instance: object) -> tuple[UiField, ...]:
return tuple(_iter_ui_fields(cls_or_instance))


def _uifields_to_container(
ui_fields: Iterable[UiField],
values: Mapping[str, Any] | None = None,
*,
container_kwargs: Mapping | None = None,
) -> ContainerWidget[BaseValueWidget]:
"""Create a container widget from a sequence of UiFields.
# TODO: unify this with magicgui
# this todo could be the same thing as moving the logic in create_widget above
# to get_widget_cls...
def build_widget(cls_or_instance: Any) -> ModelContainerWidget:
"""Build a magicgui widget from a dataclass, attrs, pydantic, or function.

This function is the heart of build_widget.
Returns a ModelContainerWidget whose `.value` property returns an instance
of the model type, constructed from the current widget values.

Parameters
----------
ui_fields : Iterable[UiField]
A sequence of UiFields to use to create the container.
values : Mapping[str, Any], optional
A mapping of field name to values to use to initialize each widget the
container, by default None.
container_kwargs : Mapping, optional
A mapping of keyword arguments to pass to the container constructor,
by default None.
cls_or_instance : Any
The class or instance to build the widget from.

Returns
-------
ContainerWidget[ValueWidget]
A container widget with a widget for each UiField.
"""
from magicgui import widgets

container = widgets.Container(
widgets=[field.create_widget() for field in ui_fields],
**(container_kwargs or {}),
)
if values is not None:
container.update(values)
return container


def _get_values(obj: Any) -> dict | None:
"""Return a dict of values from an object.

The object can be a dataclass, attrs, pydantic object or named tuple.
ModelContainerWidget
The constructed widget.
"""
if isinstance(obj, dict):
return obj

# named tuple
if isinstance(obj, tuple) and hasattr(obj, "_asdict"):
return cast("dict", obj._asdict())

# dataclass
if dc.is_dataclass(type(obj)):
return dc.asdict(obj)

# attrs
attr = sys.modules.get("attr")
if attr is not None and attr.has(obj):
return cast("dict", attr.asdict(obj))

# pydantic models
if hasattr(obj, "model_dump"):
return cast("dict", obj.model_dump())
elif hasattr(obj, "dict"):
return cast("dict", obj.dict())

return None
from magicgui.widgets import ModelContainerWidget

# Get the class (type) for the model
if isinstance(cls_or_instance, type):
model_type = cls_or_instance
value = None
else:
model_type = type(cls_or_instance)
value = cls_or_instance

# TODO: unify this with magicgui
def build_widget(cls_or_instance: Any) -> ContainerWidget[BaseValueWidget]:
"""Build a magicgui widget from a dataclass, attrs, pydantic, or function."""
values = None if isinstance(cls_or_instance, type) else _get_values(cls_or_instance)
fields = get_ui_fields(cls_or_instance)
return _uifields_to_container(fields, values=values)
inner_widgets = [f.create_widget() for f in fields]

return ModelContainerWidget(
value_type=model_type,
widgets=inner_widgets,
value=value,
)
2 changes: 2 additions & 0 deletions src/magicgui/widgets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LiteralEvalLineEdit,
LogSlider,
MainWindow,
ModelContainerWidget,
Password,
ProgressBar,
PushButton,
Expand Down Expand Up @@ -92,6 +93,7 @@
"LogSlider",
"MainFunctionGui",
"MainWindow",
"ModelContainerWidget",
"Password",
"ProgressBar",
"PushButton",
Expand Down
99 changes: 99 additions & 0 deletions src/magicgui/widgets/_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import inspect
import math
import os
import sys
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -68,6 +69,7 @@
WidgetVar = TypeVar("WidgetVar", bound=Widget)
WidgetTypeVar = TypeVar("WidgetTypeVar", bound=type[Widget])
_V = TypeVar("_V")
_M = TypeVar("_M") # For model/dataclass types


@overload
Expand Down Expand Up @@ -994,6 +996,103 @@ def set_value(self, vals: Sequence[Any]) -> None:
self.changed.emit(self.value)


class ModelContainerWidget(ValuedContainerWidget[_M], Generic[_M]):
"""A container widget for dataclass-like models (dataclass, pydantic, attrs).

This widget wraps a structured type (dataclass, pydantic model, attrs class, etc.)
and provides a `.value` property that returns an instance of that type, constructed
from the values of its child widgets.

Parameters
----------
value_type : type[_M]
The model class to construct when getting the value.
widgets : Sequence[Widget], optional
Child widgets representing the model's fields.
**kwargs : Any
Additional arguments passed to ValuedContainerWidget.
"""

def __init__(
self,
value_type: type[_M],
widgets: Sequence[Widget] = (),
value: _M | None | _Undefined = Undefined,
**kwargs: Any,
) -> None:
self._value_type = value_type
super().__init__(widgets=widgets, **kwargs)
# Connect child widget changes to emit our changed signal
for w in self._list:
if isinstance(w, BaseValueWidget):
w.changed.connect(self._on_child_changed)
if not isinstance(value, _Undefined):
self.set_value(value)

def _on_child_changed(self, _: Any = None) -> None:
"""Emit changed signal when any child widget changes."""
self.changed.emit(self.value)

def get_value(self) -> _M:
"""Construct a model instance from child widget values."""
values: dict[str, Any] = {}
for w in self._list:
if not w.name or w.gui_only:
continue
if hasattr(w, "value"):
values[w.name] = w.value
return self._value_type(**values)

def set_value(self, value: _M | None) -> None:
"""Distribute model instance values to child widgets."""
if value is None:
return

vals = self._get_values(value)
if vals is None:
return
with self.changed.blocked():
for w in self._list:
if w.name and hasattr(w, "value") and w.name in vals:
w.value = vals[w.name]

def __repr__(self) -> str:
"""Return string representation."""
return f"<{self.__class__.__name__} value_type={self._value_type.__name__!r}>"

@staticmethod
def _get_values(obj: Any) -> dict | None:
"""Return a dict of values from an object.

The object can be a dataclass, attrs, pydantic object or named tuple.
"""
if isinstance(obj, dict):
return obj

# named tuple
if isinstance(obj, tuple) and hasattr(obj, "_asdict"):
return cast("dict", obj._asdict())

import dataclasses

# dataclass
if dataclasses.is_dataclass(type(obj)):
return dataclasses.asdict(obj)

# attrs
attr = sys.modules.get("attr")
if attr is not None and attr.has(obj):
return cast("dict", attr.asdict(obj))

# pydantic models
if hasattr(obj, "model_dump"):
return cast("dict", obj.model_dump())
elif hasattr(obj, "dict"):
return cast("dict", obj.dict())

return None


@backend_widget
class ToolBar(ToolBarWidget):
"""Toolbar that contains a set of controls."""
Expand Down
Loading
Loading