Source code for mdadash.backend.widgets.base
"""
Base Class for Widgets and Widget Manager
"""
import inspect
from abc import ABC
from contextlib import contextmanager
from threading import Thread
from typing import Any
from uuid import uuid1
import IPython
import MDAnalysis as mda
from joblib import Parallel
[docs]
class WidgetBase(ABC):
"""WidgetBase
This is the base class for all widgets.
"""
_run_frequency = "per-frame"
_run_mode = "serial"
def __init_subclass__(cls, **kwargs):
"""Register any derived class with the WidgetManager"""
super().__init_subclass__(**kwargs)
WidgetManager.register_class(cls)
def __init__(self):
self.uid = None
self.u = None
self.uuid = None
self._input_errors = {}
def _set_universe(self, u: mda.Universe):
"""Internal: Set the universe"""
self.u = u
def _get_inputs(self):
"""Internal: Get the current instance inputs"""
inputs = getattr(self, "_inputs")
for _input in inputs:
# set the value and error states
_input["value"] = getattr(self, _input["attribute"], None)
_input["error"] = self._input_errors.get(_input["attribute"], None)
return inputs
def _set_input_state(self, attribute: str, error: str = None):
"""Internal: Set input attribute validation state"""
if error is not None:
self._input_errors[attribute] = error
else:
if attribute in self._input_errors:
del self._input_errors[attribute]
[docs]
def post_connect(self) -> None:
"""post_connect handler
This handler is called after connecting to the simulation
"""
[docs]
def post_disconnect(self) -> None:
"""post_disconnect handler
This handler is called after disconnection from simulation
"""
[docs]
def post_pause(self) -> None:
"""post_pause handler
This handler is called after user pauses trajectory iteration
"""
[docs]
def pre_resume(self) -> None:
"""pre_resume handler
This handler is called after user resumes trajectory iteration
"""
[docs]
def on_input_change(self, attribute: str, old_value: Any, new_value: Any) -> None:
"""on_input_change handler
This handler is called after a widget input has changed.
Validations can be performed in this handler and any exceptions
raised with messages will show up as errors in the UI
Parameters
----------
attribute: str
The input attribute that changed
old_value: Any
The previous value held by this attribute
new_value: Any
The current value of this attribute
"""
[docs]
def run_per_frame(self) -> None:
"""run_per_frame handler
This handler is called during every trajectory iteration if the run
frequency is set to `per-frame` (`_run_frequency='per-frame'`). The
trajectory timestep is the current frame.
"""
[docs]
def run_batch(self, batch_size: int) -> None:
"""run_batch handler
This handler is called every time a new batch of timesteps is full
and ready to be run if the run frequency is set to `batch`
(`_run_frequency='batch'`).
Parameters
----------
batch_size: int
The batch size, which indicates the number of buffered timesteps
available in the buffer is passed
"""
[docs]
def get_parallel_job(self, batch_size: int) -> Any:
"""get_parallel_job handler
This handler is called if the run mode is set to `parallel`
(`_run_mode='parallel'`) to get the parallel job to run.
Parameters
----------
batch_size: int
The configured batch size to use in the parallel jobs
if their run frequency is `batch`
Returns
-------
job: Any
A joblib's delayed function
"""
[docs]
def apply_parallel_results(self, values: Any) -> None:
"""apply_parallel_results handler
This handler is called with the results of the parallel job
execution. This is invoked when the run mode is set to `parallel`
(`_run_mode='parallel'`) after the parallel job completes.
Parameters
----------
values: Any
The results returned by the parallel job run
"""
[docs]
class WidgetManager:
"""WidgetManager
This is the manager that manager all widgets.
"""
_classes = {}
_instances = {}
def __init__(self):
self.n_jobs = 2
self._patch_IMDReader()
@property
def classes(self):
"""Dictionary of registered widget classes keyed by widget name"""
return self._classes
@property
def instances(self):
"""Dictionary of widget instances keyed by widget uuid"""
return self._instances
[docs]
@classmethod
def register_class(cls, widget_class: WidgetBase) -> None:
"""Register widget class
Parameters
----------
widget_class
A widget class that is derived from WidgetBase
"""
cls._validate_widget_class(widget_class)
widget_name = widget_class.name
if widget_name in cls._classes:
raise ValueError(f"Widget name '{widget_name}' already registered")
cls._classes[widget_name] = widget_class
@classmethod
def _validate_widget_class(cls, widget_class: WidgetBase) -> None:
"""Internal: Method to validate a widget class"""
if not issubclass(widget_class, WidgetBase):
raise ValueError(f"{widget_class} is not a widget class")
if not hasattr(widget_class, "name"):
raise ValueError("name not specified in widget class")
# check for one of the run methods to exist with correct params
run_methods = {
"run_per_frame": 1,
"run_batch": 2,
}
has_valid_run_method = False
for run_method, n_params in run_methods.items():
method = getattr(widget_class, run_method)
if method == getattr(WidgetBase, run_method):
continue
if not callable(method):
continue
signature = inspect.signature(method)
has_valid_run_method = len(signature.parameters.values()) == n_params
break
if not has_valid_run_method:
raise ValueError("run method not found in class")
# TODO: add more validations
def _invoke_widget_lifecyle_method(self, widget: WidgetBase, method: str) -> None:
"""Internal: Invoke the lifecycle method if implemented"""
if widget._input_errors:
# lifecycle methods not invoked when
# there are input errors
return
if hasattr(widget, method):
handler = getattr(widget, method)
if callable(handler):
try:
handler()
# pylint: disable=broad-exception-caught
except Exception as e: # pragma: no cover
print(e)
def _set_widget_universe(
self, widget: WidgetBase, uid: int, u: mda.Universe
) -> None:
"""Internal: Set the universe for instance"""
if widget.uid == uid:
widget._set_universe(u)
# invoke the post_connect handler
self._invoke_widget_lifecyle_method(widget, "post_connect")
def _set_universe(self, uid: int, u: mda.Universe, uuid: str = None) -> None:
"""Internal: Set the universe for all or given widget"""
if uuid is None:
for widget in self.instances.values():
self._set_widget_universe(widget, uid, u)
else:
widget = self.instances[uuid]
self._set_widget_universe(widget, uid, u)
def _invoke_lifecycle_method(self, method: str) -> None:
"""Internal: Invoke given lifecycle method for all instances"""
for widget in self.instances.values():
self._invoke_widget_lifecyle_method(widget, method)
[docs]
def add_widget_instance(self, uid: int, widget_name: str) -> str | None:
"""Add widget instance
Add a widget instance based on the widget name already
registered with the manager.
Parameters
----------
uid: int
Universe ID (index into universes array)
widget_name: str
Name of the widget class registered
Returns
-------
uuid of instance added or None
"""
if widget_name in self.classes:
widget_class = self.classes[widget_name]
uuid = str(uuid1())
instance = widget_class()
setattr(instance, "uid", uid)
setattr(instance, "uuid", uuid)
self.instances[uuid] = instance
return uuid
return None
[docs]
def duplicate_widget_instance(self, uid: int, uuid: str) -> str | None:
"""Duplicate widget instance
Duplicate widget instance based on existing instance uuid
Parameters
----------
uid: int
Universe ID (index into universes array)
uuid: str
The uuid of the instance to be duplicated
Returns
-------
uuid of new instance created
"""
# get existing instance
instance = self.instances[uuid]
# duplicate instance
widget_class = instance.__class__
new_instance = widget_class()
setattr(new_instance, "uid", uid)
# set inputs for new instance
inputs = instance._get_inputs()
for _input in inputs:
attribute = _input["attribute"]
value = _input["value"]
setattr(new_instance, attribute, value)
# add new instance to instances list
new_uuid = str(uuid1())
setattr(new_instance, "uuid", new_uuid)
self.instances[new_uuid] = new_instance
return new_uuid
[docs]
def delete_widget_instance(self, uuid: str) -> str | None:
"""Remove widget instance
Remove widget instance based on uuid returned during
the instance creation using :meth:`add_widget_instance`
Parameters
----------
uuid: str
The uuid of the instance to be removed
Returns
-------
uuid of instance deleted or None
"""
if uuid in self.instances:
del self.instances[uuid]
return uuid
return None
[docs]
def get_inputs(self, uuid: str) -> list:
"""Get inputs for widget instance
Parameters
----------
uuid: str
The uuid of the widget instance
Returns
-------
response: list
List of input dicts
"""
widget = self.instances[uuid]
return widget._get_inputs()
[docs]
def set_input(self, uuid: str, attribute: str, value: Any) -> bool:
"""Set input for a widget instance attribute
Parameters
----------
uuid: str
The uuid of the widget instance
attribute: str
The input attribute to set
value: Any
The value to set for the attribute
Returns
-------
response: bool
True or False to indicate if input validation succeeded
"""
widget = self.instances[uuid]
old_value = getattr(widget, attribute, value)
old_type = type(old_value)
# set input using the same existing type
setattr(widget, attribute, old_type(value))
try:
widget.on_input_change(attribute, old_value, value)
widget._set_input_state(attribute)
return True
except Exception as e: # pylint: disable=broad-exception-caught
widget._set_input_state(attribute, str(e))
return False
[docs]
def update_n_jobs(self, data: dict) -> None:
"""Update n_jobs for joblib.Parallel"""
self.n_jobs = data["n_jobs"]
def _patch_IMDReader(self):
"""Internal: Patch `IMDReader` to make it serializable"""
# pylint: disable=import-outside-toplevel
from MDAnalysis.coordinates.IMD import IMDReader
def custom_getstate(self):
state = self.__dict__.copy()
del state["_imdclient"]
return state
def custom_setstate(self, state):
self.__dict__.update(state)
self._imdclient = None
IMDReader.__setstate__ = custom_setstate
IMDReader.__getstate__ = custom_getstate
def _run_parallel_jobs(self, parallel_widgets, batch_size, parallel_results):
"""Internal: Run parallel jobs using joblib.Parallel"""
parallel_jobs = []
for widget in parallel_widgets:
parallel_jobs.append(widget.get_parallel_job(batch_size))
try:
results = Parallel(n_jobs=self.n_jobs, initializer=self._patch_IMDReader)(
parallel_jobs
)
parallel_results.extend(results)
# pylint: disable=broad-exception-caught
except Exception as e: # pragma: no cover
print(f"parallel run failed with '{e}'")
[docs]
def run_widgets(self, uid: int, batch_ready: bool, batch_size: int) -> None:
"""Run widget instances
Parameters
----------
uid: int
Universe ID (index into universes array)
batch_ready: bool
Flag indicating if a batch of timesteps is full
batch_size: int
Size of the batch
"""
# collect widgets that need to be run
parallel_widgets = []
serial_widgets = []
for widget in self.instances.values():
# only run widget if there are no input errors
if widget.uid != uid or widget._input_errors:
continue
if widget._run_mode == "parallel":
if widget._run_frequency == "per-frame" or batch_ready:
parallel_widgets.append(widget)
else:
serial_widgets.append(widget)
# run parallel widgets in separate thread
if parallel_widgets:
parallel_results = []
parallel_thread = Thread(
target=self._run_parallel_jobs,
args=(
parallel_widgets,
batch_size,
parallel_results,
),
)
parallel_thread.start()
# run serial widgets
for widget in serial_widgets:
with _widget_uuid_in_metadata(widget.uuid):
try:
if widget._run_frequency == "per-frame":
widget.run_per_frame()
elif batch_ready:
widget.run_batch(batch_size)
# pylint: disable=broad-exception-caught
except Exception as e: # pragma: no cover
print(f"{widget.uuid} serial run failed with '{e}'")
# apply parallel results back
if parallel_widgets:
# wait for all parallel jobs to be done
parallel_thread.join()
for i, widget in enumerate(parallel_widgets):
with _widget_uuid_in_metadata(widget.uuid):
widget.apply_parallel_results(parallel_results[i])
@contextmanager
def _widget_uuid_in_metadata(uuid: str):
"""Internal: Add uuid in content metadata sent from kernel"""
session = IPython.get_ipython().kernel.session
original_send = session.send
def patched_send(stream, msg_type_or_msg, *args, **kwargs):
msg_type = None
content = None
if isinstance(msg_type_or_msg, str):
msg_type = msg_type_or_msg
if len(args) > 0 and isinstance(args[0], dict):
content = args[0]
elif isinstance(msg_type_or_msg, dict):
msg_type = msg_type_or_msg.get("msg_type") or msg_type_or_msg.get(
"header", {}
).get("msg_type")
content = msg_type_or_msg.get("content", msg_type_or_msg)
# Add widget uuid to metadata
if msg_type == "display_data" and content is not None:
content["metadata"]["widget_uuid"] = uuid
return original_send(stream, msg_type_or_msg, *args, **kwargs)
session.send = patched_send
try:
yield
finally:
session.send = original_send