"""This module contains the schema classes for the medrecord module."""
from __future__ import annotations
from enum import Enum, auto
from typing import (
TYPE_CHECKING,
Dict,
List,
Literal,
Optional,
Tuple,
TypeAlias,
Union,
overload,
)
from medmodels._medmodels import (
PyAttributeDataType,
PyAttributeType,
PyGroupSchema,
PySchema,
PySchemaType,
)
from medmodels.medrecord.datatype import (
DataType,
DateTime,
Duration,
Float,
Int,
Null,
Option,
)
from medmodels.medrecord.datatype import Union as DataTypeUnion
from medmodels.medrecord.types import (
Attributes,
EdgeIndex,
MedRecordAttribute,
NodeIndex,
)
if TYPE_CHECKING:
from medmodels.medrecord.medrecord import MedRecord
from medmodels.medrecord.types import Group
[docs]
class AttributeType(Enum):
"""Enumeration of attribute types."""
Categorical = auto()
Continuous = auto()
Temporal = auto()
Unstructured = auto()
@staticmethod
def _from_py_attribute_type(py_attribute_type: PyAttributeType) -> AttributeType:
"""Converts a PyAttributeType to an AttributeType.
Args:
py_attribute_type (PyAttributeType): The PyAttributeType to convert.
Returns:
AttributeType: The converted AttributeType.
"""
if py_attribute_type == PyAttributeType.Categorical:
return AttributeType.Categorical
if py_attribute_type == PyAttributeType.Continuous:
return AttributeType.Continuous
if py_attribute_type == PyAttributeType.Temporal:
return AttributeType.Temporal
if py_attribute_type == PyAttributeType.Unstructured:
return AttributeType.Unstructured
msg = "Should never be reached"
raise NotImplementedError(msg)
[docs]
@staticmethod
def infer(data_type: DataType) -> AttributeType:
"""Infers the attribute type from the data type.
Args:
data_type (DataType): The data type to infer the attribute type from.
Returns:
AttributeType: The inferred attribute type.
"""
return AttributeType._from_py_attribute_type(
PyAttributeType.infer(data_type._inner())
)
def _into_py_attribute_type(self) -> PyAttributeType:
"""Converts an AttributeType to a PyAttributeType.
Returns:
PyAttributeType: The converted PyAttributeType.
"""
if self == AttributeType.Categorical:
return PyAttributeType.Categorical
if self == AttributeType.Continuous:
return PyAttributeType.Continuous
if self == AttributeType.Temporal:
return PyAttributeType.Temporal
if self == AttributeType.Unstructured:
return PyAttributeType.Unstructured
msg = "Should never be reached"
raise NotImplementedError(msg)
[docs]
def __repr__(self) -> str:
"""Returns a string representation of the AttributeType instance.
Returns:
str: String representation of the attribute type.
"""
return f"AttributeType.{self.name}"
[docs]
def __str__(self) -> str:
"""Returns a string representation of the AttributeType instance.
Returns:
str: String representation of the attribute type.
"""
return self.name
[docs]
def __hash__(self) -> int:
"""Returns the hash of the AttributeType instance.
Returns:
int: The hash of the AttributeType instance.
"""
return hash(self.name)
[docs]
def __eq__(self, value: object) -> bool:
"""Compares the AttributeType instance to another object for equality.
Args:
value (object): The object to compare against.
Returns:
bool: True if the objects are equal, False otherwise.
"""
if isinstance(value, PyAttributeType):
return self._into_py_attribute_type() == value
if isinstance(value, AttributeType):
return str(self) == str(value)
return False
CategoricalType: TypeAlias = DataType
CategoricalPair: TypeAlias = Tuple[CategoricalType, Literal[AttributeType.Categorical]]
ContinuousType: TypeAlias = Union[
Int,
Float,
Null,
Option["ContinuousType"],
DataTypeUnion["ContinuousType", "ContinuousType"],
]
ContinuousPair: TypeAlias = Tuple[ContinuousType, Literal[AttributeType.Continuous]]
TemporalType = Union[
DateTime,
Duration,
Null,
Option["TemporalType"],
DataTypeUnion["TemporalType", "TemporalType"],
]
TemporalPair: TypeAlias = Tuple[TemporalType, Literal[AttributeType.Temporal]]
UnstructuredType: TypeAlias = DataType
UnstructuredPair: TypeAlias = Tuple[
UnstructuredType, Literal[AttributeType.Unstructured]
]
AttributeDataType: TypeAlias = Union[
CategoricalPair, ContinuousPair, TemporalPair, UnstructuredPair
]
AttributesSchema: TypeAlias = Dict[MedRecordAttribute, AttributeDataType]
[docs]
class GroupSchema:
"""A schema for a group of nodes and edges."""
_group_schema: PyGroupSchema
def __init__(
self,
*,
nodes: Optional[
Dict[
MedRecordAttribute,
Union[DataType, AttributeDataType],
],
] = None,
edges: Optional[
Dict[
MedRecordAttribute,
Union[DataType, AttributeDataType],
],
] = None,
) -> None:
"""Initializes a new instance of GroupSchema.
Args:
nodes (Dict[MedRecordAttribute, Union[DataType, AttributeDataType]]):
A dictionary mapping node attributes to their data
types and optional attribute types. Defaults to an empty dictionary.
When no attribute type is provided, it is inferred from the data type.
edges (Dict[MedRecordAttribute, Union[DataType, AttributeDataType]]):
A dictionary mapping edge attributes to their data types and
optional attribute types. Defaults to an empty dictionary.
When no attribute type is provided, it is inferred from the data type.
"""
if edges is None:
edges = {}
if nodes is None:
nodes = {}
def _convert_input(
input: Union[DataType, AttributeDataType],
) -> PyAttributeDataType:
if isinstance(input, tuple):
return PyAttributeDataType(
input[0]._inner(), input[1]._into_py_attribute_type()
)
return PyAttributeDataType(
input._inner(), PyAttributeType.infer(input._inner())
)
self._group_schema = PyGroupSchema(
nodes={x: _convert_input(nodes[x]) for x in nodes},
edges={x: _convert_input(edges[x]) for x in edges},
)
@classmethod
def _from_py_group_schema(cls, group_schema: PyGroupSchema) -> GroupSchema:
"""Creates a GroupSchema instance from an existing PyGroupSchema.
Args:
group_schema (PyGroupSchema): The PyGroupSchema instance to convert.
Returns:
GroupSchema: A new GroupSchema instance.
"""
new_group_schema = cls()
new_group_schema._group_schema = group_schema
return new_group_schema
@property
def nodes(self) -> AttributesSchema:
"""Returns the node attributes in the GroupSchema instance.
Returns:
AttributesSchema: An AttributesSchema object containing the node attributes
and their data types.
"""
def _convert_node(
input: PyAttributeDataType,
) -> AttributeDataType:
# SAFETY: The typing is guaranteed to be correct
return (
DataType._from_py_data_type(input.data_type),
AttributeType._from_py_attribute_type(input.attribute_type),
) # pyright: ignore[reportReturnType]
return {
x: _convert_node(self._group_schema.nodes[x])
for x in self._group_schema.nodes
}
@property
def edges(self) -> AttributesSchema:
"""Returns the edge attributes in the GroupSchema instance.
Returns:
AttributesSchema: An AttributesSchema object containing the edge attributes
and their data types.
"""
def _convert_edge(
input: PyAttributeDataType,
) -> AttributeDataType:
# SAFETY: The typing is guaranteed to be correct
return (
DataType._from_py_data_type(input.data_type),
AttributeType._from_py_attribute_type(input.attribute_type),
) # pyright: ignore[reportReturnType]
return {
x: _convert_edge(self._group_schema.edges[x])
for x in self._group_schema.edges
}
[docs]
def validate_node(self, index: NodeIndex, attributes: Attributes) -> None:
"""Validates the attributes of a node against the schema.
Args:
index (NodeIndex): The index of the node.
attributes (Attributes): The attributes of the node.
"""
self._group_schema.validate_node(index, attributes)
[docs]
def validate_edge(self, index: EdgeIndex, attributes: Attributes) -> None:
"""Validates the attributes of an edge against the schema.
Args:
index (EdgeIndex): The index of the edge.
attributes (Attributes): The attributes of the edge.
"""
self._group_schema.validate_edge(index, attributes)
[docs]
class SchemaType(Enum):
"""Enumeration of schema types."""
Provided = auto()
Inferred = auto()
@staticmethod
def _from_py_schema_type(py_schema_type: PySchemaType) -> SchemaType:
"""Converts a PySchemaType to a SchemaType.
Args:
py_schema_type (PySchemaType): The PySchemaType to convert.
Returns:
SchemaType: The converted SchemaType.
"""
if py_schema_type == PySchemaType.Provided:
return SchemaType.Provided
if py_schema_type == PySchemaType.Inferred:
return SchemaType.Inferred
msg = "Should never be reached"
raise NotImplementedError(msg)
def _into_py_schema_type(self) -> PySchemaType:
"""Converts a SchemaType to a PySchemaType.
Returns:
PySchemaType: The converted PySchemaType.
"""
if self == SchemaType.Provided:
return PySchemaType.Provided
if self == SchemaType.Inferred:
return PySchemaType.Inferred
msg = "Should never be reached"
raise NotImplementedError(msg)
[docs]
class Schema:
"""A schema for a collection of groups."""
_schema: PySchema
def __init__(
self,
*,
groups: Optional[Dict[Group, GroupSchema]] = None,
ungrouped: Optional[GroupSchema] = None,
schema_type: Optional[SchemaType] = None,
) -> None:
"""Initializes a new instance of Schema.
Args:
groups (Dict[Group, GroupSchema], optional): A dictionary of group names
to their schemas. Defaults to None.
ungrouped (Optional[GroupSchema], optional): The group schema for all nodes
not in a group. If not provided, an empty group schema is used.
Defaults to None.
schema_type (Optional[SchemaType], optional): The type of the schema.
If not provided, the schema is of type provided. Defaults to None.
"""
if not ungrouped:
ungrouped = GroupSchema()
if groups is None:
groups = {}
if schema_type:
self._schema = PySchema(
groups={x: groups[x]._group_schema for x in groups},
ungrouped=ungrouped._group_schema,
schema_type=schema_type._into_py_schema_type(),
)
else:
self._schema = PySchema(
groups={x: groups[x]._group_schema for x in groups},
ungrouped=ungrouped._group_schema,
)
[docs]
@classmethod
def infer(cls, medrecord: MedRecord) -> Schema:
"""Infers a schema from a MedRecord instance.
Args:
medrecord (MedRecord): The MedRecord instance to infer the schema from.
Returns:
Schema: The inferred schema.
"""
new_schema = cls()
new_schema._schema = PySchema.infer(medrecord._medrecord)
return new_schema
@classmethod
def _from_py_schema(cls, schema: PySchema) -> Schema:
"""Creates a Schema instance from an existing PySchema.
Args:
schema (PySchema): The PySchema instance to convert.
Returns:
Schema: A new Schema instance.
"""
new_schema = cls()
new_schema._schema = schema
return new_schema
@property
def groups(self) -> List[Group]:
"""Lists all the groups in the Schema instance.
Returns:
List[Group]: A list of groups.
"""
return self._schema.groups
[docs]
def group(self, group: Group) -> GroupSchema:
"""Retrieves the schema for a specific group.
Args:
group (Group): The name of the group.
Returns:
GroupSchema: The schema for the specified group.
Raises:
ValueError: If the group does not exist in the schema.
""" # noqa: DOC502
return GroupSchema._from_py_group_schema(self._schema.group(group))
@property
def ungrouped(self) -> GroupSchema:
"""Retrieves the group schema for all ungrouped nodes and edges.
Returns:
GroupSchema: The ungrouped group schema.
"""
return GroupSchema._from_py_group_schema(self._schema.ungrouped)
@property
def schema_type(self) -> SchemaType:
"""Retrieves the schema type.
Returns:
SchemaType: The schema type.
"""
return SchemaType._from_py_schema_type(self._schema.schema_type)
[docs]
def validate_node(
self, index: NodeIndex, attributes: Attributes, group: Optional[Group] = None
) -> None:
"""Validates the attributes of a node against the schema.
Args:
index (NodeIndex): The index of the node.
attributes (Attributes): The attributes of the node.
group (Optional[Group], optional): The group to validate the node against.
If not provided, the ungrouped schema is used. Defaults to None.
"""
self._schema.validate_node(index, attributes, group)
[docs]
def validate_edge(
self, index: EdgeIndex, attributes: Attributes, group: Optional[Group] = None
) -> None:
"""Validates the attributes of an edge against the schema.
Args:
index (EdgeIndex): The index of the edge.
attributes (Attributes): The attributes of the edge.
group (Optional[Group], optional): The group to validate the edge against.
If not provided, the ungrouped schema is used. Defaults to None.
"""
self._schema.validate_edge(index, attributes, group)
@overload
def set_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[
Literal[AttributeType.Categorical, AttributeType.Unstructured]
] = None,
group: Optional[Group] = None,
) -> None: ...
@overload
def set_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: ContinuousType,
attribute_type: Literal[AttributeType.Continuous],
group: Optional[Group] = None,
) -> None: ...
@overload
def set_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: TemporalType,
attribute_type: Literal[AttributeType.Temporal],
group: Optional[Group] = None,
) -> None: ...
[docs]
def set_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[AttributeType] = None,
group: Optional[Group] = None,
) -> None:
"""Sets the data type and attribute type of a node attribute.
If a data type for the attribute already exists, it is overwritten.
Args:
attribute (MedRecordAttribute): The name of the attribute.
data_type (DataType): The data type of the attribute.
attribute_type (Optional[AttributeType], optional): The attribute type of
the attribute. If not provided, the attribute type is inferred
from the data type. Defaults to None.
group (Optional[Group], optional): The group to set the attribute for.
If no schema for the group exists, a new schema is created.
If not provided, the ungrouped schema is used. Defaults to None.
"""
if not attribute_type:
attribute_type = AttributeType.infer(data_type)
self._schema.set_node_attribute(
attribute,
data_type._inner(),
attribute_type._into_py_attribute_type(),
group,
)
@overload
def set_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[
Literal[AttributeType.Categorical, AttributeType.Unstructured]
] = None,
group: Optional[Group] = None,
) -> None: ...
@overload
def set_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: ContinuousType,
attribute_type: Literal[AttributeType.Continuous],
group: Optional[Group] = None,
) -> None: ...
@overload
def set_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: TemporalType,
attribute_type: Literal[AttributeType.Temporal],
group: Optional[Group] = None,
) -> None: ...
[docs]
def set_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[AttributeType] = None,
group: Optional[Group] = None,
) -> None:
"""Sets the data type and attribute type of an edge attribute.
If a data type for the attribute already exists, it is overwritten.
Args:
attribute (MedRecordAttribute): The name of the attribute.
data_type (DataType): The data type of the attribute.
attribute_type (Optional[AttributeType], optional): The attribute type of
the attribute. If not provided, the attribute type is inferred
from the data type. Defaults to None.
group (Optional[Group], optional): The group to set the attribute for.
If no schema for this group exists, a new schema is created.
If not provided, the ungrouped schema is used. Defaults to None.
"""
if not attribute_type:
attribute_type = AttributeType.infer(data_type)
self._schema.set_edge_attribute(
attribute,
data_type._inner(),
attribute_type._into_py_attribute_type(),
group,
)
@overload
def update_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[
Literal[AttributeType.Categorical, AttributeType.Unstructured]
] = None,
group: Optional[Group] = None,
) -> None: ...
@overload
def update_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: ContinuousType,
attribute_type: Literal[AttributeType.Continuous],
group: Optional[Group] = None,
) -> None: ...
@overload
def update_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: TemporalType,
attribute_type: Literal[AttributeType.Temporal],
group: Optional[Group] = None,
) -> None: ...
[docs]
def update_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[AttributeType] = None,
group: Optional[Group] = None,
) -> None:
"""Updates the data type and attribute type of a node attribute.
If a data type for the attribute already exists, it is merged
with the new data type.
Args:
attribute (MedRecordAttribute): The name of the attribute.
data_type (DataType): The data type of the attribute.
attribute_type (Optional[AttributeType], optional): The attribute type of
the attribute. If not provided, the attribute type is inferred
from the data type. Defaults to None.
group (Optional[Group], optional): The group to update the attribute for.
If no schema for this group exists, a new schema is created.
If not provided, the ungrouped schema is used. Defaults to None.
"""
if not attribute_type:
attribute_type = AttributeType.infer(data_type)
self._schema.update_node_attribute(
attribute,
data_type._inner(),
attribute_type._into_py_attribute_type(),
group,
)
@overload
def update_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[
Literal[AttributeType.Categorical, AttributeType.Unstructured]
] = None,
group: Optional[Group] = None,
) -> None: ...
@overload
def update_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: ContinuousType,
attribute_type: Literal[AttributeType.Continuous],
group: Optional[Group] = None,
) -> None: ...
@overload
def update_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: TemporalType,
attribute_type: Literal[AttributeType.Temporal],
group: Optional[Group] = None,
) -> None: ...
[docs]
def update_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[AttributeType] = None,
group: Optional[Group] = None,
) -> None:
"""Updates the data type and attribute type of an edge attribute.
If a data type for the attribute already exists, it is merged
with the new data type.
Args:
attribute (MedRecordAttribute): The name of the attribute.
data_type (DataType): The data type of the attribute.
attribute_type (Optional[AttributeType], optional): The attribute type of
the attribute. If not provided, the attribute type is inferred
from the data type. Defaults to None.
group (Optional[Group], optional): The group to update the attribute for.
If no schema for this group exists, a new schema is created.
If not provided, the ungrouped schema is used. Defaults to None.
"""
if not attribute_type:
attribute_type = AttributeType.infer(data_type)
self._schema.update_edge_attribute(
attribute,
data_type._inner(),
attribute_type._into_py_attribute_type(),
group,
)
[docs]
def remove_node_attribute(
self, attribute: MedRecordAttribute, group: Optional[Group] = None
) -> None:
"""Removes a node attribute from the schema.
Args:
attribute (MedRecordAttribute): The name of the attribute to remove.
group (Optional[Group], optional): The group to remove the attribute from.
If not provided, the ungrouped schema is used. Defaults to None.
"""
self._schema.remove_node_attribute(attribute, group)
[docs]
def remove_edge_attribute(
self, attribute: MedRecordAttribute, group: Optional[Group] = None
) -> None:
"""Removes an edge attribute from the schema.
Args:
attribute (MedRecordAttribute): The name of the attribute to remove.
group (Optional[Group], optional): The group to remove the attribute from.
If not provided, the ungrouped schema is used. Defaults to None.
"""
self._schema.remove_edge_attribute(attribute, group)
[docs]
def add_group(self, group: Group, group_schema: GroupSchema) -> None:
"""Adds a new group to the schema.
Args:
group (Group): The name of the group.
group_schema (GroupSchema): The schema for the group.
"""
self._schema.add_group(group, group_schema._group_schema)
[docs]
def remove_group(self, group: Group) -> None:
"""Removes a group from the schema.
Args:
group (Group): The name of the group to remove.
"""
self._schema.remove_group(group)
[docs]
def freeze(self) -> None:
"""Freezes the schema. No changes are automatically inferred."""
self._schema.freeze()
[docs]
def unfreeze(self) -> None:
"""Unfreezes the schema. Changes are automatically inferred."""
self._schema.unfreeze()