diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py index e9eab0921..160ae8449 100644 --- a/src/datajoint/connection.py +++ b/src/datajoint/connection.py @@ -168,6 +168,7 @@ def __init__( port: int | None = None, use_tls: bool | dict | None = None, *, + database_name: str | None = None, backend: str | None = None, config_override: "Config | None" = None, ) -> None: @@ -180,7 +181,9 @@ def __init__( port = int(port) elif port is None: port = self._config["database.port"] - self.conn_info = dict(host=host, port=port, user=user, passwd=password) + if database_name is None: + database_name = self._config.get("database.name") + self.conn_info = dict(host=host, port=port, user=user, passwd=password, database_name=database_name) if use_tls is not False: # use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config) if isinstance(use_tls, dict): @@ -201,12 +204,27 @@ def __init__( backend = self._config["database.backend"] self.adapter = get_adapter(backend) + if database_name and self.adapter.backend == "mysql": + warnings.warn( + "database.name is set but the MySQL backend does not support database selection. " + "This setting only applies to PostgreSQL connections.", + UserWarning, + stacklevel=2, + ) + self.connect() if self.is_connected: - logger.info("DataJoint {version} connected to {user}@{host}:{port}".format(version=__version__, **self.conn_info)) + db = self.conn_info.get("database_name") + db_str = f"/{db}" if db else "" + logger.info( + f"DataJoint {__version__} connected to " + f"{self.conn_info['user']}@{self.conn_info['host']}:{self.conn_info['port']}{db_str}" + ) self.connection_id = self.adapter.get_connection_id(self._conn) else: - raise errors.LostConnectionError("Connection failed {user}@{host}:{port}".format(**self.conn_info)) + raise errors.LostConnectionError( + f"Connection failed {self.conn_info['user']}@{self.conn_info['host']}:{self.conn_info['port']}" + ) self._in_transaction = False self.schemas = dict() self.dependencies = Dependencies(self) @@ -216,22 +234,33 @@ def __eq__(self, other): def __repr__(self): connected = "connected" if self.is_connected else "disconnected" - return "DataJoint connection ({connected}) {user}@{host}:{port}".format(connected=connected, **self.conn_info) + user = self.conn_info["user"] + host = self.conn_info["host"] + port = self.conn_info["port"] + db = self.conn_info.get("database_name") + db_str = f"/{db}" if db else "" + return f"DataJoint connection ({connected}) {user}@{host}:{port}{db_str}" + + def _build_connect_kwargs(self, use_tls=None): + """Build kwargs dict for adapter.connect().""" + kwargs = dict( + host=self.conn_info["host"], + port=self.conn_info["port"], + user=self.conn_info["user"], + password=self.conn_info["passwd"], + charset=self._config["connection.charset"], + use_tls=use_tls if use_tls is not None else self.conn_info.get("ssl"), + ) + if self.conn_info.get("database_name"): + kwargs["dbname"] = self.conn_info["database_name"] + return kwargs def connect(self) -> None: """Establish or re-establish connection to the database server.""" with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*deprecated.*") try: - # Use adapter to create connection - self._conn = self.adapter.connect( - host=self.conn_info["host"], - port=self.conn_info["port"], - user=self.conn_info["user"], - password=self.conn_info["passwd"], - charset=self._config["connection.charset"], - use_tls=self.conn_info.get("ssl"), - ) + self._conn = self.adapter.connect(**self._build_connect_kwargs()) except Exception as ssl_error: # If SSL fails, retry without SSL (if it was auto-detected) if self.conn_info.get("ssl_input") is None: @@ -240,14 +269,7 @@ def connect(self) -> None: "To require SSL, set use_tls=True explicitly.", ssl_error, ) - self._conn = self.adapter.connect( - host=self.conn_info["host"], - port=self.conn_info["port"], - user=self.conn_info["user"], - password=self.conn_info["passwd"], - charset=self._config["connection.charset"], - use_tls=False, # Explicitly disable SSL for fallback - ) + self._conn = self.adapter.connect(**self._build_connect_kwargs(use_tls=False)) else: raise self._is_closed = False # Mark as connected after successful connection diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index aad69102b..ff1b0e234 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -172,6 +172,13 @@ def activate( self.connection = connection if self.connection is None: self.connection = _get_singleton_connection() + if self.connection._config.get("database.database_prefix"): + warnings.warn( + "database_prefix is deprecated and will be removed in DataJoint 2.3. " + "Use database.name to select a PostgreSQL database instead.", + DeprecationWarning, + stacklevel=2, + ) self.database = schema_name if create_schema is not None: self.create_schema = create_schema diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index 7019d8345..f5881793f 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -65,6 +65,7 @@ "database.password": "DJ_PASS", "database.backend": "DJ_BACKEND", "database.port": "DJ_PORT", + "database.name": "DJ_DATABASE_NAME", "database.database_prefix": "DJ_DATABASE_PREFIX", "database.create_tables": "DJ_CREATE_TABLES", "loglevel": "DJ_LOG_LEVEL", @@ -196,13 +197,17 @@ class DatabaseSettings(BaseSettings): description="Database backend: 'mysql' or 'postgresql'", ) port: int | None = Field(default=None, validation_alias="DJ_PORT") + name: str | None = Field( + default=None, + validation_alias="DJ_DATABASE_NAME", + description="Database name for PostgreSQL connections. Defaults to 'postgres' if not set.", + ) reconnect: bool = True use_tls: bool | None = Field(default=None, validation_alias="DJ_USE_TLS") database_prefix: str = Field( default="", validation_alias="DJ_DATABASE_PREFIX", - description="Prefix for database/schema names. " - "Not automatically applied; use dj.config.database.database_prefix when creating schemas.", + description="Deprecated. Use database.name instead.", ) create_tables: bool = Field( default=True, diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 475d96df9..83bb1d67c 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -750,6 +750,71 @@ def test_similar_prefix_names_allowed(self): dj.config.stores.update(original_stores) +class TestDatabaseNameConfiguration: + """Test database.name configuration.""" + + def test_database_name_default_is_none(self): + """Database name defaults to None when not configured.""" + from datajoint.settings import DatabaseSettings + + s = DatabaseSettings() + assert s.name is None + + def test_database_name_env_var(self, monkeypatch): + """DJ_DATABASE_NAME environment variable sets database name.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_DATABASE_NAME", "my_database") + s = DatabaseSettings() + assert s.name == "my_database" + + def test_database_name_from_config_file(self, tmp_path, monkeypatch): + """Load database name from config file.""" + import json + + from datajoint.settings import Config + + config_file = tmp_path / "test_config.json" + config_file.write_text(json.dumps({"database": {"name": "custom_db", "host": "localhost"}})) + + monkeypatch.delenv("DJ_DATABASE_NAME", raising=False) + monkeypatch.delenv("DJ_HOST", raising=False) + + cfg = Config() + cfg.load(config_file) + assert cfg.database.name == "custom_db" + + def test_database_name_dict_access(self): + """Dict-style access reads and writes database name.""" + original = dj.config.database.name + try: + dj.config.database.name = "test_db" + assert dj.config["database.name"] == "test_db" + finally: + dj.config.database.name = original + + def test_database_name_override_context_manager(self): + """Override context manager temporarily sets database name.""" + original = dj.config.database.name + with dj.config.override(database__name="override_db"): + assert dj.config.database.name == "override_db" + assert dj.config.database.name == original + + def test_database_prefix_empty_no_warning(self): + """Empty database_prefix does not emit DeprecationWarning at config load.""" + import warnings + + from datajoint.settings import DatabaseSettings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + DatabaseSettings() + deprecation_warnings = [ + x for x in w if issubclass(x.category, DeprecationWarning) and "database_prefix" in str(x.message) + ] + assert len(deprecation_warnings) == 0 + + class TestBackendConfiguration: """Test database backend configuration and port auto-detection."""