Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,9 @@ def _extract_warehouse_id(self, http_path: str) -> str:
ValueError: If the warehouse ID cannot be extracted from the path
"""

warehouse_pattern = re.compile(r".*/warehouses/(.+)")
endpoint_pattern = re.compile(r".*/endpoints/(.+)")
# [^?&]+ stops at query params (e.g. ?o= for SPOG routing)
warehouse_pattern = re.compile(r".*/warehouses/([^?&]+)")
endpoint_pattern = re.compile(r".*/endpoints/([^?&]+)")

for pattern in [warehouse_pattern, endpoint_pattern]:
match = pattern.match(http_path)
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def read(self) -> Optional[OAuthToken]:
host_url=self.session.host,
batch_size=self.telemetry_batch_size,
client_context=client_context,
extra_headers=self.session.get_spog_headers(),
)

self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _refresh_flags(self):
# Authenticate the request
self._connection.session.auth_provider.add_headers(headers)
headers["User-Agent"] = self._connection.session.useragent_header
headers.update(self._connection.session.get_spog_headers())

response = self._http_client.request(
HttpMethod.GET, self._feature_flag_endpoint, headers=headers, timeout=30
Expand Down
32 changes: 32 additions & 0 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def __init__(
base_headers = [("User-Agent", self.useragent_header)]
all_headers = (http_headers or []) + base_headers

# Extract ?o=<workspaceId> from http_path for SPOG routing.
# On SPOG hosts, the httpPath contains ?o=<workspaceId> which routes Thrift
# requests via the URL. For SEA, telemetry, and feature flags (which use
# separate endpoints), we inject x-databricks-org-id as an HTTP header.
self._spog_headers = self._extract_spog_headers(http_path, all_headers)
if self._spog_headers:
all_headers = all_headers + list(self._spog_headers.items())

self.ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
Expand Down Expand Up @@ -131,6 +139,30 @@ def _create_backend(
}
return databricks_client_class(**common_args)

@staticmethod
def _extract_spog_headers(http_path, existing_headers):
"""Extract ?o=<workspaceId> from http_path and return as a header dict for SPOG routing."""
if not http_path or "?" not in http_path:
return {}

from urllib.parse import parse_qs

query_string = http_path.split("?", 1)[1]
params = parse_qs(query_string)
org_id = params.get("o", [None])[0]
if not org_id:
return {}

# Don't override if explicitly set
if any(k == "x-databricks-org-id" for k, _ in existing_headers):
return {}

return {"x-databricks-org-id": org_id}

def get_spog_headers(self):
"""Returns SPOG routing headers (x-databricks-org-id) if ?o= was in http_path."""
return dict(self._spog_headers)

def open(self):
self._session_id = self.backend.open_session(
session_configuration=self.session_configuration,
Expand Down
6 changes: 6 additions & 0 deletions src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ def __init__(
executor,
batch_size: int,
client_context,
extra_headers: Optional[Dict[str, str]] = None,
) -> None:
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
self._telemetry_enabled = telemetry_enabled
self._batch_size = batch_size
self._session_id_hex = session_id_hex
self._auth_provider = auth_provider
self._user_agent = None
self._extra_headers = extra_headers or {}

# OPTIMIZATION: Use lock-free Queue instead of list + lock
# Queue is thread-safe internally and has better performance under concurrency
Expand Down Expand Up @@ -287,6 +289,8 @@ def _send_telemetry(self, events):
if self._auth_provider:
self._auth_provider.add_headers(headers)

headers.update(self._extra_headers)

try:
logger.debug("Submitting telemetry request to thread pool")

Expand Down Expand Up @@ -587,6 +591,7 @@ def initialize_telemetry_client(
host_url,
batch_size,
client_context,
extra_headers=None,
):
"""
Initialize a telemetry client for a specific connection if telemetry is enabled.
Expand Down Expand Up @@ -627,6 +632,7 @@ def initialize_telemetry_client(
executor=TelemetryClientFactory._executor,
batch_size=batch_size,
client_context=client_context,
extra_headers=extra_headers,
)
TelemetryClientFactory._clients[
host_url
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,39 @@ def test_initialization(self, mock_http_client):
)
assert client2.warehouse_id == "def456"

# Test with SPOG query param ?o= in http_path
client_spog = SeaDatabricksClient(
server_hostname="test-server.databricks.com",
port=443,
http_path="/sql/1.0/warehouses/abc123?o=6051921418418893",
http_headers=[],
auth_provider=AuthProvider(),
ssl_options=SSLOptions(),
)
assert client_spog.warehouse_id == "abc123"

# Test with SPOG query param on endpoints path
client_spog_ep = SeaDatabricksClient(
server_hostname="test-server.databricks.com",
port=443,
http_path="/sql/1.0/endpoints/def456?o=6051921418418893",
http_headers=[],
auth_provider=AuthProvider(),
ssl_options=SSLOptions(),
)
assert client_spog_ep.warehouse_id == "def456"

# Test with multiple query params
client_spog_multi = SeaDatabricksClient(
server_hostname="test-server.databricks.com",
port=443,
http_path="/sql/1.0/warehouses/abc123?o=123&extra=val",
http_headers=[],
auth_provider=AuthProvider(),
ssl_options=SSLOptions(),
)
assert client_spog_multi.warehouse_id == "abc123"

# Test with custom max_download_threads
client3 = SeaDatabricksClient(
server_hostname="test-server.databricks.com",
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
THandleIdentifier,
)
from databricks.sql.backend.types import SessionId, BackendType
from databricks.sql.session import Session

import databricks.sql

Expand Down Expand Up @@ -223,3 +224,46 @@ def test_query_tags_dict_takes_precedence_over_session_config(self, mock_client_

call_kwargs = mock_client_class.return_value.open_session.call_args[1]
assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:new-team"


class TestSpogHeaders:
"""Unit tests for SPOG header extraction from http_path."""

def test_extracts_org_id_from_query_param(self):
result = Session._extract_spog_headers(
"/sql/1.0/warehouses/abc123?o=6051921418418893", []
)
assert result == {"x-databricks-org-id": "6051921418418893"}

def test_no_query_param_returns_empty(self):
result = Session._extract_spog_headers(
"/sql/1.0/warehouses/abc123", []
)
assert result == {}

def test_no_o_param_returns_empty(self):
result = Session._extract_spog_headers(
"/sql/1.0/warehouses/abc123?other=value", []
)
assert result == {}

def test_empty_http_path_returns_empty(self):
result = Session._extract_spog_headers("", [])
assert result == {}

def test_none_http_path_returns_empty(self):
result = Session._extract_spog_headers(None, [])
assert result == {}

def test_explicit_header_takes_precedence(self):
existing = [("x-databricks-org-id", "explicit-value")]
result = Session._extract_spog_headers(
"/sql/1.0/warehouses/abc123?o=6051921418418893", existing
)
assert result == {}

def test_multiple_query_params(self):
result = Session._extract_spog_headers(
"/sql/1.0/warehouses/abc123?o=12345&extra=val", []
)
assert result == {"x-databricks-org-id": "12345"}
Loading