Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Optional, List, Any, TYPE_CHECKING

from databricks.sql.common.http import HttpMethod
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -67,7 +68,8 @@ def __init__(

endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
self._feature_flag_endpoint = (
f"https://{self._connection.session.host}{endpoint_suffix}"
normalize_host_with_protocol(self._connection.session.host)
+ endpoint_suffix
)

# Use the provided HTTP client
Expand Down
44 changes: 44 additions & 0 deletions src/databricks/sql/common/url_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
URL utility functions for the Databricks SQL connector.
"""


def normalize_host_with_protocol(host: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was rechecking this piece of code : auth_utils.py and thrift_backend - Has proper check on this already. should we use this util instead?

Also, the sea flow looks incorrect at the moment : backend/sea/utils/http_client.py

"""
Normalize a connection hostname by ensuring it has a protocol and removing trailing slashes.

This is useful for handling cases where users may provide hostnames with or without protocols
(common with dbt-databricks users copying URLs from their browser).

Args:
host: Connection hostname which may or may not include a protocol prefix (https:// or http://)
and may or may not have a trailing slash

Returns:
Normalized hostname with protocol prefix and no trailing slash

Examples:
normalize_host_with_protocol("myserver.com") -> "https://myserver.com"
normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com"
normalize_host_with_protocol("HTTPS://myserver.com") -> "https://myserver.com"

Raises:
ValueError: If host is None or empty string
"""
# Handle None or empty host
if not host or not host.strip():
raise ValueError("Host cannot be None or empty")

# Remove trailing slash
host = host.rstrip("/")

# Add protocol if not present (case-insensitive check)
host_lower = host.lower()
if not host_lower.startswith("https://") and not host_lower.startswith("http://"):
host = f"https://{host}"
elif host_lower.startswith("https://") or host_lower.startswith("http://"):
# Normalize protocol to lowercase
protocol_end = host.index("://") + 3
host = host[:protocol_end].lower() + host[protocol_end:]

return host
3 changes: 2 additions & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TelemetryPushClient,
CircuitBreakerTelemetryPushClient,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -278,7 +279,7 @@ def _send_telemetry(self, events):
if self._auth_provider
else self.TELEMETRY_UNAUTHENTICATED_PATH
)
url = f"https://{self._host_url}{path}"
url = normalize_host_with_protocol(self._host_url) + path

headers = {"Accept": "application/json", "Content-Type": "application/json"}

Expand Down
72 changes: 57 additions & 15 deletions tests/e2e/test_circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,46 @@
from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager


def wait_for_circuit_state(circuit_breaker, expected_state, timeout=5):
"""
Wait for circuit breaker to reach expected state with polling.

Args:
circuit_breaker: The circuit breaker instance to monitor
expected_state: The expected state (STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN)
timeout: Maximum time to wait in seconds

Returns:
True if state reached, False if timeout
"""
start = time.time()
while time.time() - start < timeout:
if circuit_breaker.current_state == expected_state:
return True
time.sleep(0.1) # Poll every 100ms
return False


def wait_for_circuit_state_multiple(circuit_breaker, expected_states, timeout=5):
"""
Wait for circuit breaker to reach one of multiple expected states.

Args:
circuit_breaker: The circuit breaker instance to monitor
expected_states: List of acceptable states
timeout: Maximum time to wait in seconds

Returns:
True if any state reached, False if timeout
"""
start = time.time()
while time.time() - start < timeout:
if circuit_breaker.current_state in expected_states:
return True
time.sleep(0.1)
return False


@pytest.fixture(autouse=True)
def aggressive_circuit_breaker_config():
"""
Expand Down Expand Up @@ -107,9 +147,13 @@ def mock_request(*args, **kwargs):
time.sleep(0.5)

if should_trigger:
# Circuit should be OPEN after 2 rate-limit failures
# Wait for circuit to open (async telemetry may take time)
assert wait_for_circuit_state(circuit_breaker, STATE_OPEN, timeout=5), \
f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}"

# Circuit should be OPEN after rate-limit failures
assert circuit_breaker.current_state == STATE_OPEN
assert circuit_breaker.fail_counter == 2
assert circuit_breaker.fail_counter >= 2 # At least 2 failures

# Track requests before another query
requests_before = request_count["count"]
Expand Down Expand Up @@ -197,7 +241,9 @@ def mock_conditional_request(*args, **kwargs):
cursor.fetchone()
time.sleep(2)

assert circuit_breaker.current_state == STATE_OPEN
# Wait for circuit to open
assert wait_for_circuit_state(circuit_breaker, STATE_OPEN, timeout=5), \
f"Circuit didn't open, state: {circuit_breaker.current_state}"

# Wait for reset timeout (5 seconds in test)
time.sleep(6)
Expand All @@ -208,24 +254,20 @@ def mock_conditional_request(*args, **kwargs):
# Execute query to trigger HALF_OPEN state
cursor.execute("SELECT 3")
cursor.fetchone()
time.sleep(1)

# Circuit should be recovering
assert circuit_breaker.current_state in [
STATE_HALF_OPEN,
STATE_CLOSED,
], f"Circuit should be recovering, but is {circuit_breaker.current_state}"
# Wait for circuit to start recovering
assert wait_for_circuit_state_multiple(
circuit_breaker, [STATE_HALF_OPEN, STATE_CLOSED], timeout=5
), f"Circuit didn't recover, state: {circuit_breaker.current_state}"

# Execute more queries to fully recover
cursor.execute("SELECT 4")
cursor.fetchone()
time.sleep(1)

current_state = circuit_breaker.current_state
assert current_state in [
STATE_CLOSED,
STATE_HALF_OPEN,
], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}"
# Wait for full recovery
assert wait_for_circuit_state_multiple(
circuit_breaker, [STATE_CLOSED, STATE_HALF_OPEN], timeout=5
), f"Circuit didn't fully recover, state: {circuit_breaker.current_state}"


if __name__ == "__main__":
Expand Down
36 changes: 32 additions & 4 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase):
"access_token": "tok",
}

def _setup_mock_session_with_http_client(self, mock_session):
"""
Helper to configure a mock session with HTTP client mocks.
This prevents feature flag network requests during Connection initialization.
"""
mock_session.host = "foo"

# Mock HTTP client to prevent feature flag network requests
mock_http_client = Mock()
mock_session.http_client = mock_http_client

# Mock feature flag response to prevent blocking HTTP calls
mock_ff_response = Mock()
mock_ff_response.status = 200
mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}'
mock_http_client.request.return_value = mock_ff_response

def _create_mock_connection(self, mock_session_class):
"""Helper to create a mocked connection for transaction tests."""
# Mock session
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"
mock_session.get_autocommit.return_value = True

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=False to test actual transaction functionality
Expand Down Expand Up @@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class):
conn = self._create_mock_connection(mock_session_class)

mock_cursor = Mock()
original_error = DatabaseError(
"Original error", host_url="test-host"
)
original_error = DatabaseError("Original error", host_url="test-host")
mock_cursor.execute.side_effect = original_error

with patch.object(conn, "cursor", return_value=mock_cursor):
Expand Down Expand Up @@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down Expand Up @@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand All @@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_url_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Tests for URL utility functions."""
import pytest
from databricks.sql.common.url_utils import normalize_host_with_protocol


class TestNormalizeHostWithProtocol:
"""Tests for normalize_host_with_protocol function."""

@pytest.mark.parametrize("input_host,expected_output", [
# Hostname without protocol - should add https://
("myserver.com", "https://myserver.com"),
("workspace.databricks.com", "https://workspace.databricks.com"),

# Hostname with https:// - should not duplicate
("https://myserver.com", "https://myserver.com"),
("https://workspace.databricks.com", "https://workspace.databricks.com"),

# Hostname with http:// - should preserve
("http://localhost", "http://localhost"),
("http://myserver.com:8080", "http://myserver.com:8080"),

# Hostname with port numbers
("myserver.com:443", "https://myserver.com:443"),
("https://myserver.com:443", "https://myserver.com:443"),
("http://localhost:8080", "http://localhost:8080"),

# Trailing slash - should be removed
("myserver.com/", "https://myserver.com"),
("https://myserver.com/", "https://myserver.com"),
("http://localhost/", "http://localhost"),

# Case-insensitive protocol handling - should normalize to lowercase
("HTTPS://myserver.com", "https://myserver.com"),
("HTTP://myserver.com", "http://myserver.com"),
("HttPs://workspace.databricks.com", "https://workspace.databricks.com"),
("HtTp://localhost:8080", "http://localhost:8080"),
("HTTPS://MYSERVER.COM", "https://MYSERVER.COM"), # Only protocol lowercased

# Case-insensitive with trailing slashes
("HTTPS://myserver.com/", "https://myserver.com"),
("HTTP://localhost:8080/", "http://localhost:8080"),
("HttPs://workspace.databricks.com//", "https://workspace.databricks.com"),

# Mixed case protocols with ports
("HTTPS://myserver.com:443", "https://myserver.com:443"),
("HtTp://myserver.com:8080", "http://myserver.com:8080"),

# Case preservation - only protocol lowercased, hostname case preserved
("HTTPS://MyServer.DataBricks.COM", "https://MyServer.DataBricks.COM"),
("HttPs://CamelCase.Server.com", "https://CamelCase.Server.com"),
("HTTP://UPPERCASE.COM:8080", "http://UPPERCASE.COM:8080"),
])
def test_normalize_host_with_protocol(self, input_host, expected_output):
"""Test host normalization with various input formats."""
result = normalize_host_with_protocol(input_host)
assert result == expected_output

# Additional assertion: verify protocol is always lowercase
assert result.startswith("https://") or result.startswith("http://")

@pytest.mark.parametrize("invalid_host", [
None,
"",
" ", # Whitespace only
])
def test_normalize_host_with_protocol_raises_on_invalid_input(self, invalid_host):
"""Test that function raises ValueError for None or empty host."""
with pytest.raises(ValueError, match="Host cannot be None or empty"):
normalize_host_with_protocol(invalid_host)

Loading