Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions ipykernel/displayhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import builtins
import sys
import threading
import typing as t
from contextvars import ContextVar

from IPython.core.displayhook import DisplayHook
from jupyter_client.session import Session, extract_header
from traitlets import Any, Instance
from traitlets import Any, Instance, default

from ipykernel.jsonutil import encode_images, json_clean

Expand Down Expand Up @@ -80,13 +81,41 @@ class ZMQShellDisplayHook(DisplayHook):
session = Instance(Session, allow_none=True)
pub_socket = Any(allow_none=True)
_parent_header: ContextVar[dict[str, Any]]
_thread_local = Any()
msg: dict[str, t.Any] | None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._parent_header = ContextVar("parent_header")
self._parent_header.set({})

@default("_thread_local")
def _default_thread_local(self):
return threading.local()

@property
def _hooks(self):
if not hasattr(self._thread_local, "hooks"):
self._thread_local.hooks = []
return self._thread_local.hooks

def register_hook(self, hook):
"""Register a transform hook on the execute_result message.

Mirrors ``ZMQDisplayPublisher.register_hook``. Each hook receives the
outbound message dict and must return either a (possibly mutated)
message dict to continue the chain, or ``None`` to suppress the send.
"""
self._hooks.append(hook)

def unregister_hook(self, hook):
"""Remove a previously registered hook. Returns True on success."""
try:
self._hooks.remove(hook)
return True
except ValueError:
return False

@property
def parent_header(self):
try:
Expand Down Expand Up @@ -124,9 +153,22 @@ def write_format_data(self, format_dict, md_dict=None):
self.msg["content"]["metadata"] = md_dict

def finish_displayhook(self):
"""Finish up all displayhook activities."""
"""Finish up all displayhook activities.

Runs the registered hook chain before ``session.send``. Each hook
either returns a message (to continue) or ``None`` (to suppress the
send). This mirrors the transform pipeline on
``ZMQDisplayPublisher.publish`` so a single hook implementation can
attach to both the ``display_data`` and ``execute_result`` paths.
"""
sys.stdout.flush()
sys.stderr.flush()
if self.msg and self.msg["content"]["data"] and self.session:
self.session.send(self.pub_socket, self.msg, ident=self.topic)
msg = self.msg
for hook in self._hooks:
msg = hook(msg)
if msg is None:
self.msg = None
return
self.session.send(self.pub_socket, msg, ident=self.topic)
self.msg = None
207 changes: 207 additions & 0 deletions tests/test_displayhook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""Tests for the ZMQ execute_result displayhook."""

# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

import unittest
from queue import Queue
from threading import Thread

import zmq
from IPython.core.interactiveshell import InteractiveShell
from jupyter_client.session import Session
from traitlets import Int

from ipykernel.displayhook import ZMQShellDisplayHook


class NoReturnHook:
call_count = 0

def __call__(self, msg):
self.call_count += 1


class ReturnHook(NoReturnHook):
def __call__(self, msg):
super().__call__(msg)
return msg


class MutatingHook(NoReturnHook):
"""Attaches a buffer to the message and returns it."""

def __call__(self, msg):
super().__call__(msg)
msg.setdefault("buffers", []).append(b"arrow-bytes")
return msg


class CounterSession(Session):
send_count = Int(0)
last_msg = None

def send(self, *args, **kwargs):
self.send_count += 1
# args: (stream, msg_or_type, ...)
if len(args) >= 2:
self.last_msg = args[1]
return super().send(*args, **kwargs)


def _drive(hook, data=None):
"""Run a single execute_result emission through the hook."""
if data is None:
data = {"text/plain": "1"}
hook.start_displayhook()
hook.write_format_data(data, {})
hook.finish_displayhook()


class ZMQShellDisplayHookTests(unittest.TestCase):
def setUp(self):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.session = CounterSession()
self.shell = InteractiveShell()
self.disp = ZMQShellDisplayHook(shell=self.shell)
self.disp.session = self.session
self.disp.pub_socket = self.socket

def tearDown(self):
self.socket.close()
self.context.term()

def test_no_hooks_sends_message(self):
"""With no hooks registered, finish_displayhook still calls send."""
assert self.disp._hooks == []
_drive(self.disp)
assert self.session.send_count == 1

def test_thread_local_hooks(self):
"""_hooks is thread-local: registering on one thread doesn't leak."""
assert self.disp._hooks == []

def hook(msg):
return msg

self.disp.register_hook(hook)
assert self.disp._hooks == [hook]

q: Queue = Queue()

def read_other_thread():
q.put(self.disp._hooks)

t = Thread(target=read_other_thread)
t.start()
other = q.get(timeout=10)
t.join()
assert other == []

def test_hook_returning_none_halts_send(self):
"""A hook that returns None suppresses session.send."""
hook = NoReturnHook()
self.disp.register_hook(hook)

_drive(self.disp)

assert hook.call_count == 1
assert self.session.send_count == 0
assert self.disp.msg is None

def test_hook_returning_msg_calls_send(self):
"""A hook that returns the message lets it through to send."""
hook = ReturnHook()
self.disp.register_hook(hook)

_drive(self.disp)

assert hook.call_count == 1
assert self.session.send_count == 1

def test_hook_can_mutate_message(self):
"""A hook can attach buffers (the original motivation)."""
hook = MutatingHook()
self.disp.register_hook(hook)

_drive(self.disp)

assert hook.call_count == 1
assert self.session.send_count == 1
sent = self.session.last_msg
assert sent is not None
assert sent.get("buffers") == [b"arrow-bytes"]

def test_hook_chain_short_circuits(self):
"""If an early hook returns None, later hooks are not called."""
first = NoReturnHook()
second = NoReturnHook()
self.disp.register_hook(first)
self.disp.register_hook(second)

_drive(self.disp)

assert first.call_count == 1
assert second.call_count == 0
assert self.session.send_count == 0

def test_hook_chain_threads_message(self):
"""Each hook receives the message returned by the previous hook."""
observed: list[dict] = []

def first(msg):
msg["content"]["metadata"]["seen_by_first"] = True
return msg

def second(msg):
observed.append(msg)
return msg

self.disp.register_hook(first)
self.disp.register_hook(second)

_drive(self.disp)

assert len(observed) == 1
assert observed[0]["content"]["metadata"].get("seen_by_first") is True
assert self.session.send_count == 1

def test_unregister_hook(self):
"""Unregistered hooks no longer run; double-unregister returns False."""
hook = NoReturnHook()
self.disp.register_hook(hook)

_drive(self.disp)
assert hook.call_count == 1
assert self.session.send_count == 0

first = self.disp.unregister_hook(hook)
assert bool(first)

_drive(self.disp)
# Hook didn't run again, but the message went out via session.send.
assert hook.call_count == 1
assert self.session.send_count == 1

# Unregistering an unknown hook returns False.
assert not bool(self.disp.unregister_hook(hook))

def test_empty_data_skips_send_and_hooks(self):
"""The existing guard: if content.data is empty, don't send or hook."""
hook = ReturnHook()
self.disp.register_hook(hook)

# start_displayhook initializes self.msg with empty data; if we never
# call write_format_data, the data dict stays empty and finish should
# short-circuit before calling either hooks or send.
self.disp.start_displayhook()
self.disp.finish_displayhook()

assert hook.call_count == 0
assert self.session.send_count == 0
assert self.disp.msg is None


if __name__ == "__main__":
unittest.main()
Loading