Source code for airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""This module contains the Amazon SageMaker Unified Studio Notebook Run hook."""

from __future__ import annotations

import json
import logging
import math
import time
import uuid
from functools import cached_property
from typing import Any
from urllib.parse import urlparse

from botocore.exceptions import ClientError

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook

[docs] TWELVE_HOURS_IN_MINUTES = 12 * 60
#: Minimum botocore version required for the DataZone NotebookRun APIs.
[docs] MIN_BOTOCORE_VERSION = "1.43.1"
#: Terminal success states for a notebook run.
[docs] NOTEBOOK_RUN_SUCCESS_STATES = frozenset({"SUCCEEDED"})
#: States indicating a notebook run is still in progress.
[docs] NOTEBOOK_RUN_IN_PROGRESS_STATES = frozenset({"QUEUED", "STARTING", "RUNNING", "STOPPING"})
#: Terminal failure states for a notebook run.
[docs] NOTEBOOK_RUN_FAILURE_STATES = frozenset({"FAILED", "STOPPED"})
#: XCom key prefix for notebook output variables.
[docs] NOTEBOOK_OUTPUT_KEY_PREFIX = "NOTEBOOK_OUTPUT"
[docs] class SageMakerUnifiedStudioNotebookHook(AwsBaseHook): """ Interact with Sagemaker Unified Studio Workflows for asynchronous notebook execution. This hook provides a wrapper around the DataZone StartNotebookRun / GetNotebookRun APIs. Examples: .. code-block:: python from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( SageMakerUnifiedStudioNotebookHook, ) hook = SageMakerUnifiedStudioNotebookHook(aws_conn_id="my_aws_conn") Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ def __init__(self, *args: Any, **kwargs: Any): self._endpoint_url = kwargs.pop("endpoint_url", None) kwargs.setdefault("client_type", "datazone") super().__init__(*args, **kwargs) @cached_property
[docs] def conn(self): """Get the underlying boto3 DataZone client, optionally with a custom endpoint URL.""" if self._endpoint_url: session = self.get_session() return session.client( "datazone", endpoint_url=self._endpoint_url, config=self.config, verify=self.verify, ) return super().conn
def _validate_api_availability(self) -> None: """ Verify that the NotebookRun APIs are available in the installed boto3/botocore version. :raises RuntimeError: If the required APIs are not available. """ required_methods = ("start_notebook_run", "get_notebook_run") for method_name in required_methods: if not hasattr(self.conn, method_name): raise RuntimeError( f"The '{method_name}' API is not available in the installed boto3/botocore version. " f"Please upgrade botocore to version {MIN_BOTOCORE_VERSION} or later to use the " f"DataZone NotebookRun APIs: pip install 'botocore>={MIN_BOTOCORE_VERSION}'" )
[docs] def start_notebook_run( self, notebook_identifier: str, domain_identifier: str, owning_project_identifier: str, client_token: str | None = None, notebook_parameters: dict | None = None, compute_configuration: dict | None = None, timeout_configuration: dict | None = None, workflow_name: str | None = None, ) -> dict: """ Start an asynchronous notebook run via the DataZone StartNotebookRun API. :param notebook_identifier: The ID of the notebook to execute. :param domain_identifier: The ID of the DataZone domain containing the notebook. :param owning_project_identifier: The ID of the DataZone project containing the notebook. :param client_token: Idempotency token. Auto-generated if not provided. :param notebook_parameters: Parameters to pass to the notebook. :param compute_configuration: Compute config (e.g. instanceType). :param timeout_configuration: Timeout settings (runTimeoutInMinutes). :param workflow_name: Name of the workflow (DAG) that triggered this run. :return: The StartNotebookRun API response dict. """ self._validate_api_availability() params: dict = { "domainIdentifier": domain_identifier, "owningProjectIdentifier": owning_project_identifier, "notebookIdentifier": notebook_identifier, "clientToken": client_token or str(uuid.uuid4()), } if notebook_parameters: params["parameters"] = notebook_parameters if compute_configuration: params["computeConfiguration"] = compute_configuration if timeout_configuration: params["timeoutConfiguration"] = timeout_configuration if workflow_name: params["triggerSource"] = {"type": "WORKFLOW", "name": workflow_name} self.log.info( "Starting notebook run for notebook %s in domain %s", notebook_identifier, domain_identifier ) return self.conn.start_notebook_run(**params)
[docs] def get_notebook_run(self, notebook_run_id: str, domain_identifier: str) -> dict: """ Get the status of a notebook run via the DataZone GetNotebookRun API. :param notebook_run_id: The ID of the notebook run. :param domain_identifier: The ID of the DataZone domain. :return: The GetNotebookRun API response dict. """ self._validate_api_availability() return self.conn.get_notebook_run( domainIdentifier=domain_identifier, identifier=notebook_run_id, )
[docs] def wait_for_notebook_run( self, notebook_run_id: str, domain_identifier: str, waiter_delay: int = 10, timeout_configuration: dict | None = None, ) -> dict: """ Poll GetNotebookRun until the run reaches a terminal state. :param notebook_run_id: The ID of the notebook run to monitor. :param domain_identifier: The ID of the DataZone domain. :param waiter_delay: Interval in seconds to poll the notebook run status. :param timeout_configuration: Timeout settings for the notebook execution. When provided, the maximum number of poll attempts is derived from ``runTimeoutInMinutes * 60 / waiter_delay``. Defaults to 12 hours. :return: A dict with Status and NotebookRunId on success. :raises RuntimeError: If the run fails or times out. """ if waiter_delay <= 0: raise ValueError("waiter_delay must be a positive integer") run_timeout = (timeout_configuration or {}).get("runTimeoutInMinutes", TWELVE_HOURS_IN_MINUTES) waiter_max_attempts = max(1, math.ceil(run_timeout * 60 / waiter_delay)) for _attempt in range(waiter_max_attempts): response = self.get_notebook_run(notebook_run_id, domain_identifier=domain_identifier) status = response.get("status", "") error_message = response.get("errorMessage", "") ret = self._handle_status(notebook_run_id, status, error_message, waiter_delay) if ret: return ret time.sleep(waiter_delay) error_message = "Execution timed out" self.log.error("Notebook run %s failed with error: %s", notebook_run_id, error_message) raise RuntimeError(error_message)
def _handle_status( self, notebook_run_id: str, status: str, error_message: str, waiter_delay: int = 10 ) -> dict | None: """ Evaluate the current notebook run status and return or raise accordingly. :param notebook_run_id: The ID of the notebook run. :param status: The current status string. :param error_message: Error message from the API response, if any. :param waiter_delay: Interval in seconds between polls (for logging). :return: A dict with Status and NotebookRunId on success, None if still in progress. :raises RuntimeError: If the run has failed. """ in_progress_statuses = NOTEBOOK_RUN_IN_PROGRESS_STATES finished_statuses = NOTEBOOK_RUN_SUCCESS_STATES failure_statuses = NOTEBOOK_RUN_FAILURE_STATES if status in in_progress_statuses: self.log.info( "Notebook run %s is still in progress with status: %s, " "will check for a terminal status again in %ss", notebook_run_id, status, waiter_delay, ) return None execution_message = f"Exiting notebook run {notebook_run_id}. Status: {status}" if status in finished_statuses: self.log.info(execution_message) return {"Status": status, "NotebookRunId": notebook_run_id} if status in failure_statuses: self.log.error("Notebook run %s failed with error: %s", notebook_run_id, error_message) else: self.log.error("Notebook run %s reached unexpected status: %s", notebook_run_id, status) if error_message == "": error_message = execution_message raise RuntimeError(error_message)
[docs] def get_project_s3_path(self, domain_identifier: str, project_id: str) -> tuple[str, str]: """ Look up the S3 location for a SageMaker Unified Studio project. The bucket and key prefix are read from the ``s3BucketPath`` provisioned resource of the project's default ("Tooling") environment via the DataZone APIs. This mirrors how SageMaker Unified Studio resolves the project bucket and accommodates projects whose bucket name does not follow the ``amazon-sagemaker-{account_id}-{region}-{project_id}`` template (for example, BYOR-bucket projects). :param domain_identifier: The ID of the DataZone domain. :param project_id: The ID of the DataZone project. :return: A ``(bucket, prefix)`` tuple. ``bucket`` is the S3 bucket name. ``prefix`` is the path component of the project's ``s3BucketPath`` (with no leading or trailing ``/``). :raises RuntimeError: If the default tooling environment or the ``s3BucketPath`` provisioned resource cannot be found. """ environment = self._get_default_tooling_environment(domain_identifier, project_id) environment_id = environment.get("id") provisioned_resources = environment.get("provisionedResources", []) or [] for resource in provisioned_resources: if resource.get("name") == "s3BucketPath": value = resource.get("value") if not value: raise RuntimeError( f"s3BucketPath provisioned resource is empty in default tooling " f"environment {environment_id} for project {project_id} in domain " f"{domain_identifier}" ) # value looks like "s3://<bucket>/shared/<suffix>" (IAM) or # "s3://<bucket>/<domain>/<project>/dev/<suffix>" (IDC). Return both # parts so callers can construct project-scoped keys. parts = urlparse(value, allow_fragments=False) bucket = parts.netloc if not bucket: raise RuntimeError( f"s3BucketPath provisioned resource has unexpected format " f"'{value}' in default tooling environment {environment_id} for " f"project {project_id} in domain {domain_identifier}" ) prefix = parts.path.strip("/") return bucket, prefix raise RuntimeError( f"s3BucketPath provisioned resource not found in default tooling environment " f"{environment_id} for project {project_id} in domain {domain_identifier}" )
def _get_default_tooling_environment(self, domain_identifier: str, project_id: str) -> dict: """ Resolve the project's default ("Tooling") environment via DataZone APIs. 1. ``ListEnvironmentBlueprints(managed=True, name="Tooling")`` → resolve the Tooling blueprint id. 2. ``ListEnvironments(environmentBlueprintIdentifier=...)`` → list the project's tooling environments. 3. Pick the environment with the smallest non-null ``deploymentOrder`` as the default. If none has one, fall back to the ``ToolingLite`` blueprint with the same logic. 4. ``GetEnvironment(identifier=...)`` → read the full record (including ``provisionedResources``). :param domain_identifier: The ID of the DataZone domain. :param project_id: The ID of the DataZone project. :return: The full environment dict from ``GetEnvironment``. :raises RuntimeError: If no default Tooling/ToolingLite environment is found or the DataZone APIs return an error. """ try: default_env_summary = self._find_default_tooling_environment_summary( domain_identifier=domain_identifier, project_id=project_id, blueprint_name="Tooling", ) if default_env_summary is None: default_env_summary = self._find_default_tooling_environment_summary( domain_identifier=domain_identifier, project_id=project_id, blueprint_name="ToolingLite", ) if default_env_summary is None: raise RuntimeError( f"No default Tooling or ToolingLite environment found for project " f"{project_id} in domain {domain_identifier}" ) return self.conn.get_environment( domainIdentifier=domain_identifier, identifier=default_env_summary["id"], ) except ClientError as e: raise RuntimeError( f"Failed to resolve default tooling environment for project {project_id} " f"in domain {domain_identifier}: {e}" ) from e def _find_default_tooling_environment_summary( self, domain_identifier: str, project_id: str, blueprint_name: str, ) -> dict | None: """ Resolve the default tooling environment summary for a given blueprint. Returns ``None`` when the blueprint has no environments for the project (so the caller can fall back to ``ToolingLite``). When environments exist, prefers the one with the lowest non-null ``deploymentOrder``; when ``deploymentOrder`` is absent on every env (the field is optional in the DataZone response shape), falls back to the first item. Raises ``RuntimeError`` only when the blueprint itself is missing. :param domain_identifier: The ID of the DataZone domain. :param project_id: The ID of the DataZone project. :param blueprint_name: ``"Tooling"`` or ``"ToolingLite"``. :return: The environment summary dict, or ``None``. """ blueprints = ( self.conn.list_environment_blueprints( domainIdentifier=domain_identifier, managed=True, name=blueprint_name, ).get("items", []) or [] ) if not blueprints: raise RuntimeError( f"{blueprint_name} environment blueprint not found in domain {domain_identifier}" ) blueprint_id = blueprints[0]["id"] environments = ( self.conn.list_environments( domainIdentifier=domain_identifier, projectIdentifier=project_id, environmentBlueprintIdentifier=blueprint_id, ).get("items", []) or [] ) if not environments: return None ordered = [env for env in environments if env.get("deploymentOrder") is not None] if ordered: return min(ordered, key=lambda env: env["deploymentOrder"]) # ``deploymentOrder`` is optional in the EnvironmentSummary shape; when # absent on every item, fall back to the first env for this blueprint. return environments[0]
[docs] def get_notebook_outputs( self, notebook_identifier: str, notebook_run_id: str, domain_identifier: str, owning_project_identifier: str, ) -> dict[str, Any]: """ Read notebook output artifacts from the S3 project bucket. After a notebook run completes, the SDK writes output variables as a JSON file to a well-known S3 location within the project bucket. This method reads that file and returns the parsed key-value pairs. :param notebook_identifier: The ID of the notebook that was executed. :param notebook_run_id: The ID of the completed notebook run. :param domain_identifier: The ID of the DataZone domain. :param owning_project_identifier: The ID of the DataZone project. :return: A dict of notebook output key-value pairs. Returns an empty dict if no outputs were written or the file cannot be parsed. """ log = logging.getLogger(__name__) try: bucket, prefix = self.get_project_s3_path(domain_identifier, owning_project_identifier) except Exception: log.warning( "Failed to resolve project S3 location for project %s in domain %s, " "skipping notebook outputs read.", owning_project_identifier, domain_identifier, exc_info=True, ) return {} # IDC domains have a non-empty prefix (e.g. "<domain>/<project>/<scope>") # and the project role's IAM policy only allows S3 reads under that prefix. # IAM domains have prefix == "" and the key is unchanged from the # legacy bucket-root layout. run_key = f".sys/notebooks/{notebook_identifier}/runs/{notebook_run_id}/notebook_outputs.json" key = f"{prefix}/{run_key}" if prefix else run_key log.info("Reading notebook outputs from s3://%s/%s", bucket, key) s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, region_name=self.conn_region_name) try: content = s3_hook.read_key(key=key, bucket_name=bucket) outputs = json.loads(content) if not isinstance(outputs, dict): log.warning( "Notebook outputs at s3://%s/%s is not a JSON object, ignoring.", bucket, key, ) return {} log.info("Successfully read %d notebook output(s).", len(outputs)) return outputs except ClientError as e: error_code = e.response.get("Error", {}).get("Code") if error_code in ("NoSuchKey", "404"): log.info("No notebook outputs found at s3://%s/%s.", bucket, key) return {} log.warning( "Unexpected error reading notebook outputs from s3://%s/%s, ignoring.", bucket, key, exc_info=True, ) return {} except (json.JSONDecodeError, UnicodeDecodeError): log.warning( "Failed to parse notebook outputs at s3://%s/%s as JSON, ignoring.", bucket, key, ) return {} except Exception: log.warning( "Unexpected error reading notebook outputs from s3://%s/%s, ignoring.", bucket, key, exc_info=True, ) return {}

Was this entry helpful?