Last active
March 16, 2025 15:06
-
-
Save rajvermacas/5d33ec7234a2c5cf174156ceae605843 to your computer and use it in GitHub Desktop.
Azure Databricks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from databricks.sdk import WorkspaceClient | |
from databricks.sdk.core import Config | |
def main(): | |
# Parse command-line arguments | |
parser = argparse.ArgumentParser(description="Run an Azure Databricks job using a Service Principal") | |
parser.add_argument("--client_id", required=True, help="Azure Service Principal client ID") | |
parser.add_argument("--client_secret", required=True, help="Azure Service Principal client secret") | |
parser.add_argument("--tenant_id", required=True, help="Azure tenant ID") | |
parser.add_argument("--workspace_url", required=True, help="Databricks workspace URL (e.g., https://adb-1234567890123456.7.azuredatabricks.net)") | |
parser.add_argument("--job_id", required=True, help="Databricks job ID to execute") | |
args = parser.parse_args() | |
# Set up SDK configuration with Service Principal credentials | |
config = Config( | |
host=args.workspace_url, | |
auth_type="oauth", | |
client_id=args.client_id, | |
client_secret=args.client_secret, | |
tenant_id=args.tenant_id | |
) | |
# Initialize the WorkspaceClient | |
client = WorkspaceClient(config=config) | |
# Trigger the job | |
run_id = client.jobs.run_now(job_id=int(args.job_id)) | |
# Output the run ID | |
print(f"Job triggered successfully. Run ID: {run_id.run_id}") | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from datetime import datetime | |
from azure.identity import ClientSecretCredential | |
from azure.mgmt.databricks import DatabricksManagementClient | |
from databricks.sdk import WorkspaceClient | |
from databricks.sdk.service import jobs | |
def execute_databricks_job( | |
subscription_id, | |
resource_group, | |
workspace_name, | |
tenant_id, | |
client_id, | |
client_secret, | |
job_id=None, | |
notebook_path=None, | |
parameters=None | |
): | |
""" | |
Execute an Azure Databricks job using both Azure SDK and Databricks SDK | |
Args: | |
subscription_id (str): Azure subscription ID | |
resource_group (str): Azure resource group containing the Databricks workspace | |
workspace_name (str): Name of the Databricks workspace | |
tenant_id (str): Azure tenant ID | |
client_id (str): Azure SPN client ID | |
client_secret (str): Azure SPN client secret | |
job_id (int, optional): ID of an existing job to run | |
notebook_path (str, optional): Path to notebook if creating a new run | |
parameters (dict, optional): Parameters to pass to the job or notebook | |
Returns: | |
run_id (int): The ID of the executed run | |
workspace_url (str): The URL of the Databricks workspace | |
""" | |
# Authenticate using Azure SDK | |
credential = ClientSecretCredential( | |
tenant_id=tenant_id, | |
client_id=client_id, | |
client_secret=client_secret | |
) | |
# Get Databricks workspace information using Azure SDK | |
databricks_client = DatabricksManagementClient( | |
credential=credential, | |
subscription_id=subscription_id | |
) | |
# Get the workspace URL | |
workspace = databricks_client.workspaces.get( | |
resource_group_name=resource_group, | |
workspace_name=workspace_name | |
) | |
workspace_url = workspace.workspace_url | |
print(f"Found workspace URL: {workspace_url}") | |
# Create Databricks SDK client using the Azure credential | |
db_client = WorkspaceClient( | |
host=f"https://{workspace_url}", | |
azure_client_id=client_id, | |
azure_client_secret=client_secret, | |
azure_tenant_id=tenant_id | |
) | |
# Execute existing job if job_id is provided | |
if job_id: | |
run_parameters = {} | |
if parameters: | |
run_parameters["notebook_params"] = parameters | |
run = db_client.jobs.run_now( | |
job_id=job_id, | |
**run_parameters | |
) | |
run_id = run.run_id | |
print(f"Started job run with ID: {run_id}") | |
# Create a new run if notebook_path is provided | |
elif notebook_path: | |
run_name = f"Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" | |
# Define a new cluster configuration | |
new_cluster = jobs.ClusterSpec( | |
spark_version="11.3.x-scala2.12", | |
node_type_id="Standard_DS3_v2", | |
num_workers=1 | |
) | |
# Set up notebook task | |
notebook_task = jobs.NotebookTask( | |
notebook_path=notebook_path, | |
base_parameters=parameters if parameters else None | |
) | |
# Create a task | |
task = jobs.Task( | |
task_key="notebook_task", | |
notebook_task=notebook_task, | |
new_cluster=new_cluster | |
) | |
# Submit the run | |
run = db_client.jobs.submit( | |
run_name=run_name, | |
tasks=[task] | |
) | |
run_id = run.run_id | |
print(f"Submitted notebook run with ID: {run_id}") | |
else: | |
raise ValueError("Either job_id or notebook_path must be provided") | |
return run_id, workspace_url | |
def check_run_status( | |
run_id, | |
tenant_id, | |
client_id, | |
client_secret, | |
workspace_url=None, | |
workspace_name=None, | |
subscription_id=None, | |
resource_group=None | |
): | |
""" | |
Check the status of a Databricks job run | |
Args: | |
run_id (int): The run ID to check | |
tenant_id (str): Azure tenant ID | |
client_id (str): Azure SPN client ID | |
client_secret (str): Azure SPN client secret | |
workspace_url (str, optional): Direct URL to the Databricks workspace | |
workspace_name (str, optional): Name of the Databricks workspace (used with subscription_id and resource_group) | |
subscription_id (str, optional): Azure subscription ID (used with workspace_name) | |
resource_group (str, optional): Azure resource group (used with workspace_name) | |
Returns: | |
dict: Information about the run status including: | |
- status: current run state (PENDING, RUNNING, TERMINATED, etc.) | |
- result: result state if completed (SUCCESS, FAILED, etc.) | |
- is_completed: boolean indicating if run is complete | |
- run_details: complete run details object | |
""" | |
# Determine workspace URL | |
if not workspace_url and workspace_name and subscription_id and resource_group: | |
# Get workspace URL from Azure if not provided directly | |
credential = ClientSecretCredential( | |
tenant_id=tenant_id, | |
client_id=client_id, | |
client_secret=client_secret | |
) | |
databricks_client = DatabricksManagementClient( | |
credential=credential, | |
subscription_id=subscription_id | |
) | |
workspace = databricks_client.workspaces.get( | |
resource_group_name=resource_group, | |
workspace_name=workspace_name | |
) | |
workspace_url = workspace.workspace_url | |
elif not workspace_url: | |
raise ValueError("Either workspace_url or (workspace_name, subscription_id, and resource_group) must be provided") | |
# Create Databricks SDK client using the Azure credential | |
db_client = WorkspaceClient( | |
host=f"https://{workspace_url}", | |
azure_client_id=client_id, | |
azure_client_secret=client_secret, | |
azure_tenant_id=tenant_id | |
) | |
# Get run details | |
run_details = db_client.jobs.get_run(run_id=run_id) | |
# Extract key status information | |
life_cycle_state = run_details.state.life_cycle_state | |
result_state = run_details.state.result_state if hasattr(run_details.state, 'result_state') else None | |
# Determine if run is completed | |
is_completed = life_cycle_state == "TERMINATED" | |
# Format timestamps if available | |
start_time = run_details.start_time | |
end_time = run_details.end_time if hasattr(run_details, 'end_time') else None | |
if start_time: | |
start_time_str = start_time.strftime('%Y-%m-%d %H:%M:%S') | |
else: | |
start_time_str = None | |
if end_time: | |
end_time_str = end_time.strftime('%Y-%m-%d %H:%M:%S') | |
duration = (end_time - start_time).total_seconds() if start_time else None | |
else: | |
end_time_str = None | |
duration = None | |
# Get any error information if available | |
error_message = None | |
if hasattr(run_details.state, 'state_message'): | |
error_message = run_details.state.state_message | |
# Create response with useful status information | |
status_info = { | |
"run_id": run_id, | |
"status": life_cycle_state, | |
"result": result_state, | |
"is_completed": is_completed, | |
"start_time": start_time_str, | |
"end_time": end_time_str, | |
"duration_seconds": duration, | |
"error_message": error_message, | |
"run_page_url": f"https://{workspace_url}/#job/runs/detail/{run_id}", | |
"run_details": run_details | |
} | |
return status_info | |
def monitor_job_run( | |
run_id, | |
tenant_id, | |
client_id, | |
client_secret, | |
workspace_url, | |
polling_interval=30, | |
timeout=3600 | |
): | |
""" | |
Monitor a Databricks job run until completion | |
Args: | |
run_id (int): The run ID to monitor | |
tenant_id (str): Azure tenant ID | |
client_id (str): Azure SPN client ID | |
client_secret (str): Azure SPN client secret | |
workspace_url (str): The URL of the Databricks workspace | |
polling_interval (int): Seconds to wait between status checks | |
timeout (int): Maximum seconds to wait before timing out | |
Returns: | |
dict: Final run status | |
""" | |
start_time = time.time() | |
while True: | |
status = check_run_status( | |
run_id=run_id, | |
tenant_id=tenant_id, | |
client_id=client_id, | |
client_secret=client_secret, | |
workspace_url=workspace_url | |
) | |
print(f"Current state: {status['status']}, Result: {status['result']}") | |
if status['is_completed']: | |
print(f"Run completed with result: {status['result']}") | |
return status | |
# Check if we've hit the timeout | |
if time.time() - start_time > timeout: | |
print("Monitoring timed out") | |
return status | |
time.sleep(polling_interval) | |
# Example usage | |
if __name__ == "__main__": | |
# Configuration - replace with your values | |
subscription_id = "your-subscription-id" | |
resource_group = "your-resource-group" | |
workspace_name = "your-workspace-name" | |
tenant_id = "your-tenant-id" | |
client_id = "your-spn-client-id" | |
client_secret = "your-spn-client-secret" | |
# Option 1: Run an existing job by ID | |
job_id = 123 # Replace with your job ID | |
# Option 2: Run a specific notebook | |
notebook_path = "/Path/To/Your/Notebook" # Start with / | |
# Parameters to pass to the job (optional) | |
parameters = { | |
"param1": "value1", | |
"param2": "value2" | |
} | |
try: | |
# Execute the job - choose one of these options: | |
# Option 1: Run an existing job | |
run_id, workspace_url = execute_databricks_job( | |
subscription_id=subscription_id, | |
resource_group=resource_group, | |
workspace_name=workspace_name, | |
tenant_id=tenant_id, | |
client_id=client_id, | |
client_secret=client_secret, | |
job_id=job_id, | |
parameters=parameters | |
) | |
print(f"Job execution initiated. Run ID: {run_id}") | |
print(f"Check status at: https://{workspace_url}/#job/runs/detail/{run_id}") | |
# Option 1: Monitor until completion | |
final_status = monitor_job_run( | |
run_id=run_id, | |
tenant_id=tenant_id, | |
client_id=client_id, | |
client_secret=client_secret, | |
workspace_url=workspace_url | |
) | |
# Option 2: Check status once | |
# status = check_run_status( | |
# run_id=run_id, | |
# tenant_id=tenant_id, | |
# client_id=client_id, | |
# client_secret=client_secret, | |
# workspace_url=workspace_url | |
# ) | |
# | |
# print(f"Current status: {status['status']}") | |
# print(f"Is completed: {status['is_completed']}") | |
# if status['result']: | |
# print(f"Result: {status['result']}") | |
except Exception as e: | |
print(f"Error: {str(e)}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from datetime import datetime | |
from azure.identity import ClientSecretCredential | |
from azure.mgmt.databricks import DatabricksManagementClient | |
from databricks.sdk import WorkspaceClient | |
from databricks.sdk.service import jobs | |
def execute_databricks_job( | |
subscription_id, | |
resource_group, | |
workspace_name, | |
tenant_id, | |
client_id, | |
client_secret, | |
job_id=None, | |
notebook_path=None, | |
parameters=None | |
): | |
""" | |
Execute an Azure Databricks job using both Azure SDK and Databricks SDK | |
Args: | |
subscription_id (str): Azure subscription ID | |
resource_group (str): Azure resource group containing the Databricks workspace | |
workspace_name (str): Name of the Databricks workspace | |
tenant_id (str): Azure tenant ID | |
client_id (str): Azure SPN client ID | |
client_secret (str): Azure SPN client secret | |
job_id (int, optional): ID of an existing job to run | |
notebook_path (str, optional): Path to notebook if creating a new run | |
parameters (dict, optional): Parameters to pass to the job or notebook | |
Returns: | |
dict: Response containing run information | |
""" | |
# Authenticate using Azure SDK | |
credential = ClientSecretCredential( | |
tenant_id=tenant_id, | |
client_id=client_id, | |
client_secret=client_secret | |
) | |
# Get Databricks workspace information using Azure SDK | |
databricks_client = DatabricksManagementClient( | |
credential=credential, | |
subscription_id=subscription_id | |
) | |
# Get the workspace URL | |
workspace = databricks_client.workspaces.get( | |
resource_group_name=resource_group, | |
workspace_name=workspace_name | |
) | |
workspace_url = workspace.workspace_url | |
print(f"Found workspace URL: {workspace_url}") | |
# Create Databricks SDK client using the Azure credential | |
db_client = WorkspaceClient( | |
host=f"https://{workspace_url}", | |
azure_client_id=client_id, | |
azure_client_secret=client_secret, | |
azure_tenant_id=tenant_id | |
) | |
# Execute existing job if job_id is provided | |
if job_id: | |
run_parameters = {} | |
if parameters: | |
run_parameters["notebook_params"] = parameters | |
run = db_client.jobs.run_now( | |
job_id=job_id, | |
**run_parameters | |
) | |
run_id = run.run_id | |
print(f"Started job run with ID: {run_id}") | |
# Create a new run if notebook_path is provided | |
elif notebook_path: | |
run_name = f"Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" | |
# Define a new cluster configuration | |
new_cluster = jobs.ClusterSpec( | |
spark_version="11.3.x-scala2.12", | |
node_type_id="Standard_DS3_v2", | |
num_workers=1 | |
) | |
# Set up notebook task | |
notebook_task = jobs.NotebookTask( | |
notebook_path=notebook_path, | |
base_parameters=parameters if parameters else None | |
) | |
# Create a task | |
task = jobs.Task( | |
task_key="notebook_task", | |
notebook_task=notebook_task, | |
new_cluster=new_cluster | |
) | |
# Submit the run | |
run = db_client.jobs.submit( | |
run_name=run_name, | |
tasks=[task] | |
) | |
run_id = run.run_id | |
print(f"Submitted notebook run with ID: {run_id}") | |
else: | |
raise ValueError("Either job_id or notebook_path must be provided") | |
# Return the run details | |
run_details = db_client.jobs.get_run(run_id=run_id) | |
return run_details | |
def monitor_job_run(db_client, run_id, polling_interval=30, timeout=3600): | |
""" | |
Monitor a Databricks job run until completion | |
Args: | |
db_client: Databricks SDK client | |
run_id (int): The run ID to monitor | |
polling_interval (int): Seconds to wait between status checks | |
timeout (int): Maximum seconds to wait before timing out | |
Returns: | |
dict: Final run status | |
""" | |
start_time = time.time() | |
while True: | |
run_details = db_client.jobs.get_run(run_id=run_id) | |
life_cycle_state = run_details.state.life_cycle_state | |
result_state = run_details.state.result_state if hasattr(run_details.state, 'result_state') else None | |
print(f"Current state: {life_cycle_state}, Result: {result_state}") | |
if life_cycle_state == "TERMINATED": | |
print(f"Run completed with result: {result_state}") | |
return run_details | |
# Check if we've hit the timeout | |
if time.time() - start_time > timeout: | |
print("Monitoring timed out") | |
return run_details | |
time.sleep(polling_interval) | |
# Example usage | |
if __name__ == "__main__": | |
# Configuration - replace with your values | |
subscription_id = "your-subscription-id" | |
resource_group = "your-resource-group" | |
workspace_name = "your-workspace-name" | |
tenant_id = "your-tenant-id" | |
client_id = "your-spn-client-id" | |
client_secret = "your-spn-client-secret" | |
# Option 1: Run an existing job by ID | |
job_id = 123 # Replace with your job ID | |
# Option 2: Run a specific notebook | |
notebook_path = "/Path/To/Your/Notebook" # Start with / | |
# Parameters to pass to the job (optional) | |
parameters = { | |
"param1": "value1", | |
"param2": "value2" | |
} | |
try: | |
# Execute the job - choose one of these options: | |
# Option 1: Run an existing job | |
run_details = execute_databricks_job( | |
subscription_id=subscription_id, | |
resource_group=resource_group, | |
workspace_name=workspace_name, | |
tenant_id=tenant_id, | |
client_id=client_id, | |
client_secret=client_secret, | |
job_id=job_id, | |
parameters=parameters | |
) | |
# Option 2: Submit a notebook run | |
# run_details = execute_databricks_job( | |
# subscription_id=subscription_id, | |
# resource_group=resource_group, | |
# workspace_name=workspace_name, | |
# tenant_id=tenant_id, | |
# client_id=client_id, | |
# client_secret=client_secret, | |
# notebook_path=notebook_path, | |
# parameters=parameters | |
# ) | |
# Get the run ID for monitoring | |
run_id = run_details.run_id | |
# Create Databricks client for monitoring | |
db_client = WorkspaceClient( | |
host=f"https://{workspace_name}.azuredatabricks.net", | |
azure_client_id=client_id, | |
azure_client_secret=client_secret, | |
azure_tenant_id=tenant_id | |
) | |
# Monitor the job until completion (optional) | |
# final_status = monitor_job_run(db_client, run_id) | |
print(f"Job execution initiated. Run ID: {run_id}") | |
print(f"Check status at: https://{workspace_name}.azuredatabricks.net/#job/{job_id}/run/{run_id}") | |
except Exception as e: | |
print(f"Error: {str(e)}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime, timedelta | |
from airflow import DAG | |
from airflow.providers.databricks.operators.databricks import DatabricksRunNowOperator | |
from airflow.hooks.base import BaseHook | |
from airflow.models.connection import Connection | |
import requests | |
import json | |
import os | |
# Default arguments for the DAG | |
default_args = { | |
'owner': 'airflow', | |
'depends_on_past': False, | |
'email_on_failure': False, | |
'email_on_retry': False, | |
'retries': 1, | |
'retry_delay': timedelta(minutes=5), | |
} | |
# UAMI and Databricks configuration | |
UAMI_CLIENT_ID = "your-uami-client-id" # Replace with your actual UAMI client ID | |
DATABRICKS_URL = "https://your-databricks-workspace.azuredatabricks.net" # Replace with your Databricks workspace URL | |
DATABRICKS_JOB_ID = "your-databricks-job-id" # Replace with your Databricks job ID | |
# Connection ID that will be created programmatically | |
DATABRICKS_CONN_ID = "databricks_uami_conn" | |
def get_token_from_uami(): | |
""" | |
Get an access token using the User-Assigned Managed Identity. | |
This uses the Azure Instance Metadata Service (IMDS) endpoint. | |
""" | |
metadata_url = "http://169.254.169.254/metadata/identity/oauth2/token" | |
params = { | |
"api-version": "2018-02-01", | |
"resource": "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", # Resource ID for Azure Databricks | |
"client_id": UAMI_CLIENT_ID | |
} | |
headers = {"Metadata": "true"} | |
response = requests.get(metadata_url, headers=headers, params=params) | |
response.raise_for_status() | |
return response.json()["access_token"] | |
def create_databricks_connection(**kwargs): | |
""" | |
Create a Databricks connection using UAMI authentication. | |
This function will be run as the first task in the DAG. | |
""" | |
try: | |
# Check if connection already exists | |
try: | |
connection = BaseHook.get_connection(DATABRICKS_CONN_ID) | |
print(f"Connection {DATABRICKS_CONN_ID} already exists, no need to create it.") | |
return | |
except: | |
print(f"Connection {DATABRICKS_CONN_ID} does not exist, creating it...") | |
# Get token from UAMI | |
token = get_token_from_uami() | |
# Create a new connection | |
conn = Connection( | |
conn_id=DATABRICKS_CONN_ID, | |
conn_type='databricks', | |
host=DATABRICKS_URL, | |
extra=json.dumps({ | |
"token": token, | |
"host": DATABRICKS_URL | |
}) | |
) | |
# Add connection to session | |
from airflow.settings import Session | |
session = Session() | |
session.add(conn) | |
session.commit() | |
print(f"Created connection: {DATABRICKS_CONN_ID}") | |
except Exception as e: | |
print(f"Error creating connection: {str(e)}") | |
raise | |
# Define the DAG | |
with DAG( | |
'databricks_uami_job', | |
default_args=default_args, | |
description='Run Azure Databricks job using UAMI', | |
schedule_interval=timedelta(days=1), | |
start_date=datetime(2025, 3, 15), | |
catchup=False, | |
tags=['databricks', 'uami'], | |
) as dag: | |
# Task to create the Databricks connection | |
create_connection = PythonOperator( | |
task_id='create_databricks_connection', | |
python_callable=create_databricks_connection, | |
provide_context=True, | |
) | |
# Task to run the Databricks job | |
run_databricks_job = DatabricksRunNowOperator( | |
task_id='run_databricks_job', | |
databricks_conn_id=DATABRICKS_CONN_ID, | |
job_id=DATABRICKS_JOB_ID, | |
) | |
# Set task dependencies | |
create_connection >> run_databricks_job |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment