From efcea13a1296a30131c1495e90b46c8f91c8d06f Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Mon, 6 Apr 2026 18:16:30 +0530 Subject: [PATCH] Add SPOG routing support for account-level vanity URLs SPOG replaces per-workspace hostnames with account-level URLs. When httpPath contains ?o=, the connector now extracts the workspace ID and injects x-databricks-org-id as an HTTP header on all non-OAuth endpoints (SEA, telemetry, feature flags). Changes: - Fix warehouse ID regex to stop at query params ([^?&]+ instead of .+) - Extract ?o= from httpPath once during session init, store as _spog_headers - Propagate org-id header to telemetry client via extra_headers param - Propagate org-id header to feature flags client - Do NOT propagate to OAuth endpoints (they reject it with 400) Signed-off-by: Madhavendra Rathore Co-authored-by: Isaac Signed-off-by: Madhavendra Rathore --- src/databricks/sql/backend/sea/backend.py | 5 ++- src/databricks/sql/client.py | 1 + src/databricks/sql/common/feature_flag.py | 1 + src/databricks/sql/session.py | 32 ++++++++++++++ .../sql/telemetry/telemetry_client.py | 6 +++ tests/unit/test_sea_backend.py | 33 ++++++++++++++ tests/unit/test_session.py | 44 +++++++++++++++++++ 7 files changed, 120 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index ff130cd39..04c79a18b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -188,8 +188,9 @@ def _extract_warehouse_id(self, http_path: str) -> str: ValueError: If the warehouse ID cannot be extracted from the path """ - warehouse_pattern = re.compile(r".*/warehouses/(.+)") - endpoint_pattern = re.compile(r".*/endpoints/(.+)") + # [^?&]+ stops at query params (e.g. ?o= for SPOG routing) + warehouse_pattern = re.compile(r".*/warehouses/([^?&]+)") + endpoint_pattern = re.compile(r".*/endpoints/([^?&]+)") for pattern in [warehouse_pattern, endpoint_pattern]: match = pattern.match(http_path) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 2aeea175e..fe52f0c79 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -353,6 +353,7 @@ def read(self) -> Optional[OAuthToken]: host_url=self.session.host, batch_size=self.telemetry_batch_size, client_context=client_context, + extra_headers=self.session.get_spog_headers(), ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 36e4b8a02..0b2c7490b 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -113,6 +113,7 @@ def _refresh_flags(self): # Authenticate the request self._connection.session.auth_provider.add_headers(headers) headers["User-Agent"] = self._connection.session.useragent_header + headers.update(self._connection.session.get_spog_headers()) response = self._http_client.request( HttpMethod.GET, self._feature_flag_endpoint, headers=headers, timeout=30 diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 0f723d144..44077650b 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -67,6 +67,14 @@ def __init__( base_headers = [("User-Agent", self.useragent_header)] all_headers = (http_headers or []) + base_headers + # Extract ?o= from http_path for SPOG routing. + # On SPOG hosts, the httpPath contains ?o= which routes Thrift + # requests via the URL. For SEA, telemetry, and feature flags (which use + # separate endpoints), we inject x-databricks-org-id as an HTTP header. + self._spog_headers = self._extract_spog_headers(http_path, all_headers) + if self._spog_headers: + all_headers = all_headers + list(self._spog_headers.items()) + self.ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility tls_verify=not kwargs.get( @@ -131,6 +139,30 @@ def _create_backend( } return databricks_client_class(**common_args) + @staticmethod + def _extract_spog_headers(http_path, existing_headers): + """Extract ?o= from http_path and return as a header dict for SPOG routing.""" + if not http_path or "?" not in http_path: + return {} + + from urllib.parse import parse_qs + + query_string = http_path.split("?", 1)[1] + params = parse_qs(query_string) + org_id = params.get("o", [None])[0] + if not org_id: + return {} + + # Don't override if explicitly set + if any(k == "x-databricks-org-id" for k, _ in existing_headers): + return {} + + return {"x-databricks-org-id": org_id} + + def get_spog_headers(self): + """Returns SPOG routing headers (x-databricks-org-id) if ?o= was in http_path.""" + return dict(self._spog_headers) + def open(self): self._session_id = self.backend.open_session( session_configuration=self.session_configuration, diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 408162400..55d845e46 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -188,6 +188,7 @@ def __init__( executor, batch_size: int, client_context, + extra_headers: Optional[Dict[str, str]] = None, ) -> None: logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -195,6 +196,7 @@ def __init__( self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None + self._extra_headers = extra_headers or {} # OPTIMIZATION: Use lock-free Queue instead of list + lock # Queue is thread-safe internally and has better performance under concurrency @@ -287,6 +289,8 @@ def _send_telemetry(self, events): if self._auth_provider: self._auth_provider.add_headers(headers) + headers.update(self._extra_headers) + try: logger.debug("Submitting telemetry request to thread pool") @@ -587,6 +591,7 @@ def initialize_telemetry_client( host_url, batch_size, client_context, + extra_headers=None, ): """ Initialize a telemetry client for a specific connection if telemetry is enabled. @@ -627,6 +632,7 @@ def initialize_telemetry_client( executor=TelemetryClientFactory._executor, batch_size=batch_size, client_context=client_context, + extra_headers=extra_headers, ) TelemetryClientFactory._clients[ host_url diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f71e60943..24a5e8242 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -143,6 +143,39 @@ def test_initialization(self, mock_http_client): ) assert client2.warehouse_id == "def456" + # Test with SPOG query param ?o= in http_path + client_spog = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/warehouses/abc123?o=6051921418418893", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog.warehouse_id == "abc123" + + # Test with SPOG query param on endpoints path + client_spog_ep = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/endpoints/def456?o=6051921418418893", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog_ep.warehouse_id == "def456" + + # Test with multiple query params + client_spog_multi = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/warehouses/abc123?o=123&extra=val", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog_multi.warehouse_id == "abc123" + # Test with custom max_download_threads client3 = SeaDatabricksClient( server_hostname="test-server.databricks.com", diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 3a43c1a75..36ebd1e6c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -8,6 +8,7 @@ THandleIdentifier, ) from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.session import Session import databricks.sql @@ -223,3 +224,46 @@ def test_query_tags_dict_takes_precedence_over_session_config(self, mock_client_ call_kwargs = mock_client_class.return_value.open_session.call_args[1] assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:new-team" + + +class TestSpogHeaders: + """Unit tests for SPOG header extraction from http_path.""" + + def test_extracts_org_id_from_query_param(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=6051921418418893", [] + ) + assert result == {"x-databricks-org-id": "6051921418418893"} + + def test_no_query_param_returns_empty(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123", [] + ) + assert result == {} + + def test_no_o_param_returns_empty(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?other=value", [] + ) + assert result == {} + + def test_empty_http_path_returns_empty(self): + result = Session._extract_spog_headers("", []) + assert result == {} + + def test_none_http_path_returns_empty(self): + result = Session._extract_spog_headers(None, []) + assert result == {} + + def test_explicit_header_takes_precedence(self): + existing = [("x-databricks-org-id", "explicit-value")] + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=6051921418418893", existing + ) + assert result == {} + + def test_multiple_query_params(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=12345&extra=val", [] + ) + assert result == {"x-databricks-org-id": "12345"}