Skip to content

Hook handler

Handler for application hooks. This class manages the execution of application hooks, which can be either SQL files or Python functions.

Source code in pum/hook.py
class HookHandler:
    """Handler for application hooks.
    This class manages the execution of application hooks, which can be either SQL files or Python functions."""

    def __init__(
        self,
        *,
        file: str | Path | None = None,
        code: str | None = None,
        base_path: Path | None = None,
    ) -> None:
        """Initialize a Hook instance.

        Args:
            type: The type of the hook (e.g., "pre", "post").
            file: The file path of the hook.
            code: The SQL code for the hook.

        """
        if file and code:
            raise ValueError("Cannot specify both file and code. Choose one.")

        self.file = file
        self.code = code
        self.hook_instance = None
        self.sys_path_additions = []  # Store paths to add during execution

        if file:
            if isinstance(file, str):
                self.file = Path(file)
            if not self.file.is_absolute():
                if base_path is None:
                    raise ValueError("Base path must be provided for relative file paths.")
                self.file = base_path.absolute() / self.file
            if not self.file.exists():
                raise PumHookError(f"Hook file {self.file} does not exist.")
            if not self.file.is_file():
                raise PumHookError(f"Hook file {self.file} is not a file.")

        if self.file and self.file.suffix == ".py":
            # Support local imports in hook files by adding parent dir and base_path to sys.path
            parent_dir = str(self.file.parent.resolve())

            # Store paths that need to be added for hook execution
            # Always add paths even if already in sys.path - we need them at position 0
            # for priority and we'll track what we added for cleanup
            self.sys_path_additions.append(parent_dir)

            # Also add base_path if provided, to support imports from sibling directories
            if base_path is not None:
                base_path_str = str(base_path.resolve())
                if base_path_str != parent_dir:
                    self.sys_path_additions.append(base_path_str)

            # Add paths for module loading - insert at position 0 for priority
            for path in reversed(self.sys_path_additions):
                sys.path.insert(0, path)

            # Invalidate caches so Python recognizes the new paths
            importlib.invalidate_caches()

            try:
                logger.debug(f"Loading hook from: {self.file}")
                logger.debug(f"sys.path additions: {self.sys_path_additions}")
                spec = importlib.util.spec_from_file_location(
                    self.file.stem,
                    self.file,
                    submodule_search_locations=[parent_dir],
                )
                module = importlib.util.module_from_spec(spec)
                # Set __path__ to enable package-like imports from the hook's directory
                module.__path__ = [parent_dir]
                # Add to sys.modules before executing so imports can find it
                sys.modules[self.file.stem] = module
                spec.loader.exec_module(module)

                # Check that the module contains a class named Hook inheriting from HookBase
                # Do this BEFORE removing paths from sys.path
                hook_class = getattr(module, "Hook", None)
                if not hook_class or not inspect.isclass(hook_class):
                    raise PumHookError(
                        f"Python hook file {self.file} must define a class named 'Hook'."
                    )

                # Check inheritance by class name to handle multiple pum installations
                # (e.g., installed pum vs. libs/pum in QGIS plugin)
                base_classes = [base.__name__ for base in inspect.getmro(hook_class)]
                if "HookBase" not in base_classes:
                    # Get more info for debugging
                    hook_bases = [
                        f"{base.__module__}.{base.__name__}" for base in inspect.getmro(hook_class)
                    ]
                    logger.error(f"Hook class MRO: {hook_bases}")
                    logger.error(
                        f"Expected HookBase from: {HookBase.__module__}.{HookBase.__name__}"
                    )
                    raise PumHookError(
                        f"Class 'Hook' in {self.file} must inherit from HookBase. "
                        f"Found bases: {', '.join(base_classes)}"
                    )

                if not hasattr(hook_class, "run_hook"):
                    raise PumHookError(f"Hook function 'run_hook' not found in {self.file}.")

                self.hook_instance = hook_class()
                arg_names = list(inspect.signature(hook_class.run_hook).parameters.keys())
                if "connection" not in arg_names:
                    raise PumHookError(
                        f"Hook function 'run_hook' in {self.file} must accept 'connection' as an argument."
                    )
                self.parameter_args = [
                    arg for arg in arg_names if arg not in ("self", "connection")
                ]

            except Exception:
                # On error, clean up paths we added
                self.cleanup_sys_path()
                raise

    def cleanup_sys_path(self) -> None:
        """Remove paths that were added to sys.path for this hook.

        This should be called when the hook is no longer needed to prevent
        sys.path pollution.
        """
        for path in self.sys_path_additions:
            # Remove all occurrences of this path (we may have added it multiple times)
            while path in sys.path:
                sys.path.remove(path)
        self.sys_path_additions.clear()

    def __repr__(self) -> str:
        """Return a string representation of the Hook instance."""
        return f"<hook: {self.file}>"

    def __eq__(self, other: "HookHandler") -> bool:
        """Check if two Hook instances are equal."""
        if not isinstance(other, HookHandler):
            return NotImplemented
        return (not self.file or self.file == other.file) and (
            not self.code or self.code == other.code
        )

    def validate(self, parameters: dict) -> None:
        """Check if the parameters match the expected parameter definitions.
        This is only effective for Python hooks for now.

        Args:
            parameters (dict): The parameters to check.

        Raises:
            PumHookError: If the parameters do not match the expected definitions.

        """
        if self.file and self.file.suffix == ".py":
            for parameter_arg in self.parameter_args:
                if parameter_arg not in parameters:
                    raise PumHookError(
                        f"Hook function 'run_hook' in {self.file} has an unexpected argument "
                        f"'{parameter_arg}' which is not specified in the parameters."
                    )

        if self.file and self.file.suffix == ".sql":
            SqlContent(self.file).validate(parameters=parameters)

    def execute(
        self,
        connection: psycopg.Connection,
        *,
        commit: bool = False,
        parameters: dict | None = None,
    ) -> None:
        """Execute the migration hook.
        This method executes the SQL code or the Python file specified in the hook.

        Args:
            connection: The database connection.
            commit: Whether to commit the transaction after executing the SQL.
            parameters (dict, optional): Parameters to bind to the SQL statement. Defaults to ().

        """
        logger.debug(
            f"Executing hook from file: {self.file} or SQL code with parameters: {parameters}",
        )

        parameters_literals = SqlContent.prepare_parameters(parameters)

        if self.file is None and self.code is None:
            raise ValueError("No file or SQL code specified for the migration hook.")

        if self.file:
            if self.file.suffix == ".sql":
                SqlContent(self.file).execute(
                    connection=connection, commit=False, parameters=parameters_literals
                )
            elif self.file.suffix == ".py":
                for parameter_arg in self.parameter_args:
                    if not parameters or parameter_arg not in self.parameter_args:
                        raise PumHookError(
                            f"Hook function 'run_hook' in {self.file} has an unexpected "
                            f"argument '{parameter_arg}' which is not specified in the parameters."
                        )

                _hook_parameters = {}
                if parameters:
                    for key, value in parameters.items():
                        if key in self.parameter_args:
                            _hook_parameters[key] = value
                self.hook_instance._prepare(connection=connection, parameters=parameters)

                # Temporarily add sys.path entries for hook execution
                # This allows dynamic imports inside run_hook to work
                for path in self.sys_path_additions:
                    if path not in sys.path:
                        sys.path.insert(0, path)

                try:
                    if _hook_parameters:
                        self.hook_instance.run_hook(connection=connection, **_hook_parameters)
                    else:
                        self.hook_instance.run_hook(connection=connection)
                except PumSqlError as e:
                    raise PumHookError(f"Error executing Python hook from {self.file}: {e}") from e
                finally:
                    # Remove the paths after execution
                    for path in self.sys_path_additions:
                        if path in sys.path:
                            sys.path.remove(path)

            else:
                raise PumHookError(
                    f"Unsupported file type for migration hook: {self.file.suffix}. Only .sql and .py files are supported."
                )
        elif self.code:
            SqlContent(self.code).execute(connection, parameters=parameters_literals, commit=False)

        if commit:
            connection.commit()

__eq__

__eq__(other: HookHandler) -> bool

Check if two Hook instances are equal.

Source code in pum/hook.py
def __eq__(self, other: "HookHandler") -> bool:
    """Check if two Hook instances are equal."""
    if not isinstance(other, HookHandler):
        return NotImplemented
    return (not self.file or self.file == other.file) and (
        not self.code or self.code == other.code
    )

__init__

__init__(*, file: str | Path | None = None, code: str | None = None, base_path: Path | None = None) -> None

Initialize a Hook instance.

Parameters:

Name Type Description Default
type

The type of the hook (e.g., "pre", "post").

required
file str | Path | None

The file path of the hook.

None
code str | None

The SQL code for the hook.

None
Source code in pum/hook.py
def __init__(
    self,
    *,
    file: str | Path | None = None,
    code: str | None = None,
    base_path: Path | None = None,
) -> None:
    """Initialize a Hook instance.

    Args:
        type: The type of the hook (e.g., "pre", "post").
        file: The file path of the hook.
        code: The SQL code for the hook.

    """
    if file and code:
        raise ValueError("Cannot specify both file and code. Choose one.")

    self.file = file
    self.code = code
    self.hook_instance = None
    self.sys_path_additions = []  # Store paths to add during execution

    if file:
        if isinstance(file, str):
            self.file = Path(file)
        if not self.file.is_absolute():
            if base_path is None:
                raise ValueError("Base path must be provided for relative file paths.")
            self.file = base_path.absolute() / self.file
        if not self.file.exists():
            raise PumHookError(f"Hook file {self.file} does not exist.")
        if not self.file.is_file():
            raise PumHookError(f"Hook file {self.file} is not a file.")

    if self.file and self.file.suffix == ".py":
        # Support local imports in hook files by adding parent dir and base_path to sys.path
        parent_dir = str(self.file.parent.resolve())

        # Store paths that need to be added for hook execution
        # Always add paths even if already in sys.path - we need them at position 0
        # for priority and we'll track what we added for cleanup
        self.sys_path_additions.append(parent_dir)

        # Also add base_path if provided, to support imports from sibling directories
        if base_path is not None:
            base_path_str = str(base_path.resolve())
            if base_path_str != parent_dir:
                self.sys_path_additions.append(base_path_str)

        # Add paths for module loading - insert at position 0 for priority
        for path in reversed(self.sys_path_additions):
            sys.path.insert(0, path)

        # Invalidate caches so Python recognizes the new paths
        importlib.invalidate_caches()

        try:
            logger.debug(f"Loading hook from: {self.file}")
            logger.debug(f"sys.path additions: {self.sys_path_additions}")
            spec = importlib.util.spec_from_file_location(
                self.file.stem,
                self.file,
                submodule_search_locations=[parent_dir],
            )
            module = importlib.util.module_from_spec(spec)
            # Set __path__ to enable package-like imports from the hook's directory
            module.__path__ = [parent_dir]
            # Add to sys.modules before executing so imports can find it
            sys.modules[self.file.stem] = module
            spec.loader.exec_module(module)

            # Check that the module contains a class named Hook inheriting from HookBase
            # Do this BEFORE removing paths from sys.path
            hook_class = getattr(module, "Hook", None)
            if not hook_class or not inspect.isclass(hook_class):
                raise PumHookError(
                    f"Python hook file {self.file} must define a class named 'Hook'."
                )

            # Check inheritance by class name to handle multiple pum installations
            # (e.g., installed pum vs. libs/pum in QGIS plugin)
            base_classes = [base.__name__ for base in inspect.getmro(hook_class)]
            if "HookBase" not in base_classes:
                # Get more info for debugging
                hook_bases = [
                    f"{base.__module__}.{base.__name__}" for base in inspect.getmro(hook_class)
                ]
                logger.error(f"Hook class MRO: {hook_bases}")
                logger.error(
                    f"Expected HookBase from: {HookBase.__module__}.{HookBase.__name__}"
                )
                raise PumHookError(
                    f"Class 'Hook' in {self.file} must inherit from HookBase. "
                    f"Found bases: {', '.join(base_classes)}"
                )

            if not hasattr(hook_class, "run_hook"):
                raise PumHookError(f"Hook function 'run_hook' not found in {self.file}.")

            self.hook_instance = hook_class()
            arg_names = list(inspect.signature(hook_class.run_hook).parameters.keys())
            if "connection" not in arg_names:
                raise PumHookError(
                    f"Hook function 'run_hook' in {self.file} must accept 'connection' as an argument."
                )
            self.parameter_args = [
                arg for arg in arg_names if arg not in ("self", "connection")
            ]

        except Exception:
            # On error, clean up paths we added
            self.cleanup_sys_path()
            raise

__repr__

__repr__() -> str

Return a string representation of the Hook instance.

Source code in pum/hook.py
def __repr__(self) -> str:
    """Return a string representation of the Hook instance."""
    return f"<hook: {self.file}>"

cleanup_sys_path

cleanup_sys_path() -> None

Remove paths that were added to sys.path for this hook.

This should be called when the hook is no longer needed to prevent sys.path pollution.

Source code in pum/hook.py
def cleanup_sys_path(self) -> None:
    """Remove paths that were added to sys.path for this hook.

    This should be called when the hook is no longer needed to prevent
    sys.path pollution.
    """
    for path in self.sys_path_additions:
        # Remove all occurrences of this path (we may have added it multiple times)
        while path in sys.path:
            sys.path.remove(path)
    self.sys_path_additions.clear()

execute

execute(connection: Connection, *, commit: bool = False, parameters: dict | None = None) -> None

Execute the migration hook. This method executes the SQL code or the Python file specified in the hook.

Parameters:

Name Type Description Default
connection Connection

The database connection.

required
commit bool

Whether to commit the transaction after executing the SQL.

False
parameters dict

Parameters to bind to the SQL statement. Defaults to ().

None
Source code in pum/hook.py
def execute(
    self,
    connection: psycopg.Connection,
    *,
    commit: bool = False,
    parameters: dict | None = None,
) -> None:
    """Execute the migration hook.
    This method executes the SQL code or the Python file specified in the hook.

    Args:
        connection: The database connection.
        commit: Whether to commit the transaction after executing the SQL.
        parameters (dict, optional): Parameters to bind to the SQL statement. Defaults to ().

    """
    logger.debug(
        f"Executing hook from file: {self.file} or SQL code with parameters: {parameters}",
    )

    parameters_literals = SqlContent.prepare_parameters(parameters)

    if self.file is None and self.code is None:
        raise ValueError("No file or SQL code specified for the migration hook.")

    if self.file:
        if self.file.suffix == ".sql":
            SqlContent(self.file).execute(
                connection=connection, commit=False, parameters=parameters_literals
            )
        elif self.file.suffix == ".py":
            for parameter_arg in self.parameter_args:
                if not parameters or parameter_arg not in self.parameter_args:
                    raise PumHookError(
                        f"Hook function 'run_hook' in {self.file} has an unexpected "
                        f"argument '{parameter_arg}' which is not specified in the parameters."
                    )

            _hook_parameters = {}
            if parameters:
                for key, value in parameters.items():
                    if key in self.parameter_args:
                        _hook_parameters[key] = value
            self.hook_instance._prepare(connection=connection, parameters=parameters)

            # Temporarily add sys.path entries for hook execution
            # This allows dynamic imports inside run_hook to work
            for path in self.sys_path_additions:
                if path not in sys.path:
                    sys.path.insert(0, path)

            try:
                if _hook_parameters:
                    self.hook_instance.run_hook(connection=connection, **_hook_parameters)
                else:
                    self.hook_instance.run_hook(connection=connection)
            except PumSqlError as e:
                raise PumHookError(f"Error executing Python hook from {self.file}: {e}") from e
            finally:
                # Remove the paths after execution
                for path in self.sys_path_additions:
                    if path in sys.path:
                        sys.path.remove(path)

        else:
            raise PumHookError(
                f"Unsupported file type for migration hook: {self.file.suffix}. Only .sql and .py files are supported."
            )
    elif self.code:
        SqlContent(self.code).execute(connection, parameters=parameters_literals, commit=False)

    if commit:
        connection.commit()

validate

validate(parameters: dict) -> None

Check if the parameters match the expected parameter definitions. This is only effective for Python hooks for now.

Parameters:

Name Type Description Default
parameters dict

The parameters to check.

required

Raises:

Type Description
PumHookError

If the parameters do not match the expected definitions.

Source code in pum/hook.py
def validate(self, parameters: dict) -> None:
    """Check if the parameters match the expected parameter definitions.
    This is only effective for Python hooks for now.

    Args:
        parameters (dict): The parameters to check.

    Raises:
        PumHookError: If the parameters do not match the expected definitions.

    """
    if self.file and self.file.suffix == ".py":
        for parameter_arg in self.parameter_args:
            if parameter_arg not in parameters:
                raise PumHookError(
                    f"Hook function 'run_hook' in {self.file} has an unexpected argument "
                    f"'{parameter_arg}' which is not specified in the parameters."
                )

    if self.file and self.file.suffix == ".sql":
        SqlContent(self.file).validate(parameters=parameters)