case-kの備忘録

日々の備忘録です。データ分析とか基盤系に興味あります。

API GatewayとVPCエンドポイントを活用した、プライベートネットワークからパブリックネットワークへの接続手法

本稿では、API Gatewayをプロキシサーバとして活用し、プライベートネットワークからパブリックネットワークへ接続する方法をご紹介します。一般的にはNAT Gatewayを利用する構成が多いですが、セキュリティ要件などによりインターネットへのアクセスを厳格に管理する必要がある場合、API Gatewayの活用が有効です。また、最近ではLLMが注目される中、API Gatewayの利用クオータが引き上げられました。今後このような活用事例も増えてくるかもしれません。
aws.amazon.com

また、先日投稿した記事もLLMのサービングに関連する記事です。
www.case-k.jp

本記事では詳細な説明は省略いたしますが、Terraformによる定義例を共有いたします。同様の活用事例を検討される方の参考になれば幸いです。

Terraform

ネットワークリソース

まず、ネットワークリソースについて簡単に説明いたします。今回扱うVPCのリソースマップは以下の通りです。

図から、ネットワーク接続にインターネットゲートウェイやNAT Gatewayが存在しないことが確認できます。
これは、API Gatewayをプロキシサーバとしてインターナルアクセスする場合、これらのリソースを自前で用意する必要がないためです。
なお、参考までに、NAT Gatewayを用いた構成例は以下のようになります。

まずはネットワークリソースの作成から始めます。VPC環境からAPI Gatewayへのインターナルアクセスを実現するため、VPCエンドポイントを活用します。

resource "aws_vpc" "vpc" {
  cidr_block = "10.0.0.0/16"
  enable_dns_support   = true
  enable_dns_hostnames = true

  tags = {
    "Name" = "vpc-for-api-gateway"
  }
  tags_all = {
    "Name" = "vpc-for-api-gateway"
  }
}

# resource "aws_subnet" "aws_subnet_private" {
resource "aws_subnet" "subnet_private" {
  vpc_id     = aws_vpc.vpc.id
  cidr_block = "10.0.1.0/24"
  tags = {
    "Name" = "private-subnet-for-api-gateway"
  }
  tags_all = {
    "Name" = "private-subnet-for-api-gateway"
  }
}

ingressにはself = trueを設定し、同じセキュリティグループ内のリソース同士で通信できるようになります。NAT Gateway には不要ですが、API GatewayVPC エンドポイントアクセスなど、セキュリティグループ内のリソース間で HTTPS 通信が必要な場合に有用です。アウトバウンドのトラフィックは全て許可します。

resource "aws_security_group" "security_group" {
  vpc_id = aws_vpc.vpc.id

  # ingress does not need for nat gateway but need for API gateway VPC endpoint access
  ingress {
    from_port = 443
    to_port   = 443
    protocol  = "tcp"
    self      = true
  }

  egress {
    from_port   = 0
    to_port     = 0
    protocol    = "-1"
    cidr_blocks = ["0.0.0.0/0"]
  }

  tags = {
    "Name" = "sg-private-subnet-for-api-gateway"
  }
  tags_all = {
    "Name" = "sg-private-subnet-for-api-gateway"
  }
}

以下のコードは、Terraform を利用してプライベート API Gateway 用のルートテーブルおよび VPC エンドポイントを構築する例です。インターフェース型エンドポイントは、AWS PrivateLink を利用してプライベートにサービスに接続するためのものです。API Gatewayへのインターナル接続で利用します。

resource "aws_route_table" "private-route" {
  propagating_vgws = []
  tags = {
    Name = "private-route-for-api-gateway"
  }
  tags_all = {
    Name = "private-route-for-api-gateway"
  }
  vpc_id = aws_vpc.vpc.id
}

resource "aws_route_table_association" "route_table_association" {
  subnet_id      = aws_subnet.subnet_private.id
  route_table_id = aws_route_table.private-route.id
}

# ref
# https://docs.aws.amazon.com/ja_jp/apigateway/latest/developerguide/apigateway-private-api-create.html
resource "aws_vpc_endpoint" "vpc_endpoint" {
  vpc_id            = aws_vpc.vpc.id
  service_name      = "com.amazonaws.ap-northeast-1.execute-api"
  vpc_endpoint_type = "Interface"

  subnet_ids = [
    aws_subnet.subnet_private.id
  ]

  security_group_ids = [
    aws_security_group.security_group.id,
  ]

  private_dns_enabled = true

  tags = {
    Name = "vpc-endpoint-for-api-gateway"
  }
  tags_all = {
    Name = "vpc-endpoint-for-api-gateway"
  }
}

API Gateway

以下はAPI Gatewayの定義例です。詳細は省略いたしますが、API Gatewayをプロキシサーバとして活用する構成となっています。APIエンドポイントのタイプはプライベートに設定し、VPCエンドポイントから接続できるようにしています。デプロイ後、API Gatewayのエンドポイントにリクエストを送ると、指定したエンドポイントへアクセスが可能です。

REST API の基本情報(名前やAPIキーの受け取り方法など)を設定し、リソースポリシーで「aws:SourceVpc」が指定の VPC と一致する場合のみ呼び出しを許可する条件を付与しています。また、エンドポイントを PRIVATE に設定し、特定の VPC エンドポイント(aws_vpc_endpoint.vpc_endpoint.id)と連携させています。

resource "aws_api_gateway_rest_api" "api_gateway_rest_api" {
  api_key_source               = "HEADER"
  binary_media_types           = []
  body                         = null
  description                  = null
  disable_execute_api_endpoint = false
  fail_on_warnings             = null
  minimum_compression_size     = null
  name                         = "api-private-gw"
  parameters                   = null
  # policy                       = null
  policy = jsonencode({
    Version = "2012-10-17"
    Statement = [
      {
        Effect    = "Allow"
        Principal = "*"
        Action    = "execute-api:Invoke"
        Resource  = "execute-api:/*"
        # Resource  = "arn:aws:execute-api:ap-northeast-1:132483466678:l9jk54cpy0/*"
        Condition = {
          StringEquals = {
            "aws:SourceVpc" = aws_vpc.vpc.id
          }
        }
      }
    ]
  })
  put_rest_api_mode = "overwrite"
  tags              = {}
  tags_all          = {}
  endpoint_configuration {
    types            = ["PRIVATE"]
    vpc_endpoint_ids = [aws_vpc_endpoint.vpc_endpoint.id]
  }
}

API のルートリソース直下に、パスパラメータ "{proxy+}" を持つリソースを作成します。

resource "aws_api_gateway_resource" "api_gateway_resource" {
  depends_on = [aws_api_gateway_rest_api.api_gateway_rest_api]
  # parent_id   = aws_api_gateway_resource.api_gateway_resource_parent.id
  parent_id   = aws_api_gateway_rest_api.api_gateway_rest_api.root_resource_id
  path_part   = "{proxy+}"
  rest_api_id = aws_api_gateway_rest_api.api_gateway_rest_api.id
}

上記リソースに対して HTTP のすべてのメソッド(ANY)を許可し、APIキー認証を必須としています。リクエストパラメータとして URL のパス部分を必須に設定。

resource "aws_api_gateway_method" "api_gateway_method" {
  depends_on           = [aws_api_gateway_resource.api_gateway_resource]
  api_key_required     = true
  authorization        = "NONE"
  authorization_scopes = []
  authorizer_id        = null
  http_method          = "ANY"
  operation_name       = null
  request_models       = {}
  request_parameters = {
    "method.request.path.proxy" = true
  }
  request_validator_id = null
  resource_id          = aws_api_gateway_resource.api_gateway_resource.id
  rest_api_id          = aws_api_gateway_rest_api.api_gateway_rest_api.id
}

メソッド呼び出し成功時(200)のレスポンスモデルを定義。

resource "aws_api_gateway_method_response" "api_gateway_method_response" {
  rest_api_id = aws_api_gateway_rest_api.api_gateway_rest_api.id
  resource_id = aws_api_gateway_resource.api_gateway_resource.id
  http_method = aws_api_gateway_method.api_gateway_method.http_method
  status_code = "200"
  response_models = {
    "application/json" = "Empty"
  }
}

バックエンドからのレスポンスを、JSON テンプレート(空のスキーマ)に変換する設定を行っています。

resource "aws_api_gateway_integration_response" "api_gateway_integration_response" {
  rest_api_id = aws_api_gateway_rest_api.api_gateway_rest_api.id
  resource_id = aws_api_gateway_resource.api_gateway_resource.id
  http_method = aws_api_gateway_method.api_gateway_method.http_method
  status_code = aws_api_gateway_method_response.api_gateway_method_response.status_code
  response_templates  = {
    "application/json" = jsonencode(
          {
            "$schema" = "http://json-schema.org/draft-04/schema#"
            title     = "Empty Schema"
            type      = "object"
          }
      )
  }
}

API の呼び出しに必要な API キー(ここでは "test-key")を作成。

resource "aws_api_gateway_api_key" "api_gateway_api_key" {
  customer_id = null
  description = null
  enabled     = true
  name        = "test-key"
  tags        = {}
  tags_all    = {}
  value       = null # sensitive
}

1日あたりのリクエスト上限(クォータ)やスロットリング(バースト・レート制限)を設定し、対象の API ステージ("test")を紐付けています。

resource "aws_api_gateway_deployment" "api_gateway_deployment" {
  depends_on        = [aws_api_gateway_method.api_gateway_method]
  description       = null
  rest_api_id       = aws_api_gateway_rest_api.api_gateway_rest_api.id
  stage_description = null
  stage_name        = null
  triggers          = null
  variables         = null
}

作成した API キーと使用量プランを関連付け、API キー利用者に対して制限を適用します。

resource "aws_api_gateway_stage" "api_gateway_stage" {
  cache_cluster_enabled = false
  cache_cluster_size    = null
  client_certificate_id = null
  deployment_id         = aws_api_gateway_deployment.api_gateway_deployment.id
  description           = null
  documentation_version = null
  rest_api_id           = aws_api_gateway_rest_api.api_gateway_rest_api.id
  stage_name            = "test"
  tags                  = {}
  tags_all              = {}
  variables             = {}
  xray_tracing_enabled  = false
}

resource "aws_api_gateway_usage_plan" "api_gateway_usage_plan" {
  description  = null
  name         = "test-gw-plan"
  product_code = null
  tags         = {}
  tags_all     = {}
  api_stages {
    api_id = aws_api_gateway_rest_api.api_gateway_rest_api.id
    stage  = aws_api_gateway_stage.api_gateway_stage.stage_name
  }
  quota_settings {
    limit  = 20
    offset = 0
    period = "DAY"
  }
  throttle_settings {
    burst_limit = 5
    rate_limit  = 5
  }
}

resource "aws_api_gateway_usage_plan_key" "api_gateway_usage_plan_key" {
  key_id        = aws_api_gateway_api_key.api_gateway_api_key.id
  key_type      = "API_KEY"
  usage_plan_id = aws_api_gateway_usage_plan.api_gateway_usage_plan.id
}

定義したメソッド(ANY)に対して、HTTP プロキシ統合を設定し、受け取ったリクエストを https://httpbin.org/get に転送します。キャッシュのキーとしてパスパラメータを利用し、タイムアウトなどの挙動も指定しています。

resource "aws_api_gateway_integration" "api_gateway_integration" {
  cache_key_parameters    = ["method.request.path.proxy"]
  cache_namespace         = aws_api_gateway_resource.api_gateway_resource.id
  connection_id           = null
  connection_type         = "INTERNET"
  content_handling        = null
  credentials             = null
  http_method             = "ANY"
  integration_http_method = "GET"
  passthrough_behavior    = "WHEN_NO_MATCH"
  request_parameters = {
    "integration.request.path.proxy" = "method.request.path.proxy"
  }
  request_templates    = {}
  resource_id          = aws_api_gateway_resource.api_gateway_resource.id
  rest_api_id          = aws_api_gateway_rest_api.api_gateway_rest_api.id
  timeout_milliseconds = 5000
  type                 = "HTTP"
  uri                  = "https://httpbin.org/get"
}

Lambda

疎通確認用にLambdaをVPCにデプロイします。

import requests
import json


def lambda_handler(event, context):
    # ref
    # url = "https://httpbin.org/get"
    url = "https://<api-gw-id>.execute-api.ap-northeast-1.amazonaws.com/test/{proxy+}"
    
    # APIキーを指定
    api_key = "api-key"

    headers = {
        "x-api-key": api_key
    }
    
    response = requests.get(url, headers=headers)
    
    print(response.text)
    
    return {
        "statusCode": response.status_code,
        "body": response.text
    }

Terraform

data "archive_file" "lambda_zip" {
  type        = "zip"
  source_dir  = "../app/lambda/python/package"
  output_path = "../app/lambda/python/lambda_function.zip"
}

resource "aws_lambda_function" "lambda_function_for_api_gateway" {
  depends_on       = [data.archive_file.lambda_zip]
  function_name    = "lambda_function_for_api_gateway"
  role             = aws_iam_role.lambda_sample_function.arn
  handler          = "lambda_function.lambda_handler"
  runtime          = "python3.11"
  memory_size      = 128
  timeout          = 5
  source_code_hash = filebase64sha256(data.archive_file.lambda_zip.output_path)
  filename         = data.archive_file.lambda_zip.output_path
  vpc_config {
    subnet_ids         = [aws_subnet.subnet_private.id]
    security_group_ids = [aws_security_group.security_group.id]
  }
}

疎通確認

VPCにデプロイ済みのLambda関数をテスト実行し、プライベートネットワークからAPI Gatewayをプロキシサーバとして経由し、パブリックネットワークへ接続できるか確認します。


以上となります。

terraform-provider-awsにコントリビュートする方法:Issueの起票からコントリビュートまで

terraform-provider-awsにコントリビュートする機会があったので、その過程を備忘録として記録しておきたいと思います。コントリビュートした内容としてはEventBridgeへの強制削除オプションの追加です。EventBridgeを触っている時に見つけたもので、issueを起票してコントリビュートしました。terraform-provider-awsへのコントリビュートを考えている方の参考になれば幸いです。
github.com

Issueの見つけ方

仕様の違和感に気づく

TerraformでCI/CDを構築する際、デプロイしたEventBridgeを削除しようとしたところ、以下のエラーが発生しました。EventBridgeのターゲットが存在している場合、強制的に削除ができない仕様のようです。この制約は不便に感じたため、APIの仕様を確認しました。私の場合はこのような違和感が、コントリビュートのきっかけになることが多いように思います。

│ Error: deleting EventBridge Rule (): ValidationException: Rule can't be deleted since it has targets.
│       status code: 400, request id: <resource_id>

また、すでに多くのissueが起票されているので、もし興味のあるissueがあればここから探せます。

github.com

AWS API Referenceを確認する

APIで実装されていない機能はTerraformでも利用できないため、まずはEventBridgeのAPI仕様を確認します。今回はEventBridgeの強制削除がサポートされていることを確認しました。
docs.aws.amazon.com

Terraformのドキュメントを確認する

Terraformのドキュメントを確認し、EventBridgeにおける強制削除オプションがサポートされているか確認しました。今回はEventBridgeの強制削除オプションが未実装であることを確認しました。
github.com

Issueを起票する

Terraformで強制削除オプションがサポートされていないことを確認したため、Issueを起票しました。OSSによっては、Issueを作成する際に「Would you like to implement a fix?(修正を実装したいですか?)」という確認項目が含まれる場合があります。この場合、修正を自分で行いたい場合は、その旨を記載して意思を示します。

github.com

逆にそのような項目がないOSSに対してはテンプレートに項目を追加するのはどうか聞いてみるといいと思います。
過去に別のOSSで提案して、追加してもらいました。
github.com

コントリビュートする方法

Contributing Guideを確認する

READMEに記載されているContributing Guideを確認します。また、過去にマージされたPull Requestも参照し、コミュニティでの一般的なお作法や手順を把握します。確認すると、Pull Requestにはテストの実行結果を添付する必要があることが分かります。

hashicorp.github.io

Terraformのコードを確認してPRを作成する

リポジトリをフォークし、ブランチを切って、PRを作成します。この流れはどのOSSでもほとんど同じように思います。ブランチを切ったら、問題となる箇所を特定します。私の場合、まずは既存のテストコードを実行できるように環境を整備します。既存のテストコードを動かすと、処理の全体の流れが掴めます。修正すべき該当箇所が特定できるので、必要な修正を加えた後にPull Requestを作成します。

レビューしてもらう

Pull Request作成後はメンテナの反応を待ちます。個人的にはterraform-provider-awsは人手不足なのか少し反応が遅いように感じました。この辺はOSSによって随分違います。

まとめ

terraform-provider-awsにコントリビュートしするまでの大まかな流れを紹介しました。私の場合、実際に触ってみて違和感を感じた箇所やバグを修正する形でコントリビュートすることが多いです。GitHub上でIssueを確認したり、未解決の課題を探したりすると、さらに多くのコントリビュート機会ができると思います。現時点ではそこまで積極的にやれてないですが、今年はもう少し挑戦したいと思います。

Databricksでモデルサービングを迅速にデプロイするNotebook運用

本記事は、 Databricks - Qiita Advent Calendar 2024 - Qiitaシリーズ 2 の 25 日目の記事です。

モデルサービングをデプロイする際に使用しているNotebookテンプレートの運用をご紹介します。

モデルサービングはTerraformでデプロイすることも可能です。しかし、モデルの運用(再学習)などを考慮すると、Terraformで管理しているモデルのバージョンと実際に利用されているモデルのバージョンが一致しない場合があります。そこで、TerraformではNotebook Jobのみを管理し、Notebook内でモデルサービングのデプロイを行う運用を採用しています。

また、モデルサービングをデプロイする際には、モデルのロジック部分以外を共通化することでテンプレートとして管理しています。このテンプレートを利用することで、モデルの作成者はロジック部分のみを編集するだけで簡単にモデルサービングをデプロイできます。

docs.databricks.com

Notebookテンプレート

以下のようなテンプレートを作成しています。「EDITABLE 」となっている箇所がモデル作成者側で修正する箇所になります。利用者はテンプレートをコピーし、「EDITABLE 」となっているモデルのロジック部分を修正します。「EDITABLE 」の前後にある「pre_hook」と「post_hook」は共通処理となります。後ほど紹介しますが、Notebookに渡されたパラメータに基づいてモデルやモデルサービングの登録や更新、削除などの処理をしています。「EDITABLE」となっている箇所も別ファイルとして管理した方がテンプレートはシンプルになりますが、Databricks上でモデルの実行結果等可視化して確認しやすいようこのような運用を採用しています。

# Databricks notebook source

# MAGIC %run ./model_serving

# COMMAND ----------
import json
import mlflow
mlflow.set_registry_uri("databricks-uc")

# COMMAND ----------

params_string = dbutils.widgets.get("params")
params = json.loads(params_string)
print(f'params: {params}')


# COMMAND ----------

# MAGIC %python
# MAGIC pre_hook(params)

# COMMAND ----------
################ EDITABLE ################ 
################ START ################ 
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier

model = params['model']
artifact_path = model['artifact_path']
endpoint = params['endpoint']
model_name = endpoint['config']['served_entities'][0]['entity_name']

experiment_path = model['experiment_path']
mlflow.set_experiment(experiment_path)

with mlflow.start_run():
    # Train a sklearn model on the iris dataset
    X, y = datasets.load_iris(return_X_y=True, as_frame=True)
    clf = RandomForestClassifier(max_depth=7)
    clf.fit(X, y)
    # Take the first row of the training dataset as the model input example.
    input_example = X.iloc[[0]]
    # Log the model and register it as a new version in UC.
    mlflow.sklearn.log_model(
        sk_model=clf,
        artifact_path=artifact_path,
        # The signature is automatically inferred from the input example and its predicted output.
        input_example=input_example,
        registered_model_name=model_name,
    )
################ END ################ 

# COMMAND ----------

# MAGIC %python
# MAGIC post_hook(params)

次にテンプレートで参照している「pre_hook」と「post_hook」について紹介します。以下のコードを見るとわかりますが、Notebookのパラメータに基づいてモデルやモデルサービングの作成や更新、削除等を行っています。「pre_hook」と「post_hook」の処理から見ていただくと雰囲気が掴めると思います。

# Databricks notebook source

from mlflow.deployments import get_deploy_client
from mlflow.tracking import MlflowClient
import time
from distutils.util import strtobool


# COMMAND ----------

def pre_hook(params):
    model = params['model']
    force_delete = bool(strtobool(params['force_delete']))
    endpoint = params['endpoint']
    endpoint_name = endpoint['endpoint_name']
    config = endpoint['config']
    model_name = config['served_entities'][0]['entity_name']
    model_version = config['served_entities'][0]['entity_version']
    print(f'codfig: {config}')
    print(f'force_delete: {force_delete}')
    if model_registered_exists(model_name) and force_delete:
        delete_registered_model(model_name)
    print('skip delete_registered_model')
    if endpoint_exists(endpoint_name) and force_delete:
        delete_model_serving_endpoint(endpoint_name)
        dbutils.notebook.exit("Model serving endpoint deleted. Exiting the notebook.")
    print('skip delete_model_serving_endpoint')
    if model_registered_exists(model_name) and endpoint_exists(endpoint_name) and model_version != 'latest':
        update_model_serving_endpoint(endpoint_name, config)
        wait_for_endpoint_ready(endpoint_name)
        dbutils.notebook.exit(f'Updated model endpoint version: {model_version}. Exiting the notebook')
    print('skip update_model_serving_endpoint')

def post_hook(params):
    endpoint = params['endpoint']
    endpoint_name = endpoint['endpoint_name']
    config = endpoint['config']
    model_name = config['served_entities'][0]['entity_name']
    model_version = config['served_entities'][0]['entity_version']
    print(f'codfig: {config}')

    if model_version == 'latest':
        model_version = get_model_registered_latest_version(model_name)
        config['served_entities'][0]['entity_version'] = model_version
    print(f'model version: {model_version}')
    # Check if the endpoint exists
    if endpoint_exists(endpoint_name):
        print(f"Endpoint '{endpoint_name}' exists.")
        wait_for_endpoint_ready(endpoint_name)
        print(f'update_model_serving_endpoint')
        update_model_serving_endpoint(endpoint_name, config)
        
    else:
        print(f"Endpoint '{endpoint_name}' does not exist.")
        print(f'create_model_serving_endpoint')
        create_model_serving_endpoint(endpoint_name, config)

    wait_for_endpoint_ready(endpoint_name)
    print('end')

# Model Registerd
def delete_registered_model(model_name):
    try:
        client = MlflowClient()
        client.delete_registered_model(name=model_name)
        print(f"Model '{model_name}' has been deleted.")
    except Exception as e:
        print(f"Error deleting model '{model_name}': {e}")
        raise

    
def model_registered_exists(model_name):
    try:
        client = MlflowClient()
        registered_models = client.search_registered_models()
        return any(model.name == model_name for model in registered_models)
    except Exception as e:
        print(f"Error checking if model '{model_name}' is registered in Unity Catalog: {e}")
        return False

def get_model_registered_latest_version(model_name):
    try:
        client = MlflowClient()
        print(f'model_name:{model_name}')
        client = MlflowClient()
        model_version_infos = client.search_model_versions("name = '%s'" % model_name)
        print(f'model_version_infos:{model_version_infos}')
        latest_version = max([int(model_version_info.version) for model_version_info in model_version_infos])
        return latest_version
    except Exception as e:
        print(f"Error get_model_registered_latest_version: {e}")
        raise


# Model serving
def delete_model_serving_endpoint(endpoint_name):
    try:
        deploy_client = get_deploy_client("databricks")
        deploy_client.delete_endpoint(endpoint=endpoint_name)
        print(f"Model serving endpoint '{endpoint_name}' has been deleted.")
    except Exception as e:
        print(f"Error deleting model serving endpoint '{endpoint_name}': {e}")
        raise

def create_model_serving_endpoint(name, config):
    try:
        deploy_client = get_deploy_client("databricks")
        deploy_client.create_endpoint(name=name,config=config)
        print(f"Model serving endpoint '{name}' has been created. config: {config}")
    except Exception as e:
        print(f"Error createing model serving endpoint '{name}': {e}")
        raise

def update_model_serving_endpoint(endpoint_name, config):
    try:
        deploy_client = get_deploy_client("databricks")
        deploy_client.update_endpoint(endpoint=endpoint_name,config=config)
        print(f"Model serving endpoint '{endpoint_name}' has been updated. config: {config}")
    except Exception as e:
        print(f"Error updating model serving endpoint '{endpoint_name}': {e}")
        raise

def endpoint_exists(endpoint_name):
    try:
        deploy_client = get_deploy_client("databricks")
        endpoints = deploy_client.list_endpoints()
        return any(endpoint['name'] == endpoint_name for endpoint in endpoints)
    except Exception as e:
        print(f"Error checking if endpoint exists '{endpoint_name}': {e}")
        return False

def get_endpoint_status(endpoint_name):
    try:
        deploy_client = get_deploy_client("databricks")
        endpoint_status = deploy_client.get_endpoint(endpoint=endpoint_name)
        return endpoint_status
    except Exception as e:
        print(f"Error getting status for endpoint '{endpoint_name}': {e}")
        raise

def wait_for_endpoint_ready(endpoint_name, timeout=1000, interval=30):
    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            # https://docs.databricks.com/api/workspace/servingendpoints/get
            status = get_endpoint_status(endpoint_name)
            if status['state']['ready'] == 'READY':
                print(f"Endpoint '{endpoint_name}' is ready.")

            if status['state']['config_update'] == 'NOT_UPDATING':
                print(f"Endpoint '{endpoint_name}' is currently NOT_UPDATING.")
                return
            else:
                print(f"Endpoint '{endpoint_name}' is currently being updated. Waiting...")
        except Exception as e:
            if "RESOURCE_CONFLICT" in str(e):
                print(f"Endpoint '{endpoint_name}' is currently being updated. Waiting...")
            else:
                raise
        time.sleep(interval)
    raise TimeoutError(f"Endpoint '{endpoint_name}' is not ready after {timeout} seconds.")

Databricks Job(Terraform)

Notebook JobはTerraformを使用してデプロイしており、パラメータに関する情報は以下のYAMLファイルに定義されています。このYAMLの設定に基づき、TerraformでNotebook Jobを動的にデプロイしています。

jobs:
  wf_test_model_endpoint_tmpl:
    name: wf_test_model_endpoint_tmpl
    domain_tag: test
    notebook_path: "notebook/model_serving/model_endpoint.tmpl"
    job_params:
      {
        "force_delete": "false",
        "model":
          {
            "experiment_path": "/Shared/common_model_experiments/sample",
            "artifact_path": "model",
          },
        "endpoint":
          {
            "endpoint_name": "workspace-model-endpoint",
            "config": { "served_entities": [
                    {
                      "name": "iris_model_serving",
                      "entity_name": "${Env}_data_science.<project>.sample_model",
                      "entity_version": "latest", # latest or version number for rollback
                      "workload_size": "Small",
                      "scale_to_zero_enabled": "true",
                    },
                  ], "auto_capture_config": { "catalog_name": "${Env}_catalog", "schema_name": "sample_schema, "enabled": "true" } },
          },
      }
    clusters:
      {
        spark_version: "15.3.x-cpu-ml-scala2.12",
        node_type_id: "i3.2xlarge",
        driver_node_type_id: "i3.2xlarge",
        autoscale: { min_workers: 2, max_workers: 10 },
      }
    template_path: "../jobs/template/model_serving_template.json"
    git_url: "https://git-codecommit.ap-northeast-1.amazonaws.com/v1/repos/<repository_name>"

モデルサービング用に作成したNotebook Jobのテンプレートは以下のとおりです。このテンプレートでは、YAMLで定義されたパラメータを活用して動的にJobの定義を構築しています。

{
    "name": "${name}",
    "email_notifications": {
        "no_alert_for_skipped_runs": false
    },
    "notification_settings": {
        "no_alert_for_canceled_runs": true
    },
    "webhook_notifications": {},
    "timeout_seconds": 0,
    "max_concurrent_runs": 1,
    "tags": {
        "product": "${domain_tag}"
    },
    "parameters": [
        {
            "name": "params",
            "default": "${params}"
        }
    ],
    "job_clusters": [
        {
            "job_cluster_key": "job_cluster_key_${env}",
            "new_cluster": {
                "spark_version": "${spark_version}",
                "node_type_id": "${node_type_id}",
                "driver_node_type_id": "${driver_node_type_id}",
                "policy_id": "${policy_id}",
                "autoscale": {
                    "min_workers": "${min_workers}",
                    "max_workers": "${max_workers}"
                },
                "aws_attributes": {
                    "first_on_demand": "${first_on_demand}"
                }
            }
        }
    ],
    "tasks": [
        {
            "task_key": "deploy_model_serving",
            "max_retries": 0,
            "notebook_task": {
                "notebook_path": "${notebook_path}",
                "source": "GIT"
            },
            "job_cluster_key": "job_cluster_key_${env}",
            "libraries": [
                {
                    "pypi": {
                        "package": "mlflow-skinny[databricks]>=2.5.0"
                    }
                }
            ]
        }
    ],
    "git_source": {
        "git_url": "${git_url}",
        "git_provider": "awscodecommit",
        "git_tag": "t_${env}"
    },
    "format": "MULTI_TASK"
}

動的に構築されたJobの情報は、以下のTerraformで定義されたJobに渡されています。この仕組みでは、モデルサービング以外のジョブリソースも含めて共通の定義を活用し、効率的な管理を実現しています。各ジョブの差分はlocalsで吸収し、local.job-association-mapを用いて全てのジョブリソース情報を一元的に受け取る構造となっています。

resource "databricks_job" "job" {
  for_each = local.job-association-map
  depends_on = [
    databricks_cluster.shared_dbx_cluster
  ]
  name                = each.value.name
  timeout_seconds     = each.value.timeout_seconds
  max_concurrent_runs = each.value.max_concurrent_runs
  git_source {
    url      = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_url : null
    provider = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_provider : null
    tag      = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_tag : null
  }
  tags = {
    product = lookup(each.value, "tags", null) != null ? each.value.tags.product : local.tags.product
  }
  email_notifications {
    no_alert_for_skipped_runs = lookup(each.value.email_notifications, "no_alert_for_skipped_runs", null) != null ? each.value.email_notifications.no_alert_for_skipped_runs : null
    on_start                  = lookup(each.value.email_notifications, "on_start", []) != [] ? each.value.email_notifications.on_start : []
    on_success                = lookup(each.value.email_notifications, "on_success", []) != [] ? each.value.email_notifications.on_success : []
    on_failure                = lookup(each.value.email_notifications, "on_failure", []) != [] ? each.value.email_notifications.on_failure : local.on_failure
  }
  format = each.value.format

  dynamic "trigger" {
    for_each = { for key, val in each.value :
    key => val if key == "trigger" && val != null }
    content {
      pause_status = trigger.value.pause_status
      file_arrival {
        url = trigger.value.file_arrival.url
      }
    }
  }

  # use existing cluster instead of new_cluster.This will be used for IDBCDB,To import existing resources. 
  dynamic "job_cluster" {
    for_each = { for key, val in each.value.job_clusters :
    key => val if lookup(val, "new_cluster", null) != null }
    content {
      job_cluster_key = each.value.job_clusters[0].job_cluster_key
      new_cluster {
        spark_version       = lookup(each.value.job_clusters[0].new_cluster, "spark_version", null) != null ? each.value.job_clusters[0].new_cluster.spark_version : local.clusters.spark_version
        node_type_id        = lookup(each.value.job_clusters[0].new_cluster, "node_type_id", null) != null ? each.value.job_clusters[0].new_cluster.node_type_id : local.clusters.node_type_id
        driver_node_type_id = lookup(each.value.job_clusters[0].new_cluster, "driver_node_type_id", null) != null ? each.value.job_clusters[0].new_cluster.driver_node_type_id : local.clusters.driver_node_type_id
        policy_id           = lookup(each.value.job_clusters[0].new_cluster, "policy_id", null) != null ? each.value.job_clusters[0].new_cluster.policy_id : local.clusters.policy_id
        runtime_engine      = lookup(each.value.job_clusters[0].new_cluster, "runtime_engine", null) != null ? each.value.job_clusters[0].new_cluster.runtime_engine : local.clusters.runtime_engine
        spark_conf          = lookup(each.value.job_clusters[0].new_cluster, "spark_conf", null) != null ? each.value.job_clusters[0].new_cluster.spark_conf : null
        autoscale {
          min_workers = lookup(each.value.job_clusters[0].new_cluster.autoscale, "min_workers", null) != null ? each.value.job_clusters[0].new_cluster.autoscale.min_workers : local.clusters.autoscale.min_workers
          max_workers = lookup(each.value.job_clusters[0].new_cluster.autoscale, "max_workers", null) != null ? each.value.job_clusters[0].new_cluster.autoscale.max_workers : local.clusters.autoscale.max_workers
        }
        aws_attributes {
          first_on_demand        = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "first_on_demand", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.first_on_demand : local.clusters.aws_attributes.first_on_demand
          availability           = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "availability", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.availability : local.clusters.aws_attributes.availability
          instance_profile_arn   = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "instance_profile_arn", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.instance_profile_arn : local.clusters.aws_attributes.instance_profile_arn
          zone_id                = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "zone_id", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.zone_id : local.clusters.aws_attributes.zone_id
          spot_bid_price_percent = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "spot_bid_price_percent", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.spot_bid_price_percent : local.clusters.aws_attributes.spot_bid_price_percent
        }
        custom_tags = each.value.job_clusters[0].new_cluster.custom_tags
      }
    }
  }

  dynamic "notification_settings" {
    for_each = { for key, val in each.value :
    key => val if key == "notification_settings" && val != null }
    content {
      no_alert_for_skipped_runs  = lookup(notification_settings.value, "no_alert_for_skipped_runs", null) != null ? notification_settings.value.no_alert_for_skipped_runs : null
      no_alert_for_canceled_runs = lookup(notification_settings.value, "no_alert_for_canceled_runs", null) != null ? notification_settings.value.no_alert_for_canceled_runs : null
    }
  }

  dynamic "schedule" {
    for_each = { for key, val in each.value :
    key => val if key == "schedule" && val != null }
    content {
      pause_status           = schedule.value.pause_status
      quartz_cron_expression = schedule.value.quartz_cron_expression
      timezone_id            = schedule.value.timezone_id
    }
  }

  dynamic "parameter" {
    for_each = contains(keys(each.value), "parameter") ? each.value["parameter"] : []
    content {
      default = parameter.value.default
      name    = parameter.value.name
    }
  }

  dynamic "queue" {
    for_each = { for key, val in each.value : key => val if key == "queue" && val != {} }
    content {
      # enabled = lookup(queue.value, "enabled", null) != null ? queue.value.enabled : false
      enabled = queue.value.enabled
    }
  }

  dynamic "task" {
    for_each = each.value.tasks
    content {
      task_key                  = task.value.task_key
      job_cluster_key           = lookup(task.value, "job_cluster_key", null) != null ? task.value.job_cluster_key : null
      existing_cluster_id       = lookup(task.value, "existing_cluster_id", null) != null ? task.value.existing_cluster_id : null
      max_retries               = contains(keys(task.value), "max_retries") ? task.value["max_retries"] : local.max_retries
      min_retry_interval_millis = contains(keys(task.value), "min_retry_interval_millis") ? task.value["min_retry_interval_millis"] : local.min_retry_interval_millis
      run_if                    = lookup(task.value, "run_if", null) != null ? task.value.run_if : null
      dynamic "notebook_task" {
        for_each = { for key, val in task.value :
        key => val if key == "notebook_task" }
        content {
          notebook_path   = notebook_task.value.notebook_path
          base_parameters = lookup(notebook_task.value, "base_parameters", null) != null ? notebook_task.value.base_parameters : {}
          source          = notebook_task.value.source
        }
      }

      dynamic "depends_on" {
        for_each = contains(keys(task.value), "depends_on") ? task.value["depends_on"] : []

        content {
          task_key = depends_on.value.task_key
          outcome  = lookup(depends_on.value, "outcome", null) != null ? depends_on.value.outcome : null
        }
      }

      dynamic "dbt_task" {
        for_each = { for key, val in task.value :
        key => val if key == "dbt_task" }
        content {
          project_directory = task.value.dbt_task.project_directory
          commands          = task.value.dbt_task.commands
          schema            = task.value.dbt_task.schema
          warehouse_id      = task.value.dbt_task.warehouse_id
          catalog           = task.value.dbt_task.catalog
        }
      }

      dynamic "spark_python_task" {
        for_each = { for key, val in task.value :
        key => val if key == "spark_python_task" }
        content {
          parameters  = task.value.spark_python_task.parameters
          python_file = task.value.spark_python_task.python_file
          source      = task.value.spark_python_task.source
        }
      }

      dynamic "condition_task" {
        for_each = { for key, val in task.value :
        key => val if key == "condition_task" }
        content {
          left  = task.value.condition_task.left
          op    = task.value.condition_task.op
          right = task.value.condition_task.right
        }
      }
      dynamic "library" {
        for_each = contains(keys(task.value), "libraries") ? task.value["libraries"] : []
        content {
          pypi { package = task.value.libraries[0].pypi.package }
        }
      }
      timeout_seconds = lookup(task.value, "timeout_seconds", null) != null ? task.value.timeout_seconds : null

      dynamic "email_notifications" {
        for_each = { for key, val in task.value :
        key => val if key == "email_notifications" && val != {} }
        content {
          on_success = lookup(email_notifications.value, "on_success", null) != null ? email_notifications.value.on_success : null
          on_start   = lookup(email_notifications.value, "on_start", null) != null ? email_notifications.value.on_start : null
          on_failure = lookup(email_notifications.value, "on_failure", null) != null ? email_notifications.value.on_failure : null
        }
      }

      dynamic "notification_settings" {
        for_each = { for key, val in task.value :
        key => val if key == "notification_settings" && val != {} }
        content {
          alert_on_last_attempt      = lookup(notification_settings.value, "alert_on_last_attempt", null) != null ? notification_settings.value.alert_on_last_attempt : false
          no_alert_for_canceled_runs = lookup(notification_settings.value, "no_alert_for_canceled_runs", null) != null ? notification_settings.value.no_alert_for_canceled_runs : false
          no_alert_for_skipped_runs  = lookup(notification_settings.value, "no_alert_for_skipped_runs", null) != null ? notification_settings.value.no_alert_for_skipped_runs : false
        }
      }
    }
  }
}

まとめ

モデルサービングのデプロイはモデルの再学習等を考慮して、Terraformで直接デプロイするのではなく、Notebook Jobを活用しています。Notebookはロジック部分を除き共通化できるので、テンプレート化して利用者に提供しています。利用者はテンプレートを活用することで、モデルを作成したのち高速にデプロイできる環境を利用できます。

Databricks × dbt 運用Tips:失敗したモデルだけを効率的にリトライする方法

本記事は、 Databricks - Qiita Advent Calendar 2024 - Qiitaシリーズ 3 の 24 日目の記事です。

データ基盤を運用する際、依存関係を考慮しながら、失敗したモデルやその依存関係のあるモデルのみを再実行したい場合があります。本記事では、Databricksとdbtを活用したリトライ方法についてご紹介します。

Case 1:失敗したモデルと依存関係のある後続モデルのリトライ

dbtの機能である「dbt retry」を活用し、失敗したモデルとその依存関係にある後続処理を効率的に再実行する方法をご紹介します。
dbtの実行結果は「run_results.json」で確認可能です。「dbt retry」では、この「run_results.json」に記録されたログを基に、失敗したモデルのみを再実行することができます。
dbt Cloudでは、失敗したモデルのみを再実行する機能が提供されています。一方、OSS版のdbt Coreを利用している場合、ログが欠損しないようストレージに永続化する仕組みが必要です。しかし、Databricksが提供する標準的なdbt処理ではこの要件に対応できないため、専用のdbt用共通モジュールを作成し、Notebook内でdbtを実行する運用を採用しています。

docs.getdbt.com

具体的には、以下のコードを使用して運用を行っています。Notebook Jobで渡されたパラメータに基づき、dbt runまたはdbt retryを制御します。また、永続化先のストレージで障害が発生した場合に実行ログを失う可能性を考慮し、ログは標準出力にも記録しています。

# Databricks notebook source

# COMMAND ----------
import os
import time
from datetime import datetime
import pytz
import json

env = dbutils.widgets.get("env")
dbt_tag = dbutils.widgets.get("dbt_tag")
threads = dbutils.widgets.get("threads")
dbt_dbfs_state_dir = dbutils.widgets.get("dbt_dbfs_state_dir")
dbt_tmp_dir = f'/tmp-{time.time()}'
dbt_project_dir = f'{dbt_tmp_dir}/dbt/config/{env}'
dbt_profile_dir = f'{dbt_tmp_dir}/dbt/config'
dbt_state_dir = f'{dbt_project_dir}/target'

os.environ['env'] = env
os.environ['dbt_tag'] = dbt_tag
os.environ['threads'] = threads
os.environ['DBT_TEMP_DIR'] = dbt_tmp_dir
os.environ['DBT_PROFILES_DIR'] = dbt_profile_dir
os.environ['DBT_PROJECT_DIR'] = dbt_project_dir
os.environ['DBT_STATE_DIR'] = dbt_state_dir
os.environ['DBT_ENV_SECRET_TOKEN'] = dbutils.secrets.get(f'<secret>', f'<secret-token>')

# COMMAND ----------
# MAGIC %sh
# MAGIC set -eu
# MAGIC ls -ltr ../
# MAGIC mkdir -p ${DBT_STATE_DIR}
# MAGIC cp -r ../dbt ${DBT_TEMP_DIR}
    
# COMMAND ----------
# This params are used to run dbt command with the custome date

dbt_retry = dbutils.widgets.get('dbt_retry')
dbt_command = f"dbt run   --select {dbt_tag} --target={env} --threads {threads}"
if dbt_retry != 'false':
    dbutils.fs.cp(f'{dbt_dbfs_state_dir}/run_results.json', f'file:{dbt_state_dir}/run_results.json')
    dbt_command = f"dbt retry --target={env} --threads {threads}"
print(f'dbt command: {dbt_command}')
os.environ['DBT_COMMAND'] = dbt_command

# COMMAND ----------
# MAGIC %sh
# MAGIC set -eu
# MAGIC cd ${DBT_TEMP_DIR}
# MAGIC dbt deps
# MAGIC ${DBT_COMMAND}

# COMMAND ----------

file_path = f'{dbt_state_dir}/run_results.json'
try:
    with open(file_path, 'r') as file:
        data = json.load(file)
        print(json.dumps(data))
    if dbt_retry != 'false':
        # Remove crc file to avoid error when copying to dbfs
        dbutils.fs.rm( f'file:{dbt_state_dir}/.run_results.json.crc')
    dbutils.fs.cp(f'file:{file_path}', f'{dbt_dbfs_state_dir}/run_results.json')
    for result in data['results']:
        if result['status'] == 'error':
            raise ValueError("Error detected in results")
except Exception as e:
    raise e

Case 2:データ不整合が判明したモデルのリトライ

次に想定されるシナリオとして、モデルの実行結果に不整合が発生し、再度成功したモデルを実行したい場合があります。このようなユースケースでは、dbtのタグ機能を活用します。以下に、モデルmodel_aおよびmodel_bを実行するケースを示します。

dbt run --select tag:model_a,model_b

```

まとめ

dbtで失敗したモデルを再実行する際には、dbt retryを利用しています。ログの損失を防ぐために、ログを永続化する仕組みを整えた上で、Notebookを通じてdbtを実行しています。また、既に成功しているモデルをリトライしたい場合には、dbtのタグ機能を活用して対応しています。

DatabricksのOSS(terraform, dbt)にコントリビュートした話

この記事は、Databricksとdbtのアドベントカレンダー2023の13日めの記事です。

qiita.com

qiita.com

ここ半年ほどDatabricksを使い始めて、最近terraformやdbtにコントリビュートする機会があったので、その紹介をします。色々見つかって楽しいので来年はもっとコントリビュートしていきたいです。

Open Source Repositories Title & Url Status
terraform-provider-databricks Sort based on the Task Key specified in the 'Depends On' field https://github.com/databricks/terraform-provider-databricks/pull/3000 merged
terraform-provider-databricks Add Test Code for Sorting Tasks in the 'Depends On' Field of a Job https://github.com/databricks/terraform-provider-databricks/pull/3183 merged
terraform-provider-databricks Added an item to check if the ticket opener wants to do a bug fix https://github.com/databricks/terraform-provider-databricks/pull/3020 merged
dbt-databricks Fix dbt incremental_strategy behavior by fixing schema table existing check https://github.com/databricks/dbt-databricks/pull/530 merged
dbt-databricks Add DBT_DATABRICKS_UC_INITIAL_CATALOG Option https://github.com/databricks/dbt-databricks/pull/537 merged
dbt-databricks Eliminate the conversion of the schema to lowercase for schema-related test https://github.com/databricks/dbt-databricks/pull/541 merged
dbt-databricks Add schema option for testing https://github.com/databricks/dbt-databricks/pull/548 / https://github.com/databricks/dbt-databricks/pull/538 merged

Databricksはまだ、発展途上の部分もありOSSに対してコントリビュートできる機会が多く楽しめています。来年は開発や運用で見つけたバグだけではなく、issueも積極的に漁りながら貢献していきたいです。

Javaのバージョン管理 備忘録

MacJavaのバージョン管理をする際に備忘録

jenv install

brew install jenv
echo "if which jenv > /dev/null; then eval "$(jenv init -)"; fi" >> ~/.zshrc 
source ~/.zshrc 

M1 Macの場合

sudo softwareupdate --install-rosetta # M1で入れるのに必要

JDK Install

brew tap AdoptOpenJDK/openjdk
brew install --cask adoptopenjdk8 
brew install --cask adoptopenjdk

/usr/libexec/java_home -v 1.8
/Library/Java/JavaVirtualMachines/adoptopenjdk-8.jdk/Contents/Home
/usr/libexec/java_home -v 16         
/Library/Java/JavaVirtualMachines/adoptopenjdk-16.jdk/Contents/Home

JenvにJDKを追加

jenv add /Library/Java/JavaVirtualMachines/adoptopenjdk-8.jdk/Contents/Home
jenv add /Library/Java/JavaVirtualMachines/adoptopenjdk-16.jdk/Contents/Home
```


Jenvに追加されたバージョンを確認

jenv versions
* system (set by /Users/keisuke.taniguchi/.jenv/version)
  1.8
  1.8.0.292
  16.0
  16.0.1
  openjdk64-1.8.0.292
  openjdk64-16.0.1


利用するバージョンを選択

jenv global 1.8

java -version
openjdk version "1.8.0_292"
OpenJDK Runtime Environment (AdoptOpenJDK)(build 1.8.0_292-b10)
OpenJDK 64-Bit Server VM (AdoptOpenJDK)(build 25.292-b10, mixed mode)

ref
Macで多様なJavaバージョンとディストリビューションを管理:HomebrewとjEnvの活用 #homebrew - Qiita

Java 11

brew install --cask temurin11
jenv add /Library/Java/JavaVirtualMachines/temurin-11.jdk/Contents/Home

Java21
```
brew install --cask temurin21
jenv add /Library/Java/JavaVirtualMachines/temurin-21.jdk/Contents/Home

jenv global 21.0
keisuke.taniguchi@keisuketaniguchinoMacBook-Pro api % jenv versions
system
1.8
1.8.0.292
11.0
11.0.26

21.0 (set by /Users/keisuke.taniguchi/.jenv/version)

21.0.6
openjdk64-1.8.0.292
temurin64-11.0.26
temurin64-21.0.6
```

BigQueryに書き込まれたSQL Serverの変更ログを用いて、変更のあったPKの最新データと変更前のデータを取得する方法

ZOZO Advent Calendar 2022 カレンダー25日目の記事です
qiita.com

BigQueryに書き込まれたSQL Serverの変更追跡ログを使って、変更のあったPKの変更ログと変更前のログを取得する方法をご紹介します。

ZOZOではSQL Serverの変更追跡機能を使い、変更のあったPKの最新のレコードをBigQueryに連携しています。SQL Serverの変更追跡では以下のようなクエリを実行して、変更のあったPKとPKに紐づく最新のデータを取得できます。

  SELECT
    a.SYS_CHANGE_OPERATION as changetrack_type,
    a.SYS_CHANGE_VERSION as changetrack_ver,
    #{columns}
  FROM
    CHANGETABLE(CHANGES #{@tablename},
      @前回更新したバージョン) AS a
  LEFT OUTER JOIN #{@tablename} ON a.#{@primary_key} = b.#{@primary_key}

SQL ServerなどDBの変更を追跡する機能はCDC(Change Data Capture)と呼ばれており、DBで変更のあったデータを全て連携するログベースのCDCから、クエリで変更データをポーリング するCDCなどがあります。SQL Serverの変更追跡は後者に該当します。

ログベースのCDCであれば変更前データも取れますが、SQL Serverの変更追跡機能を使う場合、変更のあったPKの変更前のデータは取得できません。本記事ではSQL Serverの変更追跡機能を使ってBigQueryへ連携した変更のあったPKの変更ログに加えて、変更前ログを取得する方法をご紹介します。

techblog.zozo.com
datacater.io

BigQueryで実現する方法

BigQuery上で変更のあったPKの変更ログと変更前のログを取得する方法を紹介します。

変更のあった差分データを取得

SQL Serverの変更追跡機能で連携された差分テーブルから直近2日分の変更データを取得します。直近2日にしているのは、BigQueryパーティションでコストとパフォーマンスを向上させるのと、後ほど紹介する変更前のデータが変更のあった差分データに含まれていない場合に対応するためです。

  streaming AS (
  SELECT
    changetrack_type,
    changetrack_ver ,
    bigquery_insert_time,
   <primary_key> AS primary_key,
    <columns>
  FROM
    <変更追跡で連携された差分テーブル>
  WHERE
    bigquery_insert_time >= TIMESTAMP_SUB(CAST(FORMAT_TIMESTAMP("%Y-%m-%d", TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 2 day), "Asia/Tokyo") AS timestamp), INTERVAL 9 HOUR)),

各カラムは以下の情報を含んでいます。

  • changetrack_type: 変更処理。どのようは変更があったか確認できます(I: insert, U: update, D: delete)
  • changetrack_ver:変更追跡のバージョン。SQL Serverトランザクションごとに発行されるバージョンです。数値の大きいデータが最新になります
  • primary_key: SQL ServerのPKをセットしています
  • bigquery_insert_time: BigQueryに書き込まれた時刻.BigQueryのパーティション機能や遅延計測に使っています
  • columns: 変更を追跡しているテーブルのカラム

最新の変更追跡バージョンを集計

先ほど抽出したデータを用いて、PKごとに最新の変更追跡バージョンを集計します。

  streaming_latest_version AS (
  SELECT
    primary_key,
    MAX(changetrack_ver) AS changetrack_ver_max
  FROM
    streaming
    -- set instead of tracking version
  WHERE
    bigquery_insert_time >= TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 60 second)
  GROUP BY
    primary_key ),

最新の変更追跡バージョンを用いて最新の変更ログを取得

PKごとに集計した最新の変更追跡バージョンと変更ログをJOINすることで、変更ログから変更のあったPKの最新のデータを取得できます。

  streaming_latest AS (
  SELECT
    changetrack_type,
    changetrack_ver,
    streaming.primary_key,
    <columns>
  FROM
    streaming
  INNER JOIN
    streaming_latest_version
  ON
    streaming.primary_key = streaming_latest_version.primary_key
    AND streaming.changetrack_ver = streaming_latest_version.changetrack_ver_max ),

変更前の変更追跡バージョンを取得

次に変更のあったPKの変更前のデータを取得します。変更のあった差分データと最新のバージョンをLEFT JOINし、突合できなかったデータ(changetrack_ver_max IS NULL)から変更前の変更追跡バージョンを取得します。変更のあった最新のバージョンは除外されるため変更前の変更追跡バージョンを取得できます。

  streaming_before_latest_version AS (
  SELECT
    primary_key,
    MAX(changetrack_ver) AS changetrack_ver_before_latest
  FROM (
    SELECT
      streaming.primary_key,
      streaming.changetrack_ver,
      streaming_latest_version.changetrack_ver_max
    FROM
      streaming
    LEFT OUTER JOIN
      streaming_latest_version
    ON
      streaming.primary_key = streaming_latest_version.primary_key
      AND streaming.changetrack_ver = streaming_latest_version.changetrack_ver_max)
  WHERE
    changetrack_ver_max IS NULL
  GROUP BY
    primary_key ),

変更前の変更追跡バージョンを用いて変更前のログを取得

変更前の変更追跡バージョンを使って、変更のあったPKの変更前のデータを取得できます。

  streaming_before_latest AS (
  SELECT
    changetrack_type,
    changetrack_ver,
    streaming.primary_key,
  FROM
    streaming
  INNER JOIN
    streaming_before_latest_version
  ON
    streaming.primary_key = streaming_before_latest_version.changetrack_ver_before_latest
    AND streaming.changetrack_ver = streaming_before_latest_version.changetrack_ver_before_latest ),

前日分の全量テーブルから、差分データに含まれていない変更前データを取得

BigQueryのパーティションで絞りこんでいるため、変更のあった差分データの中に変更前のデータが含まれているとは限りません。変更のあった差分データの中に変更前のデータが含まれていない場合は前日分の全量テーブルから変更前のデータを取得します。

  daily_before_latest AS (
  SELECT
    CAST(NULL AS string) AS changetrack_type,
    CAST(NULL AS int64) AS changetrack_ver,
 <columns>
  FROM (
    SELECT
     <primary_key> AS primary_key,
      <columns>
    FROM
      <前日の全量日付サフィックステーブル>
    WHERE
      _TABLE_SUFFIX IN (SUBSTR(FORMAT_TIMESTAMP("%Y%m%d", TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 1 day), "Asia/Tokyo"), 3)))
  WHERE
    primary_key NOT IN (
    SELECT
      primary_key
    FROM
      streaming_before_latest_version)
    AND primary_key IN (
    SELECT
      primary_key
    FROM
      streaming_latest_version) )

変更のあったPKの変更ログと変更前のログを取得

最後に変更データと変更前データをUNIONすることで、変更のあったPKの最新データと変更前のデータを取得できます。

SELECT
  *
FROM
  streaming_latest
UNION ALL
SELECT
  *
FROM
  streaming_before_latest
UNION ALL
SELECT
  *
FROM
  daily_before_latest

BigQueryクエリ完成形

完成形のクエリは以下のようになります。

WITH
  # 本日分の差分テーブル
  streaming AS (
  SELECT
    changetrack_type,
    changetrack_ver,
    bigquery_insert_time,
    <primary_key> AS primary_key,
    <columns>
  FROM
    <変更追跡で連携された差分テーブル>
  WHERE
    bigquery_insert_time >= TIMESTAMP_SUB(CAST(FORMAT_TIMESTAMP("%Y-%m-%d", TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 2 day), "Asia/Tokyo") AS timestamp), INTERVAL 9 HOUR)),
  # 差分テーブルから最新のバージョン
  streaming_latest_version AS (
  SELECT
    primary_key,
    MAX(changetrack_ver) AS changetrack_ver_max
  FROM
    streaming
    -- set instead of tracking version
  WHERE
    bigquery_insert_time >= TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 60 second)
  GROUP BY
    primary_key ),
  # 差分テーブルから最新のログ
  streaming_latest AS (
  SELECT
    changetrack_type,
    changetrack_ver,
    streaming.primary_key,
    <columns>
  FROM
    streaming
  INNER JOIN
    streaming_latest_version
  ON
    streaming.primary_key = streaming_latest_version.primary_key
    AND streaming.changetrack_ver = streaming_latest_version.changetrack_ver_max ),
  # 差分テーブルにある変更前のバージョン(最新の一つ前のバージョン)
  streaming_before_latest_version AS (
  SELECT
    primary_key,
    MAX(changetrack_ver) AS changetrack_ver_before_latest
  FROM (
    SELECT
      streaming.primary_key,
      streaming.changetrack_ver,
      streaming_latest_version.changetrack_ver_max
    FROM
      streaming
    LEFT OUTER JOIN
      streaming_latest_version
    ON
      streaming.primary_key = streaming_latest_version.primary_key
      AND streaming.changetrack_ver = streaming_latest_version.changetrack_ver_max)
  WHERE
    changetrack_ver_max IS NULL
  GROUP BY
    primary_key ),
  # 差分テーブルにある変更前のデータ(最新の一つ前のデータ)
  streaming_before_latest AS (
  SELECT
    changetrack_type,
    changetrack_ver,
    streaming.primary_key,
    <columns>
  FROM
    streaming
  INNER JOIN
    streaming_before_latest_version
  ON
    streaming.primary_key = streaming_before_latest_version.changetrack_ver_before_latest
    AND streaming.changetrack_ver = streaming_before_latest_version.changetrack_ver_before_latest ),
  # 差分テーブルにない変更前のデータ
  daily_before_latest AS (
  SELECT
    CAST(NULL AS string) AS changetrack_type,
    CAST(NULL AS int64) AS changetrack_ver,
    primary_key,
    <columns>
  FROM (
    SELECT
     <primary_key> AS primary_key,
     <columns>
    FROM
       <前日の全量日付サフィックステーブル>
    WHERE
      _TABLE_SUFFIX IN (SUBSTR(FORMAT_TIMESTAMP("%Y%m%d", TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 1 day), "Asia/Tokyo"), 3)))
  WHERE
    primary_key NOT IN (
    SELECT
      primary_key
    FROM
      streaming_before_latest_version)
    AND primary_key IN (
    SELECT
      primary_key
    FROM
      streaming_latest_version) )
SELECT
  *
FROM
  streaming_latest
UNION ALL
SELECT
  *
FROM
  streaming_before_latest
UNION ALL
SELECT
  *
FROM
  daily_before_latest

Dataflow JDBC テンプレート検証(Java)

Python版を調べてみたがクエリの上書きができなかったり、余計な通信が発生していたりと現時点で本番運用できる状態ではなかった。Java版が使えるか検証してみる。Java版はテンプレートが用意されていたので、PostgresとSQL Serverでそれぞれ検証してみた。結論としてはテンプレートをそのまま使うことはできないけど、クエリの上書きや無駄に実行環境からコネクション張る等Apache Beam Pythonで問題だった動きはしませんでした。本番運用でも使えそうです。
www.case-k.jp
www.case-k.jp
www.case-k.jp


テンプレートでサポートしてるのはBigQueryへの追加のみで全量置換はできていない。全量置換する場合自前で作る必要がある。パフォーマンス等調査は必要だがクエリの上書きやSQL Serverもサポートしていたので自前で用意すれば実務には使えそう。

        .apply(
            "Write to BigQuery",
            BigQueryIO.writeTableRows()
                .withoutValidation()
                .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_NEVER)
                .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND)
                .withCustomGcsTempLocation(options.getBigQueryLoadingTemporaryDirectory())
                .to(options.getOutputTable()));

github.com

並列実行する場合JOBの同時実行数が懸念。プロジェクトごとに25。上限をあげてもオーガナイゼーションレベルで125が最大となっている。相談すれば増やせはするかも。
増やせないなら数百テーブルを高速かつ、並列実行でやる場合GKEのAutopilotで対応した方がコスト、パフォーマンス面で良さそうに思える。

- Each Google Cloud project can run at most 25 concurrent Dataflow jobs.
- If you opt-in to organization level quotas, each organization can run at most 125 concurrent Dataflow jobs. 

Note: If you would like to run more than 25 concurrent Dataflow jobs for your project or more than 125 concurrent Dataflow jobs for your organization, contact Google Cloud Support and we will increase the limit to a value that better suits your needs.
cloud.google.com


設定のパラメータ

  • Postgres
# required parameters
template: Jdbc to BigQuery
jdbc connection URL string:jdbc:postgresql://<private ip address>:5432/beam
jdbc driver class name:org.postgresql.Driver
jdbc source sql query:<query>
BigQuery output table : <project>:<dataset>.<table>
gcs paths for Jdbc drivers: gs://<gcs-bucket>/postgresql-42.2.18.jre7.jar
Temporary directory for BigQuery loading process:gs: gs://<gcs-bucket>//tmp/


# option parameters
Jdbc connection username: <user-name>
Jdbc connection password: <password>
subnetwork:  <subnetwork>
# required parameters
jdbc connection URL string:jdbc:sqlserver://<private-ip>:1433;database=beam;
jdbc driver class name:com.microsoft.sqlserver.jdbc.SQLServerDriver
jdbc source sql query:<query>
BigQuery output table : <project>:<dataset>.<table>
gcs paths for Jdbc drivers: gs://<gcs-bucket>/mssql-jdbc-8.4.1.jre8.jar
Temporary directory for BigQuery loading process:gs: gs://<gcs-bucket>//tmp/


# option parameters
Jdbc connection username: <user-name>
Jdbc connection password: <password>
subnetwork:  <subnetwork>

Apache Beam Python JDBCを使いDataflowを動かすには、ジョブの実行環境からもコネクションを張れる必要があった

Apache BeamのPython jdbcコネクタを使いDataflowでジョブを実行してみました。Cloud SQLとDataflowを同一サブネット内に作りプライベートIPで接続を試みました。検証したところジョブ実行時に実行環境からPostgresにコネクションを張ろうとしていることがわかりました。

class PostgresToBigQueryDataflow():

    def __init__(self):
        self._username = '<username>'
        self._password = '<password>'
        self._driver_class_name = 'org.postgresql.Driver'
        self._query = "select id from beam_table;"
        self._jdbc_url = 'jdbc:postgresql://<private_IP>:5432/beam'
        self._project = '<project id>'
        self._dataset = '<dataset>'
        self._table = '<table>'
        self._options = DebugOptions([
            "--runner=DataflowRunner",
            "--project=<project id>",
            "--job_name=<job name>",
            "--temp_location=gs://<project id>/tmp/",
            "--region=us-central1",
            "--experiments=use_runner_v2",
            "--subnetwork=regions/us-central1/subnetworks/<subnet>",
        ])
    def test(self):
        JdbcToBigQuery(self._username, self._password, self._driver_class_name, self._query, self._jdbc_url, self._project, self._dataset,self._table, self._options).run()

ローカル環境からDataflowジョブを実行した際、不可解なことにコネクションが張れずJOBの実行に失敗しました。WiresharkでパケットをみたところローカルPCからコネクションを試みていました。

f:id:casekblog:20220214124507p:plain

ジョブを実行する場合、実行環境からもDBに接続できる必要がありそうです。

GCPの同一サブネット内からは問題なくJOBを実行できました。テンプレート化することで回避できるかもしれませんが未検証。
クエリの上書きできない問題もあるので、現時点でPython版を使うのはやめた方がよさそう。

BigQuery Flex SlotsをPython版に置き換えた

bashでやっていたがエラーハンドリングが色々きつかったのでPythonに書き換えました。ドキュメントは不十分だったのでGitHub見ながら作る感じになります。
一通り機能はあるので同じようなことしようとしてる方の参考になれば幸いです。

techblog.zozo.com


コード;
github.com

from google.cloud.bigquery_reservation_v1.services import reservation_service
from google.cloud.bigquery_reservation_v1.types import reservation as reservation_types
from google.protobuf import field_mask_pb2
import adapter.repository.workflow.workflow_repository as workflow_repository
import dto.workflow.workflow_input_objects as workflow_input_objects

class BigQueryReservation():
    def __init__(self, admin_project, assignee_project, location, reservation, commitment_slot_count, reservation_slot_count):
        self._admin_project = admin_project
        self._assignee_project = assignee_project
        self._location = location
        self._reservation = reservation
        self._commitment_slot_count = commitment_slot_count
        self._reservation_slot_count = reservation_slot_count
        self._parent = f'projects/{self._admin_project}/locations/{self._location}'
        self._reservation_client = reservation_service.ReservationServiceClient()

    def upgrade_bigquery_slot_capacity(self):
        self._create_capacity_commitment(plan='FLEX', slot_count=self._commitment_slot_count)
        state = self._fetch_commitment_state(plan='FLEX', slot_count=self._commitment_slot_count)
        if state == reservation_types.CapacityCommitment.State.ACTIVE:
            self._update_reservation(slot_capacity=self._reservation_slot_count)
            self._set_bigquery_plan(bigquery_ondemand_plan=False)
        elif state == reservation_types.CapacityCommitment.State.PENDING:
            # use ondemand paln could be took several hours until successful
            # https://techblog.zozo.com/entry/bigquery-flex-slots
            assignment_id = self._fetch_assignment_id()
            self._delete_assignment(assignment_id)
            commitment_id = self._fetch_commitment_id(plan='FLEX', slot_count=self._commitment_slot_count)
            self._delete_capacity_commitment(commitment_id)
            # set ondemand_plan for digdag slack notice. cloud be delay batch
            self._set_bigquery_plan(bigquery_ondemand_plan=True)
        else:
            commitment_id = self._fetch_commitment_id(plan='FLEX', slot_count=self._commitment_slot_count)
            if len(commitment_id) != 0:
                # not to delete before create capacity commitment because other team could be buy and use same amount of flex slots
                self._delete_capacity_commitment(commitment_id)
            raise Exception(f'failed to buy commitment')

    def downgrade_bigquery_slot_capacity(self):
        assignment_id = self._fetch_assignment_id()
        if len(assignment_id) == 0:
            self._create_assignments()
        self._update_reservation(slot_capacity=self._reservation_slot_count)
        commitment_id = self._fetch_commitment_id(plan='FLEX', slot_count=self._commitment_slot_count)
        if len(commitment_id) != 0:
            self._delete_capacity_commitment(commitment_id)

    def _set_bigquery_plan(self, bigquery_ondemand_plan):
        workflow_input = workflow_input_objects.WorkflowInputObjects.generate_workflow_input_objects([{'bigquery_ondemand_plan': bigquery_ondemand_plan}])
        workflow_repository.WorkflowRepository().save(workflow_input)

    def _create_capacity_commitment(self, plan, slot_count):
        commit_config = reservation_types.CapacityCommitment(plan=plan, slot_count=slot_count)
        self._reservation_client.create_capacity_commitment(parent=self._parent,capacity_commitment=commit_config)

    def _create_assignments(self):
        assign_config = reservation_types.Assignment(job_type='QUERY',assignee=f'projects/{self._assignee_project}')
        assign = self._reservation_client.create_assignment(parent=f'{self._parent }/reservations/{self._reservation}', assignment=assign_config)

    def _update_reservation(self, slot_capacity):
        reservation_name = self._reservation_client.reservation_path(
            project=self._admin_project, location=self._location, reservation=self._reservation
        )
        reservation = reservation_types.Reservation(
            name=reservation_name, slot_capacity=slot_capacity,
        )
        field_mask = field_mask_pb2.FieldMask(paths=["slot_capacity"])
        self._reservation_client.update_reservation(reservation=reservation, update_mask=field_mask)

    def _to_commitment_plan(self, plan):
        if plan == 'FLEX':
            return reservation_types.CapacityCommitment.CommitmentPlan.FLEX
        elif plan == 'MONTHLY':
            return reservation_types.CapacityCommitment.CommitmentPlan.MONTHLY
        elif plan == 'ANNUAL':
            return reservation_types.CapacityCommitment.CommitmentPlan.ANNUAL
        else:
            raise Exception(f'plan is not match {plan}')

    def _fetch_commitment_id(self, plan, slot_count):
        commitments = self._reservation_client.list_capacity_commitments(parent=self._parent)
        commitment_id = [ commitment.name for commitment in commitments if commitment.plan == self._to_commitment_plan(plan) and commitment.slot_count == slot_count ]
        if len(commitment_id) != 0:
            commitment_id = commitment_id[0]
        return commitment_id

    def _fetch_commitment_state(self, plan, slot_count):
        commitments = self._reservation_client.list_capacity_commitments(parent=self._parent)
        commitment_state = [ commitment.state for commitment in commitments if commitment.plan == self._to_commitment_plan(plan) and commitment.slot_count == slot_count ]
        if len(commitment_state) != 0:
            commitment_state = commitment_state[0]
        return commitment_state

    def _fetch_assignment_id(self):
        assignments = self._reservation_client.list_assignments(parent=f'{self._parent }/reservations/{self._reservation}')
        assignment_id = [ assignment.name for assignment in assignments if assignment.assignee == f'projects/{self._assignee_project}']
        if len(assignment_id) != 0:
            assignment_id = assignment_id[0]
        return assignment_id

    def _delete_assignment(self, assignment):
        self._reservation_client.delete_assignment(name=assignment)

    def _delete_capacity_commitment(self, commitment_id):
        self._reservation_client.delete_capacity_commitment(name=commitment_id)