From 24910e95ffeb79fba65f8712e31e6f08d02ba5aa Mon Sep 17 00:00:00 2001 From: "Joseph T. French" Date: Fri, 10 Apr 2026 19:56:46 -0500 Subject: [PATCH 1/4] Add new models and fields for closing entries and search hits - Introduced `AccountRollupGroup`, `AccountRollupRow`, and `AccountRollupsResponse` to the models. - Enhanced `ClosingEntryResponse` with a new `reversal` field to track reversal entries. - Updated `EntryTemplateRequest` to include an `auto_reverse` option for automatic reversing entries. - Added `reversal_entry_id` and `reversal_status` fields to `PeriodCloseItemResponse` for better tracking of entry reversals. - Included `parent_document_id` in `SearchHit` to associate search results with their parent documents. These changes improve the data model's capability to handle financial entries and enhance search functionalities. --- .../api/ledger/get_account_rollups.py | 251 ++++++++++++++++++ .../api/ledger/get_closing_book_structures.py | 184 +++++++++++++ robosystems_client/models/__init__.py | 12 + .../models/account_rollup_group.py | 123 +++++++++ .../models/account_rollup_row.py | 115 ++++++++ .../models/account_rollups_response.py | 107 ++++++++ .../models/closing_book_category.py | 83 ++++++ .../models/closing_book_item.py | 139 ++++++++++ .../closing_book_structures_response.py | 83 ++++++ .../models/closing_entry_response.py | 34 ++- .../models/entry_template_request.py | 9 + .../models/period_close_item_response.py | 40 +++ robosystems_client/models/search_hit.py | 20 ++ 13 files changed, 1199 insertions(+), 1 deletion(-) create mode 100644 robosystems_client/api/ledger/get_account_rollups.py create mode 100644 robosystems_client/api/ledger/get_closing_book_structures.py create mode 100644 robosystems_client/models/account_rollup_group.py create mode 100644 robosystems_client/models/account_rollup_row.py create mode 100644 robosystems_client/models/account_rollups_response.py create mode 100644 robosystems_client/models/closing_book_category.py create mode 100644 robosystems_client/models/closing_book_item.py create mode 100644 robosystems_client/models/closing_book_structures_response.py diff --git a/robosystems_client/api/ledger/get_account_rollups.py b/robosystems_client/api/ledger/get_account_rollups.py new file mode 100644 index 0000000..1c92f73 --- /dev/null +++ b/robosystems_client/api/ledger/get_account_rollups.py @@ -0,0 +1,251 @@ +import datetime +from http import HTTPStatus +from typing import Any +from urllib.parse import quote + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.account_rollups_response import AccountRollupsResponse +from ...models.http_validation_error import HTTPValidationError +from ...types import UNSET, Response, Unset + + +def _get_kwargs( + graph_id: str, + *, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> dict[str, Any]: + params: dict[str, Any] = {} + + json_mapping_id: None | str | Unset + if isinstance(mapping_id, Unset): + json_mapping_id = UNSET + else: + json_mapping_id = mapping_id + params["mapping_id"] = json_mapping_id + + json_start_date: None | str | Unset + if isinstance(start_date, Unset): + json_start_date = UNSET + elif isinstance(start_date, datetime.date): + json_start_date = start_date.isoformat() + else: + json_start_date = start_date + params["start_date"] = json_start_date + + json_end_date: None | str | Unset + if isinstance(end_date, Unset): + json_end_date = UNSET + elif isinstance(end_date, datetime.date): + json_end_date = end_date.isoformat() + else: + json_end_date = end_date + params["end_date"] = json_end_date + + params = {k: v for k, v in params.items() if v is not UNSET and v is not None} + + _kwargs: dict[str, Any] = { + "method": "get", + "url": "/v1/ledger/{graph_id}/account-rollups".format( + graph_id=quote(str(graph_id), safe=""), + ), + "params": params, + } + + return _kwargs + + +def _parse_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> AccountRollupsResponse | HTTPValidationError | None: + if response.status_code == 200: + response_200 = AccountRollupsResponse.from_dict(response.json()) + + return response_200 + + if response.status_code == 422: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> Response[AccountRollupsResponse | HTTPValidationError]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + graph_id: str, + *, + client: AuthenticatedClient, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> Response[AccountRollupsResponse | HTTPValidationError]: + """Account Rollups + + Account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting + line items. Auto-discovers the mapping structure if mapping_id is not provided. + + Args: + graph_id (str): + mapping_id (None | str | Unset): Mapping structure ID (auto-discovers if omitted) + start_date (datetime.date | None | Unset): Start date (inclusive) + end_date (datetime.date | None | Unset): End date (inclusive) + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[AccountRollupsResponse | HTTPValidationError] + """ + + kwargs = _get_kwargs( + graph_id=graph_id, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + graph_id: str, + *, + client: AuthenticatedClient, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> AccountRollupsResponse | HTTPValidationError | None: + """Account Rollups + + Account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting + line items. Auto-discovers the mapping structure if mapping_id is not provided. + + Args: + graph_id (str): + mapping_id (None | str | Unset): Mapping structure ID (auto-discovers if omitted) + start_date (datetime.date | None | Unset): Start date (inclusive) + end_date (datetime.date | None | Unset): End date (inclusive) + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + AccountRollupsResponse | HTTPValidationError + """ + + return sync_detailed( + graph_id=graph_id, + client=client, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + ).parsed + + +async def asyncio_detailed( + graph_id: str, + *, + client: AuthenticatedClient, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> Response[AccountRollupsResponse | HTTPValidationError]: + """Account Rollups + + Account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting + line items. Auto-discovers the mapping structure if mapping_id is not provided. + + Args: + graph_id (str): + mapping_id (None | str | Unset): Mapping structure ID (auto-discovers if omitted) + start_date (datetime.date | None | Unset): Start date (inclusive) + end_date (datetime.date | None | Unset): End date (inclusive) + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[AccountRollupsResponse | HTTPValidationError] + """ + + kwargs = _get_kwargs( + graph_id=graph_id, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + graph_id: str, + *, + client: AuthenticatedClient, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> AccountRollupsResponse | HTTPValidationError | None: + """Account Rollups + + Account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting + line items. Auto-discovers the mapping structure if mapping_id is not provided. + + Args: + graph_id (str): + mapping_id (None | str | Unset): Mapping structure ID (auto-discovers if omitted) + start_date (datetime.date | None | Unset): Start date (inclusive) + end_date (datetime.date | None | Unset): End date (inclusive) + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + AccountRollupsResponse | HTTPValidationError + """ + + return ( + await asyncio_detailed( + graph_id=graph_id, + client=client, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + ) + ).parsed diff --git a/robosystems_client/api/ledger/get_closing_book_structures.py b/robosystems_client/api/ledger/get_closing_book_structures.py new file mode 100644 index 0000000..349ea45 --- /dev/null +++ b/robosystems_client/api/ledger/get_closing_book_structures.py @@ -0,0 +1,184 @@ +from http import HTTPStatus +from typing import Any +from urllib.parse import quote + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.closing_book_structures_response import ClosingBookStructuresResponse +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + graph_id: str, +) -> dict[str, Any]: + _kwargs: dict[str, Any] = { + "method": "get", + "url": "/v1/ledger/{graph_id}/closing-book/structures".format( + graph_id=quote(str(graph_id), safe=""), + ), + } + + return _kwargs + + +def _parse_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> ClosingBookStructuresResponse | HTTPValidationError | None: + if response.status_code == 200: + response_200 = ClosingBookStructuresResponse.from_dict(response.json()) + + return response_200 + + if response.status_code == 422: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> Response[ClosingBookStructuresResponse | HTTPValidationError]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + graph_id: str, + *, + client: AuthenticatedClient, +) -> Response[ClosingBookStructuresResponse | HTTPValidationError]: + """Closing Book Structures + + Returns all structure categories for the closing book sidebar. + + Aggregates statements (from latest report), schedules, account rollups + (from mapping structures), and trial balance availability into a single + response for the viewer sidebar navigation. + + Args: + graph_id (str): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[ClosingBookStructuresResponse | HTTPValidationError] + """ + + kwargs = _get_kwargs( + graph_id=graph_id, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + graph_id: str, + *, + client: AuthenticatedClient, +) -> ClosingBookStructuresResponse | HTTPValidationError | None: + """Closing Book Structures + + Returns all structure categories for the closing book sidebar. + + Aggregates statements (from latest report), schedules, account rollups + (from mapping structures), and trial balance availability into a single + response for the viewer sidebar navigation. + + Args: + graph_id (str): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + ClosingBookStructuresResponse | HTTPValidationError + """ + + return sync_detailed( + graph_id=graph_id, + client=client, + ).parsed + + +async def asyncio_detailed( + graph_id: str, + *, + client: AuthenticatedClient, +) -> Response[ClosingBookStructuresResponse | HTTPValidationError]: + """Closing Book Structures + + Returns all structure categories for the closing book sidebar. + + Aggregates statements (from latest report), schedules, account rollups + (from mapping structures), and trial balance availability into a single + response for the viewer sidebar navigation. + + Args: + graph_id (str): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[ClosingBookStructuresResponse | HTTPValidationError] + """ + + kwargs = _get_kwargs( + graph_id=graph_id, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + graph_id: str, + *, + client: AuthenticatedClient, +) -> ClosingBookStructuresResponse | HTTPValidationError | None: + """Closing Book Structures + + Returns all structure categories for the closing book sidebar. + + Aggregates statements (from latest report), schedules, account rollups + (from mapping structures), and trial balance availability into a single + response for the viewer sidebar navigation. + + Args: + graph_id (str): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + ClosingBookStructuresResponse | HTTPValidationError + """ + + return ( + await asyncio_detailed( + graph_id=graph_id, + client=client, + ) + ).parsed diff --git a/robosystems_client/models/__init__.py b/robosystems_client/models/__init__.py index f9caa2f..86285be 100644 --- a/robosystems_client/models/__init__.py +++ b/robosystems_client/models/__init__.py @@ -3,6 +3,9 @@ from .account_info import AccountInfo from .account_list_response import AccountListResponse from .account_response import AccountResponse +from .account_rollup_group import AccountRollupGroup +from .account_rollup_row import AccountRollupRow +from .account_rollups_response import AccountRollupsResponse from .account_tree_node import AccountTreeNode from .account_tree_response import AccountTreeResponse from .add_members_request import AddMembersRequest @@ -56,6 +59,9 @@ ) from .checkout_response import CheckoutResponse from .checkout_status_response import CheckoutStatusResponse +from .closing_book_category import ClosingBookCategory +from .closing_book_item import ClosingBookItem +from .closing_book_structures_response import ClosingBookStructuresResponse from .closing_entry_response import ClosingEntryResponse from .connection_options_response import ConnectionOptionsResponse from .connection_provider_info import ConnectionProviderInfo @@ -382,6 +388,9 @@ "AccountInfo", "AccountListResponse", "AccountResponse", + "AccountRollupGroup", + "AccountRollupRow", + "AccountRollupsResponse", "AccountTreeNode", "AccountTreeResponse", "AddMembersRequest", @@ -427,6 +436,9 @@ "CancelOperationResponseCanceloperation", "CheckoutResponse", "CheckoutStatusResponse", + "ClosingBookCategory", + "ClosingBookItem", + "ClosingBookStructuresResponse", "ClosingEntryResponse", "ConnectionOptionsResponse", "ConnectionProviderInfo", diff --git a/robosystems_client/models/account_rollup_group.py b/robosystems_client/models/account_rollup_group.py new file mode 100644 index 0000000..f7cffae --- /dev/null +++ b/robosystems_client/models/account_rollup_group.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.account_rollup_row import AccountRollupRow + + +T = TypeVar("T", bound="AccountRollupGroup") + + +@_attrs_define +class AccountRollupGroup: + """ + Attributes: + reporting_element_id (str): + reporting_name (str): + reporting_qname (str): + classification (str): + balance_type (str): + total (float): + accounts (list[AccountRollupRow]): + """ + + reporting_element_id: str + reporting_name: str + reporting_qname: str + classification: str + balance_type: str + total: float + accounts: list[AccountRollupRow] + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + reporting_element_id = self.reporting_element_id + + reporting_name = self.reporting_name + + reporting_qname = self.reporting_qname + + classification = self.classification + + balance_type = self.balance_type + + total = self.total + + accounts = [] + for accounts_item_data in self.accounts: + accounts_item = accounts_item_data.to_dict() + accounts.append(accounts_item) + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "reporting_element_id": reporting_element_id, + "reporting_name": reporting_name, + "reporting_qname": reporting_qname, + "classification": classification, + "balance_type": balance_type, + "total": total, + "accounts": accounts, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.account_rollup_row import AccountRollupRow + + d = dict(src_dict) + reporting_element_id = d.pop("reporting_element_id") + + reporting_name = d.pop("reporting_name") + + reporting_qname = d.pop("reporting_qname") + + classification = d.pop("classification") + + balance_type = d.pop("balance_type") + + total = d.pop("total") + + accounts = [] + _accounts = d.pop("accounts") + for accounts_item_data in _accounts: + accounts_item = AccountRollupRow.from_dict(accounts_item_data) + + accounts.append(accounts_item) + + account_rollup_group = cls( + reporting_element_id=reporting_element_id, + reporting_name=reporting_name, + reporting_qname=reporting_qname, + classification=classification, + balance_type=balance_type, + total=total, + accounts=accounts, + ) + + account_rollup_group.additional_properties = d + return account_rollup_group + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/account_rollup_row.py b/robosystems_client/models/account_rollup_row.py new file mode 100644 index 0000000..2396990 --- /dev/null +++ b/robosystems_client/models/account_rollup_row.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="AccountRollupRow") + + +@_attrs_define +class AccountRollupRow: + """ + Attributes: + element_id (str): + account_name (str): + total_debits (float): + total_credits (float): + net_balance (float): + account_code (None | str | Unset): + """ + + element_id: str + account_name: str + total_debits: float + total_credits: float + net_balance: float + account_code: None | str | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + element_id = self.element_id + + account_name = self.account_name + + total_debits = self.total_debits + + total_credits = self.total_credits + + net_balance = self.net_balance + + account_code: None | str | Unset + if isinstance(self.account_code, Unset): + account_code = UNSET + else: + account_code = self.account_code + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "element_id": element_id, + "account_name": account_name, + "total_debits": total_debits, + "total_credits": total_credits, + "net_balance": net_balance, + } + ) + if account_code is not UNSET: + field_dict["account_code"] = account_code + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + element_id = d.pop("element_id") + + account_name = d.pop("account_name") + + total_debits = d.pop("total_debits") + + total_credits = d.pop("total_credits") + + net_balance = d.pop("net_balance") + + def _parse_account_code(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + account_code = _parse_account_code(d.pop("account_code", UNSET)) + + account_rollup_row = cls( + element_id=element_id, + account_name=account_name, + total_debits=total_debits, + total_credits=total_credits, + net_balance=net_balance, + account_code=account_code, + ) + + account_rollup_row.additional_properties = d + return account_rollup_row + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/account_rollups_response.py b/robosystems_client/models/account_rollups_response.py new file mode 100644 index 0000000..b90f429 --- /dev/null +++ b/robosystems_client/models/account_rollups_response.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.account_rollup_group import AccountRollupGroup + + +T = TypeVar("T", bound="AccountRollupsResponse") + + +@_attrs_define +class AccountRollupsResponse: + """ + Attributes: + mapping_id (str): + mapping_name (str): + groups (list[AccountRollupGroup]): + total_mapped (int): + total_unmapped (int): + """ + + mapping_id: str + mapping_name: str + groups: list[AccountRollupGroup] + total_mapped: int + total_unmapped: int + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + mapping_id = self.mapping_id + + mapping_name = self.mapping_name + + groups = [] + for groups_item_data in self.groups: + groups_item = groups_item_data.to_dict() + groups.append(groups_item) + + total_mapped = self.total_mapped + + total_unmapped = self.total_unmapped + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "mapping_id": mapping_id, + "mapping_name": mapping_name, + "groups": groups, + "total_mapped": total_mapped, + "total_unmapped": total_unmapped, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.account_rollup_group import AccountRollupGroup + + d = dict(src_dict) + mapping_id = d.pop("mapping_id") + + mapping_name = d.pop("mapping_name") + + groups = [] + _groups = d.pop("groups") + for groups_item_data in _groups: + groups_item = AccountRollupGroup.from_dict(groups_item_data) + + groups.append(groups_item) + + total_mapped = d.pop("total_mapped") + + total_unmapped = d.pop("total_unmapped") + + account_rollups_response = cls( + mapping_id=mapping_id, + mapping_name=mapping_name, + groups=groups, + total_mapped=total_mapped, + total_unmapped=total_unmapped, + ) + + account_rollups_response.additional_properties = d + return account_rollups_response + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/closing_book_category.py b/robosystems_client/models/closing_book_category.py new file mode 100644 index 0000000..16f41a6 --- /dev/null +++ b/robosystems_client/models/closing_book_category.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.closing_book_item import ClosingBookItem + + +T = TypeVar("T", bound="ClosingBookCategory") + + +@_attrs_define +class ClosingBookCategory: + """ + Attributes: + label (str): + items (list[ClosingBookItem]): + """ + + label: str + items: list[ClosingBookItem] + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + label = self.label + + items = [] + for items_item_data in self.items: + items_item = items_item_data.to_dict() + items.append(items_item) + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "label": label, + "items": items, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.closing_book_item import ClosingBookItem + + d = dict(src_dict) + label = d.pop("label") + + items = [] + _items = d.pop("items") + for items_item_data in _items: + items_item = ClosingBookItem.from_dict(items_item_data) + + items.append(items_item) + + closing_book_category = cls( + label=label, + items=items, + ) + + closing_book_category.additional_properties = d + return closing_book_category + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/closing_book_item.py b/robosystems_client/models/closing_book_item.py new file mode 100644 index 0000000..37f1247 --- /dev/null +++ b/robosystems_client/models/closing_book_item.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="ClosingBookItem") + + +@_attrs_define +class ClosingBookItem: + """ + Attributes: + id (str): + name (str): + item_type (str): + structure_type (None | str | Unset): + report_id (None | str | Unset): + status (None | str | Unset): + """ + + id: str + name: str + item_type: str + structure_type: None | str | Unset = UNSET + report_id: None | str | Unset = UNSET + status: None | str | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + id = self.id + + name = self.name + + item_type = self.item_type + + structure_type: None | str | Unset + if isinstance(self.structure_type, Unset): + structure_type = UNSET + else: + structure_type = self.structure_type + + report_id: None | str | Unset + if isinstance(self.report_id, Unset): + report_id = UNSET + else: + report_id = self.report_id + + status: None | str | Unset + if isinstance(self.status, Unset): + status = UNSET + else: + status = self.status + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "id": id, + "name": name, + "item_type": item_type, + } + ) + if structure_type is not UNSET: + field_dict["structure_type"] = structure_type + if report_id is not UNSET: + field_dict["report_id"] = report_id + if status is not UNSET: + field_dict["status"] = status + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + id = d.pop("id") + + name = d.pop("name") + + item_type = d.pop("item_type") + + def _parse_structure_type(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + structure_type = _parse_structure_type(d.pop("structure_type", UNSET)) + + def _parse_report_id(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + report_id = _parse_report_id(d.pop("report_id", UNSET)) + + def _parse_status(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + status = _parse_status(d.pop("status", UNSET)) + + closing_book_item = cls( + id=id, + name=name, + item_type=item_type, + structure_type=structure_type, + report_id=report_id, + status=status, + ) + + closing_book_item.additional_properties = d + return closing_book_item + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/closing_book_structures_response.py b/robosystems_client/models/closing_book_structures_response.py new file mode 100644 index 0000000..2d6631a --- /dev/null +++ b/robosystems_client/models/closing_book_structures_response.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.closing_book_category import ClosingBookCategory + + +T = TypeVar("T", bound="ClosingBookStructuresResponse") + + +@_attrs_define +class ClosingBookStructuresResponse: + """ + Attributes: + categories (list[ClosingBookCategory]): + has_data (bool): + """ + + categories: list[ClosingBookCategory] + has_data: bool + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + categories = [] + for categories_item_data in self.categories: + categories_item = categories_item_data.to_dict() + categories.append(categories_item) + + has_data = self.has_data + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "categories": categories, + "has_data": has_data, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.closing_book_category import ClosingBookCategory + + d = dict(src_dict) + categories = [] + _categories = d.pop("categories") + for categories_item_data in _categories: + categories_item = ClosingBookCategory.from_dict(categories_item_data) + + categories.append(categories_item) + + has_data = d.pop("has_data") + + closing_book_structures_response = cls( + categories=categories, + has_data=has_data, + ) + + closing_book_structures_response.additional_properties = d + return closing_book_structures_response + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/closing_entry_response.py b/robosystems_client/models/closing_entry_response.py index a8d33c4..9cbe8e5 100644 --- a/robosystems_client/models/closing_entry_response.py +++ b/robosystems_client/models/closing_entry_response.py @@ -2,12 +2,14 @@ import datetime from collections.abc import Mapping -from typing import Any, TypeVar +from typing import Any, TypeVar, cast from attrs import define as _attrs_define from attrs import field as _attrs_field from dateutil.parser import isoparse +from ..types import UNSET, Unset + T = TypeVar("T", bound="ClosingEntryResponse") @@ -22,6 +24,7 @@ class ClosingEntryResponse: debit_element_id (str): credit_element_id (str): amount (float): + reversal (ClosingEntryResponse | None | Unset): """ entry_id: str @@ -31,6 +34,7 @@ class ClosingEntryResponse: debit_element_id: str credit_element_id: str amount: float + reversal: ClosingEntryResponse | None | Unset = UNSET additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: @@ -48,6 +52,14 @@ def to_dict(self) -> dict[str, Any]: amount = self.amount + reversal: dict[str, Any] | None | Unset + if isinstance(self.reversal, Unset): + reversal = UNSET + elif isinstance(self.reversal, ClosingEntryResponse): + reversal = self.reversal.to_dict() + else: + reversal = self.reversal + field_dict: dict[str, Any] = {} field_dict.update(self.additional_properties) field_dict.update( @@ -61,6 +73,8 @@ def to_dict(self) -> dict[str, Any]: "amount": amount, } ) + if reversal is not UNSET: + field_dict["reversal"] = reversal return field_dict @@ -81,6 +95,23 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: amount = d.pop("amount") + def _parse_reversal(data: object) -> ClosingEntryResponse | None | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + try: + if not isinstance(data, dict): + raise TypeError() + reversal_type_0 = ClosingEntryResponse.from_dict(data) + + return reversal_type_0 + except (TypeError, ValueError, AttributeError, KeyError): + pass + return cast(ClosingEntryResponse | None | Unset, data) + + reversal = _parse_reversal(d.pop("reversal", UNSET)) + closing_entry_response = cls( entry_id=entry_id, status=status, @@ -89,6 +120,7 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: debit_element_id=debit_element_id, credit_element_id=credit_element_id, amount=amount, + reversal=reversal, ) closing_entry_response.additional_properties = d diff --git a/robosystems_client/models/entry_template_request.py b/robosystems_client/models/entry_template_request.py index e46f5f8..2351740 100644 --- a/robosystems_client/models/entry_template_request.py +++ b/robosystems_client/models/entry_template_request.py @@ -19,12 +19,14 @@ class EntryTemplateRequest: credit_element_id (str): Element to credit (e.g., Accumulated Depreciation) entry_type (str | Unset): Entry type for generated entries Default: 'closing'. memo_template (str | Unset): Memo template ({structure_name} is replaced) Default: ''. + auto_reverse (bool | Unset): Auto-generate a reversing entry on the first day of the next period Default: False. """ debit_element_id: str credit_element_id: str entry_type: str | Unset = "closing" memo_template: str | Unset = "" + auto_reverse: bool | Unset = False additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: @@ -36,6 +38,8 @@ def to_dict(self) -> dict[str, Any]: memo_template = self.memo_template + auto_reverse = self.auto_reverse + field_dict: dict[str, Any] = {} field_dict.update(self.additional_properties) field_dict.update( @@ -48,6 +52,8 @@ def to_dict(self) -> dict[str, Any]: field_dict["entry_type"] = entry_type if memo_template is not UNSET: field_dict["memo_template"] = memo_template + if auto_reverse is not UNSET: + field_dict["auto_reverse"] = auto_reverse return field_dict @@ -62,11 +68,14 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: memo_template = d.pop("memo_template", UNSET) + auto_reverse = d.pop("auto_reverse", UNSET) + entry_template_request = cls( debit_element_id=debit_element_id, credit_element_id=credit_element_id, entry_type=entry_type, memo_template=memo_template, + auto_reverse=auto_reverse, ) entry_template_request.additional_properties = d diff --git a/robosystems_client/models/period_close_item_response.py b/robosystems_client/models/period_close_item_response.py index f44515c..52218c1 100644 --- a/robosystems_client/models/period_close_item_response.py +++ b/robosystems_client/models/period_close_item_response.py @@ -20,6 +20,8 @@ class PeriodCloseItemResponse: amount (float): status (str): entry_id (None | str | Unset): + reversal_entry_id (None | str | Unset): + reversal_status (None | str | Unset): """ structure_id: str @@ -27,6 +29,8 @@ class PeriodCloseItemResponse: amount: float status: str entry_id: None | str | Unset = UNSET + reversal_entry_id: None | str | Unset = UNSET + reversal_status: None | str | Unset = UNSET additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: @@ -44,6 +48,18 @@ def to_dict(self) -> dict[str, Any]: else: entry_id = self.entry_id + reversal_entry_id: None | str | Unset + if isinstance(self.reversal_entry_id, Unset): + reversal_entry_id = UNSET + else: + reversal_entry_id = self.reversal_entry_id + + reversal_status: None | str | Unset + if isinstance(self.reversal_status, Unset): + reversal_status = UNSET + else: + reversal_status = self.reversal_status + field_dict: dict[str, Any] = {} field_dict.update(self.additional_properties) field_dict.update( @@ -56,6 +72,10 @@ def to_dict(self) -> dict[str, Any]: ) if entry_id is not UNSET: field_dict["entry_id"] = entry_id + if reversal_entry_id is not UNSET: + field_dict["reversal_entry_id"] = reversal_entry_id + if reversal_status is not UNSET: + field_dict["reversal_status"] = reversal_status return field_dict @@ -79,12 +99,32 @@ def _parse_entry_id(data: object) -> None | str | Unset: entry_id = _parse_entry_id(d.pop("entry_id", UNSET)) + def _parse_reversal_entry_id(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + reversal_entry_id = _parse_reversal_entry_id(d.pop("reversal_entry_id", UNSET)) + + def _parse_reversal_status(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + reversal_status = _parse_reversal_status(d.pop("reversal_status", UNSET)) + period_close_item_response = cls( structure_id=structure_id, structure_name=structure_name, amount=amount, status=status, entry_id=entry_id, + reversal_entry_id=reversal_entry_id, + reversal_status=reversal_status, ) period_close_item_response.additional_properties = d diff --git a/robosystems_client/models/search_hit.py b/robosystems_client/models/search_hit.py index f642234..4776ffc 100644 --- a/robosystems_client/models/search_hit.py +++ b/robosystems_client/models/search_hit.py @@ -20,6 +20,7 @@ class SearchHit: score (float): source_type (str): snippet (str): + parent_document_id (None | str | Unset): entity_ticker (None | str | Unset): entity_name (None | str | Unset): section_label (None | str | Unset): @@ -40,6 +41,7 @@ class SearchHit: score: float source_type: str snippet: str + parent_document_id: None | str | Unset = UNSET entity_ticker: None | str | Unset = UNSET entity_name: None | str | Unset = UNSET section_label: None | str | Unset = UNSET @@ -65,6 +67,12 @@ def to_dict(self) -> dict[str, Any]: snippet = self.snippet + parent_document_id: None | str | Unset + if isinstance(self.parent_document_id, Unset): + parent_document_id = UNSET + else: + parent_document_id = self.parent_document_id + entity_ticker: None | str | Unset if isinstance(self.entity_ticker, Unset): entity_ticker = UNSET @@ -161,6 +169,8 @@ def to_dict(self) -> dict[str, Any]: "snippet": snippet, } ) + if parent_document_id is not UNSET: + field_dict["parent_document_id"] = parent_document_id if entity_ticker is not UNSET: field_dict["entity_ticker"] = entity_ticker if entity_name is not UNSET: @@ -203,6 +213,15 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: snippet = d.pop("snippet") + def _parse_parent_document_id(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + parent_document_id = _parse_parent_document_id(d.pop("parent_document_id", UNSET)) + def _parse_entity_ticker(data: object) -> None | str | Unset: if data is None: return data @@ -343,6 +362,7 @@ def _parse_folder(data: object) -> None | str | Unset: score=score, source_type=source_type, snippet=snippet, + parent_document_id=parent_document_id, entity_ticker=entity_ticker, entity_name=entity_name, section_label=section_label, From fef49cc6f4ca50fa67352c7f9674a9eee3d0dfcf Mon Sep 17 00:00:00 2001 From: "Joseph T. French" Date: Fri, 10 Apr 2026 20:43:36 -0500 Subject: [PATCH 2/4] Enhance LedgerClient with new schedule and closing entry functionalities - Added methods for creating and listing schedules, retrieving schedule facts, and checking period close status in the LedgerClient. - Introduced a method for creating draft closing entries from schedule facts. - Updated imports to include new API endpoints related to schedules and closing entries. - Removed the old mapping association request model in favor of a more generic association request model. These enhancements improve the LedgerClient's capabilities for managing financial schedules and closing entries. --- .../extensions/ledger_client.py | 179 +++++- .../extensions/tests/__init__.py | 1 - tests/test_agent_client.py | 335 +++++++++++ .../tests => tests}/test_dataframe_utils.py | 0 tests/test_document_client.py | 433 ++++++++++++++ .../test_extensions_integration.py | 0 tests/test_file_client.py | 458 +++++++++++++++ tests/test_ledger_client.py | 533 ++++++++++++++++++ tests/test_materialization_client.py | 340 +++++++++++ tests/test_operation_client_ops.py | 350 ++++++++++++ tests/test_query_client_ops.py | 524 +++++++++++++++++ tests/test_report_client.py | 302 ++++++++++ tests/test_sse_client.py | 503 +++++++++++++++++ tests/test_table_client.py | 240 ++++++++ .../tests => tests}/test_token_utils.py | 0 .../test_unit.py => tests/test_utils_unit.py | 7 - 16 files changed, 4193 insertions(+), 12 deletions(-) delete mode 100644 robosystems_client/extensions/tests/__init__.py create mode 100644 tests/test_agent_client.py rename {robosystems_client/extensions/tests => tests}/test_dataframe_utils.py (100%) create mode 100644 tests/test_document_client.py rename robosystems_client/extensions/tests/test_integration.py => tests/test_extensions_integration.py (100%) create mode 100644 tests/test_file_client.py create mode 100644 tests/test_ledger_client.py create mode 100644 tests/test_materialization_client.py create mode 100644 tests/test_operation_client_ops.py create mode 100644 tests/test_query_client_ops.py create mode 100644 tests/test_report_client.py create mode 100644 tests/test_sse_client.py create mode 100644 tests/test_table_client.py rename {robosystems_client/extensions/tests => tests}/test_token_utils.py (100%) rename robosystems_client/extensions/tests/test_unit.py => tests/test_utils_unit.py (98%) diff --git a/robosystems_client/extensions/ledger_client.py b/robosystems_client/extensions/ledger_client.py index 365810a..4cca900 100644 --- a/robosystems_client/extensions/ledger_client.py +++ b/robosystems_client/extensions/ledger_client.py @@ -10,13 +10,19 @@ from typing import Any from ..api.ledger.auto_map_elements import sync_detailed as auto_map_elements +from ..api.ledger.create_closing_entry import sync_detailed as create_closing_entry from ..api.ledger.create_mapping_association import ( sync_detailed as create_mapping_association, ) +from ..api.ledger.create_schedule import sync_detailed as create_schedule from ..api.ledger.create_structure import sync_detailed as create_structure from ..api.ledger.delete_mapping_association import ( sync_detailed as delete_mapping_association, ) +from ..api.ledger.get_account_rollups import sync_detailed as get_account_rollups +from ..api.ledger.get_closing_book_structures import ( + sync_detailed as get_closing_book_structures, +) from ..api.ledger.get_ledger_account_tree import ( sync_detailed as get_ledger_account_tree, ) @@ -35,9 +41,13 @@ sync_detailed as get_mapping_coverage, ) from ..api.ledger.get_mapping_detail import sync_detailed as get_mapping_detail +from ..api.ledger.get_period_close_status import ( + sync_detailed as get_period_close_status, +) from ..api.ledger.get_reporting_taxonomy import ( sync_detailed as get_reporting_taxonomy, ) +from ..api.ledger.get_schedule_facts import sync_detailed as get_schedule_facts from ..api.ledger.list_elements import sync_detailed as list_elements from ..api.ledger.list_ledger_accounts import ( sync_detailed as list_ledger_accounts, @@ -46,6 +56,7 @@ sync_detailed as list_ledger_transactions, ) from ..api.ledger.list_mappings import sync_detailed as list_mappings +from ..api.ledger.list_schedules import sync_detailed as list_schedules from ..api.ledger.list_structures import sync_detailed as list_structures from ..client import AuthenticatedClient @@ -280,11 +291,9 @@ def create_mapping( confidence: float = 1.0, ) -> None: """Create a manual mapping association (CoA element → GAAP element).""" - from ..models.create_mapping_association_request import ( - CreateMappingAssociationRequest, - ) + from ..models.create_association_request import CreateAssociationRequest - body = CreateMappingAssociationRequest( + body = CreateAssociationRequest( from_element_id=from_element_id, to_element_id=to_element_id, confidence=confidence, @@ -314,3 +323,165 @@ def auto_map(self, graph_id: str, mapping_id: str) -> dict[str, Any]: if response.status_code != HTTPStatus.ACCEPTED: raise RuntimeError(f"Auto-map failed: {response.status_code}") return response.parsed or {} + + # ── Schedules ────────────────────────────────────────────────────── + + def create_schedule( + self, + graph_id: str, + *, + name: str, + element_ids: list[str], + period_start: str, + period_end: str, + monthly_amount: int, + debit_element_id: str, + credit_element_id: str, + entry_type: str = "closing", + memo_template: str = "", + taxonomy_id: str | None = None, + method: str | None = None, + original_amount: int | None = None, + residual_value: int | None = None, + useful_life_months: int | None = None, + asset_element_id: str | None = None, + auto_reverse: bool = False, + ) -> Any: + """Create a schedule with pre-generated monthly facts.""" + body: dict[str, Any] = { + "name": name, + "element_ids": element_ids, + "period_start": period_start, + "period_end": period_end, + "monthly_amount": monthly_amount, + "entry_template": { + "debit_element_id": debit_element_id, + "credit_element_id": credit_element_id, + "entry_type": entry_type, + "memo_template": memo_template, + "auto_reverse": auto_reverse, + }, + } + if taxonomy_id: + body["taxonomy_id"] = taxonomy_id + schedule_metadata: dict[str, Any] = {} + if method: + schedule_metadata["method"] = method + if original_amount is not None: + schedule_metadata["original_amount"] = original_amount + if residual_value is not None: + schedule_metadata["residual_value"] = residual_value + if useful_life_months is not None: + schedule_metadata["useful_life_months"] = useful_life_months + if asset_element_id: + schedule_metadata["asset_element_id"] = asset_element_id + if schedule_metadata: + body["schedule_metadata"] = schedule_metadata + + response = create_schedule(graph_id=graph_id, body=body, client=self._get_client()) + if response.status_code != HTTPStatus.CREATED: + raise RuntimeError(f"Create schedule failed: {response.status_code}") + return response.parsed + + def list_schedules(self, graph_id: str) -> Any: + """List all active schedule structures.""" + response = list_schedules(graph_id=graph_id, client=self._get_client()) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"List schedules failed: {response.status_code}") + return response.parsed + + def get_schedule_facts( + self, + graph_id: str, + structure_id: str, + period_start: str | None = None, + period_end: str | None = None, + ) -> Any: + """Get fact values for a schedule, optionally filtered by period.""" + response = get_schedule_facts( + graph_id=graph_id, + structure_id=structure_id, + period_start=period_start, + period_end=period_end, + client=self._get_client(), + ) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"Get schedule facts failed: {response.status_code}") + return response.parsed + + def get_period_close_status( + self, + graph_id: str, + period_start: str, + period_end: str, + ) -> Any: + """Get close status for all schedules in a fiscal period.""" + response = get_period_close_status( + graph_id=graph_id, + period_start=period_start, + period_end=period_end, + client=self._get_client(), + ) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"Get period close status failed: {response.status_code}") + return response.parsed + + def create_closing_entry( + self, + graph_id: str, + structure_id: str, + posting_date: str, + period_start: str, + period_end: str, + memo: str | None = None, + ) -> Any: + """Create a draft closing entry from a schedule's facts for a period.""" + body: dict[str, Any] = { + "posting_date": posting_date, + "period_start": period_start, + "period_end": period_end, + } + if memo: + body["memo"] = memo + + response = create_closing_entry( + graph_id=graph_id, + structure_id=structure_id, + body=body, + client=self._get_client(), + ) + if response.status_code != HTTPStatus.CREATED: + raise RuntimeError(f"Create closing entry failed: {response.status_code}") + return response.parsed + + # ── Closing Book ───────────────────────────────────────────────────── + + def get_closing_book_structures(self, graph_id: str) -> Any: + """Get all closing book structure categories for the sidebar.""" + response = get_closing_book_structures(graph_id=graph_id, client=self._get_client()) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"Get closing book structures failed: {response.status_code}") + return response.parsed + + def get_account_rollups( + self, + graph_id: str, + mapping_id: str | None = None, + start_date: str | None = None, + end_date: str | None = None, + ) -> Any: + """Get account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting lines. + Auto-discovers the mapping structure if mapping_id is not provided. + """ + response = get_account_rollups( + graph_id=graph_id, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + client=self._get_client(), + ) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"Get account rollups failed: {response.status_code}") + return response.parsed diff --git a/robosystems_client/extensions/tests/__init__.py b/robosystems_client/extensions/tests/__init__.py deleted file mode 100644 index 6e8da8f..0000000 --- a/robosystems_client/extensions/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for RoboSystems Client Extensions""" diff --git a/tests/test_agent_client.py b/tests/test_agent_client.py new file mode 100644 index 0000000..85b160f --- /dev/null +++ b/tests/test_agent_client.py @@ -0,0 +1,335 @@ +"""Unit tests for AgentClient.""" + +import pytest +from unittest.mock import Mock, patch +from robosystems_client.extensions.agent_client import ( + AgentClient, + AgentQueryRequest, + AgentOptions, + AgentResult, + QueuedAgentResponse, + QueuedAgentError, +) + + +@pytest.mark.unit +class TestAgentDataclasses: + """Test suite for agent-related dataclasses.""" + + def test_agent_query_request(self): + """Test AgentQueryRequest dataclass.""" + request = AgentQueryRequest( + message="What is the revenue?", + history=[{"role": "user", "content": "Hello"}], + context={"fiscal_year": 2024}, + mode="standard", + enable_rag=True, + force_extended_analysis=False, + ) + + assert request.message == "What is the revenue?" + assert len(request.history) == 1 + assert request.context == {"fiscal_year": 2024} + assert request.mode == "standard" + assert request.enable_rag is True + assert request.force_extended_analysis is False + + def test_agent_query_request_defaults(self): + """Test AgentQueryRequest default values.""" + request = AgentQueryRequest(message="Simple query") + + assert request.message == "Simple query" + assert request.history is None + assert request.context is None + assert request.mode is None + assert request.enable_rag is None + assert request.force_extended_analysis is None + + def test_agent_options_defaults(self): + """Test AgentOptions default values.""" + options = AgentOptions() + + assert options.mode == "auto" + assert options.max_wait is None + assert options.on_progress is None + + def test_agent_options_custom(self): + """Test AgentOptions with custom values.""" + progress_fn = Mock() + options = AgentOptions(mode="sync", max_wait=30, on_progress=progress_fn) + + assert options.mode == "sync" + assert options.max_wait == 30 + assert options.on_progress is progress_fn + + def test_agent_result(self): + """Test AgentResult dataclass.""" + result = AgentResult( + content="Revenue is $1B", + agent_used="financial", + mode_used="standard", + metadata={"graph_id": "g-123"}, + tokens_used={"input": 100, "output": 50}, + confidence_score=0.95, + execution_time=1.5, + timestamp="2025-01-15T10:00:00Z", + ) + + assert result.content == "Revenue is $1B" + assert result.agent_used == "financial" + assert result.mode_used == "standard" + assert result.confidence_score == 0.95 + assert result.execution_time == 1.5 + + def test_agent_result_defaults(self): + """Test AgentResult with minimal fields.""" + result = AgentResult(content="Answer", agent_used="rag", mode_used="quick") + + assert result.metadata is None + assert result.tokens_used is None + assert result.confidence_score is None + assert result.execution_time is None + assert result.timestamp is None + + def test_queued_agent_response(self): + """Test QueuedAgentResponse dataclass.""" + response = QueuedAgentResponse( + status="queued", + operation_id="op-agent-1", + message="Agent execution queued", + sse_endpoint="/v1/operations/op-agent-1/stream", + ) + + assert response.status == "queued" + assert response.operation_id == "op-agent-1" + assert response.message == "Agent execution queued" + assert response.sse_endpoint == "/v1/operations/op-agent-1/stream" + + def test_queued_agent_response_defaults(self): + """Test QueuedAgentResponse defaults.""" + response = QueuedAgentResponse( + status="queued", operation_id="op-1", message="Queued" + ) + + assert response.sse_endpoint is None + + def test_queued_agent_error(self): + """Test QueuedAgentError exception.""" + queue_info = QueuedAgentResponse( + status="queued", operation_id="op-err", message="Queued" + ) + error = QueuedAgentError(queue_info) + + assert str(error) == "Agent execution was queued" + assert error.queue_info.operation_id == "op-err" + + +@pytest.mark.unit +class TestAgentClientInit: + """Test suite for AgentClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = AgentClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.sse_client is None + + def test_close_without_sse(self, mock_config): + """Test close when no SSE client exists.""" + client = AgentClient(mock_config) + client.close() # Should not raise + + def test_close_with_sse(self, mock_config): + """Test close cleans up SSE client.""" + client = AgentClient(mock_config) + client.sse_client = Mock() + client.close() + + assert client.sse_client is None + + +@pytest.mark.unit +class TestAgentExecuteQuery: + """Test suite for AgentClient.execute_query method.""" + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_dict_response(self, mock_auto, mock_config, graph_id): + """Test execute_query with dict response.""" + mock_resp = Mock() + mock_resp.parsed = { + "content": "Revenue is $1B for FY2024.", + "agent_used": "financial", + "mode_used": "standard", + "metadata": {}, + "tokens_used": {"input": 100, "output": 50}, + "confidence_score": 0.9, + "execution_time": 2.5, + "timestamp": "2025-01-15T10:00:00Z", + } + mock_auto.return_value = mock_resp + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="What is revenue?") + result = client.execute_query(graph_id, request) + + assert result.content == "Revenue is $1B for FY2024." + assert result.agent_used == "financial" + assert result.confidence_score == 0.9 + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_queued_max_wait_zero(self, mock_auto, mock_config, graph_id): + """Test execute_query raises QueuedAgentError when max_wait=0.""" + mock_resp = Mock() + mock_resp.parsed = { + "operation_id": "op-queued", + "status": "queued", + "message": "Agent execution queued", + } + mock_auto.return_value = mock_resp + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="Complex query") + options = AgentOptions(max_wait=0) + + with pytest.raises(QueuedAgentError) as exc_info: + client.execute_query(graph_id, request, options) + + assert exc_info.value.queue_info.operation_id == "op-queued" + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_no_token(self, mock_auto, mock_config, graph_id): + """Test execute_query fails without token.""" + mock_config["token"] = None + client = AgentClient(mock_config) + request = AgentQueryRequest(message="test") + + with pytest.raises(Exception, match="Authentication failed|No API key"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_auth_error(self, mock_auto, mock_config, graph_id): + """Test execute_query wraps 401 errors.""" + mock_auto.side_effect = Exception("401 Unauthorized") + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="test") + + with pytest.raises(Exception, match="Authentication failed"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_generic_error(self, mock_auto, mock_config, graph_id): + """Test execute_query wraps generic errors.""" + mock_auto.side_effect = Exception("Connection timeout") + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="test") + + with pytest.raises(Exception, match="Agent execution failed"): + client.execute_query(graph_id, request) + + +@pytest.mark.unit +class TestAgentExecuteSpecific: + """Test suite for AgentClient.execute_agent method.""" + + @patch("robosystems_client.extensions.agent_client.execute_specific_agent") + def test_execute_specific_agent(self, mock_exec, mock_config, graph_id): + """Test executing a specific agent type.""" + mock_resp = Mock() + mock_resp.parsed = { + "content": "Deep analysis of financials.", + "agent_used": "financial", + "mode_used": "extended", + } + mock_exec.return_value = mock_resp + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="Analyze financials") + result = client.execute_agent(graph_id, "financial", request) + + assert result.content == "Deep analysis of financials." + assert result.agent_used == "financial" + + @patch("robosystems_client.extensions.agent_client.execute_specific_agent") + def test_execute_specific_agent_queued(self, mock_exec, mock_config, graph_id): + """Test specific agent returns queued response.""" + mock_resp = Mock() + mock_resp.parsed = { + "operation_id": "op-specific-queued", + "status": "queued", + "message": "Queued for processing", + } + mock_exec.return_value = mock_resp + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="Heavy analysis") + options = AgentOptions(max_wait=0) + + with pytest.raises(QueuedAgentError) as exc_info: + client.execute_agent(graph_id, "research", request, options) + + assert exc_info.value.queue_info.operation_id == "op-specific-queued" + + +@pytest.mark.unit +class TestAgentConvenienceMethods: + """Test suite for convenience methods.""" + + @patch.object(AgentClient, "execute_query") + def test_query_convenience(self, mock_exec, mock_config, graph_id): + """Test query() convenience method.""" + mock_exec.return_value = AgentResult( + content="Answer", agent_used="rag", mode_used="quick" + ) + + client = AgentClient(mock_config) + result = client.query(graph_id, "What is X?") + + assert result.content == "Answer" + mock_exec.assert_called_once() + + @patch.object(AgentClient, "execute_agent") + def test_analyze_financials_convenience(self, mock_exec, mock_config, graph_id): + """Test analyze_financials() convenience method.""" + mock_exec.return_value = AgentResult( + content="Financial analysis", agent_used="financial", mode_used="standard" + ) + + client = AgentClient(mock_config) + result = client.analyze_financials(graph_id, "Analyze revenue trends") + + assert result.agent_used == "financial" + call_args = mock_exec.call_args + assert call_args[0][1] == "financial" # agent_type + + @patch.object(AgentClient, "execute_agent") + def test_research_convenience(self, mock_exec, mock_config, graph_id): + """Test research() convenience method.""" + mock_exec.return_value = AgentResult( + content="Research results", agent_used="research", mode_used="extended" + ) + + client = AgentClient(mock_config) + result = client.research(graph_id, "Deep dive into market") + + assert result.agent_used == "research" + call_args = mock_exec.call_args + assert call_args[0][1] == "research" + + @patch.object(AgentClient, "execute_agent") + def test_rag_convenience(self, mock_exec, mock_config, graph_id): + """Test rag() convenience method.""" + mock_exec.return_value = AgentResult( + content="RAG answer", agent_used="rag", mode_used="quick" + ) + + client = AgentClient(mock_config) + result = client.rag(graph_id, "Quick lookup") + + assert result.agent_used == "rag" + call_args = mock_exec.call_args + assert call_args[0][1] == "rag" diff --git a/robosystems_client/extensions/tests/test_dataframe_utils.py b/tests/test_dataframe_utils.py similarity index 100% rename from robosystems_client/extensions/tests/test_dataframe_utils.py rename to tests/test_dataframe_utils.py diff --git a/tests/test_document_client.py b/tests/test_document_client.py new file mode 100644 index 0000000..cf77c0c --- /dev/null +++ b/tests/test_document_client.py @@ -0,0 +1,433 @@ +"""Unit tests for DocumentClient.""" + +import pytest +from http import HTTPStatus +from unittest.mock import Mock, patch +from robosystems_client.extensions.document_client import DocumentClient + + +@pytest.mark.unit +class TestDocumentClientInit: + """Test suite for DocumentClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = DocumentClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.timeout == 60 + + def test_client_custom_timeout(self, mock_config): + """Test client with custom timeout.""" + mock_config["timeout"] = 120 + client = DocumentClient(mock_config) + + assert client.timeout == 120 + + def test_get_client_no_token(self, mock_config): + """Test _get_client raises without token.""" + mock_config["token"] = None + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="No API key"): + client._get_client() + + def test_close_is_noop(self, mock_config): + """Test close method doesn't raise.""" + client = DocumentClient(mock_config) + client.close() # Should not raise + + +@pytest.mark.unit +class TestDocumentUpload: + """Test suite for DocumentClient.upload method.""" + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_document(self, mock_upload, mock_config, graph_id): + """Test uploading a markdown document.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-123", section_count=3) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.upload( + graph_id=graph_id, + title="Test Document", + content="# Heading\n\nSome content.", + tags=["test", "demo"], + folder="reports", + ) + + assert result.document_id == "doc-123" + assert result.section_count == 3 + mock_upload.assert_called_once() + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_document_failure(self, mock_upload, mock_config, graph_id): + """Test upload failure raises exception.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_resp.content = b"Server error" + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Document upload failed"): + client.upload(graph_id=graph_id, title="Bad", content="content") + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_minimal(self, mock_upload, mock_config, graph_id): + """Test upload with only required fields.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-min", section_count=1) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.upload(graph_id=graph_id, title="Minimal", content="Just text.") + + assert result.document_id == "doc-min" + + +@pytest.mark.unit +class TestDocumentGet: + """Test suite for DocumentClient.get method.""" + + @patch("robosystems_client.extensions.document_client.get_document") + def test_get_document(self, mock_get, mock_config, graph_id): + """Test getting a document by ID.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + document_id="doc-456", title="Found Document", content="# Content" + ) + mock_get.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.get(graph_id=graph_id, document_id="doc-456") + + assert result is not None + assert result.title == "Found Document" + + @patch("robosystems_client.extensions.document_client.get_document") + def test_get_document_not_found(self, mock_get, mock_config, graph_id): + """Test getting a document that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.get(graph_id=graph_id, document_id="nonexistent") + + assert result is None + + @patch("robosystems_client.extensions.document_client.get_document") + def test_get_document_server_error(self, mock_get, mock_config, graph_id): + """Test get raises on server error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_resp.content = b"Internal error" + mock_get.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Get document failed"): + client.get(graph_id=graph_id, document_id="doc-err") + + +@pytest.mark.unit +class TestDocumentUpdate: + """Test suite for DocumentClient.update method.""" + + @patch("robosystems_client.extensions.document_client.update_document") + def test_update_document(self, mock_update, mock_config, graph_id): + """Test updating a document.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-upd", section_count=5) + mock_update.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.update( + graph_id=graph_id, + document_id="doc-upd", + title="Updated Title", + content="# Updated\n\nNew content.", + ) + + assert result.section_count == 5 + + @patch("robosystems_client.extensions.document_client.update_document") + def test_update_document_failure(self, mock_update, mock_config, graph_id): + """Test update failure raises exception.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_resp.content = b"Bad request" + mock_update.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Update document failed"): + client.update(graph_id=graph_id, document_id="doc-bad", title="Bad") + + +@pytest.mark.unit +class TestDocumentSearch: + """Test suite for DocumentClient.search method.""" + + @patch("robosystems_client.extensions.document_client.search_documents") + def test_search_documents(self, mock_search, mock_config, graph_id): + """Test searching documents.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=2, hits=[Mock(), Mock()]) + mock_search.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.search(graph_id=graph_id, query="revenue growth") + + assert result.total == 2 + assert len(result.hits) == 2 + + @patch("robosystems_client.extensions.document_client.search_documents") + def test_search_with_filters(self, mock_search, mock_config, graph_id): + """Test searching with filters.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=1, hits=[Mock()]) + mock_search.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.search( + graph_id=graph_id, + query="risk factors", + source_type="xbrl_textblock", + form_type="10-K", + entity="AAPL", + fiscal_year=2024, + size=5, + ) + + assert result.total == 1 + mock_search.assert_called_once() + + @patch("robosystems_client.extensions.document_client.search_documents") + def test_search_failure(self, mock_search, mock_config, graph_id): + """Test search failure raises exception.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_resp.content = b"Search error" + mock_search.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Document search failed"): + client.search(graph_id=graph_id, query="test") + + +@pytest.mark.unit +class TestDocumentList: + """Test suite for DocumentClient.list method.""" + + @patch("robosystems_client.extensions.document_client.list_documents") + def test_list_documents(self, mock_list, mock_config, graph_id): + """Test listing documents.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=3, documents=[Mock(), Mock(), Mock()]) + mock_list.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.list(graph_id=graph_id) + + assert result.total == 3 + + @patch("robosystems_client.extensions.document_client.list_documents") + def test_list_documents_with_filter(self, mock_list, mock_config, graph_id): + """Test listing with source type filter.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=1, documents=[Mock()]) + mock_list.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.list(graph_id=graph_id, source_type="uploaded_doc") + + assert result.total == 1 + + +@pytest.mark.unit +class TestDocumentDelete: + """Test suite for DocumentClient.delete method.""" + + @patch("robosystems_client.extensions.document_client.delete_document") + def test_delete_document(self, mock_delete, mock_config, graph_id): + """Test deleting a document.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NO_CONTENT + mock_delete.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.delete(graph_id=graph_id, document_id="doc-del") + + assert result is True + + @patch("robosystems_client.extensions.document_client.delete_document") + def test_delete_document_not_found(self, mock_delete, mock_config, graph_id): + """Test deleting a document that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_delete.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.delete(graph_id=graph_id, document_id="doc-missing") + + assert result is False + + @patch("robosystems_client.extensions.document_client.delete_document") + def test_delete_document_server_error(self, mock_delete, mock_config, graph_id): + """Test delete raises on server error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_resp.content = b"Delete error" + mock_delete.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Delete document failed"): + client.delete(graph_id=graph_id, document_id="doc-err") + + +@pytest.mark.unit +class TestDocumentSection: + """Test suite for DocumentClient.get_section method.""" + + @patch("robosystems_client.extensions.document_client.get_document_section") + def test_get_section(self, mock_get_section, mock_config, graph_id): + """Test getting a document section.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + section_id="sec-1", title="Revenue", content="Revenue was $1B." + ) + mock_get_section.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.get_section(graph_id=graph_id, document_id="sec-1") + + assert result is not None + assert result.title == "Revenue" + + @patch("robosystems_client.extensions.document_client.get_document_section") + def test_get_section_not_found(self, mock_get_section, mock_config, graph_id): + """Test getting a section that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get_section.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.get_section(graph_id=graph_id, document_id="nonexistent") + + assert result is None + + +@pytest.mark.unit +class TestDocumentBulkUpload: + """Test suite for DocumentClient.upload_bulk method.""" + + @patch("robosystems_client.extensions.document_client.upload_documents_bulk") + def test_upload_bulk(self, mock_bulk, mock_config, graph_id): + """Test bulk uploading documents.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=2, succeeded=2, failed=0, results=[Mock(), Mock()]) + mock_bulk.return_value = mock_resp + + client = DocumentClient(mock_config) + docs = [ + {"title": "Doc 1", "content": "Content 1"}, + {"title": "Doc 2", "content": "Content 2", "tags": ["tag1"]}, + ] + result = client.upload_bulk(graph_id=graph_id, documents=docs) + + assert result.total == 2 + assert result.succeeded == 2 + + @patch("robosystems_client.extensions.document_client.upload_documents_bulk") + def test_upload_bulk_failure(self, mock_bulk, mock_config, graph_id): + """Test bulk upload failure raises exception.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_resp.content = b"Bulk upload error" + mock_bulk.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Bulk upload failed"): + client.upload_bulk(graph_id=graph_id, documents=[{"title": "x", "content": "y"}]) + + +@pytest.mark.unit +class TestDocumentUploadFile: + """Test suite for DocumentClient.upload_file method.""" + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_file(self, mock_upload, mock_config, graph_id, tmp_path): + """Test uploading a file from disk.""" + md_file = tmp_path / "test-doc.md" + md_file.write_text("# Test\n\nContent here.") + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-file", section_count=2) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.upload_file(graph_id=graph_id, file_path=md_file) + + assert result.document_id == "doc-file" + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_file_title_from_filename( + self, mock_upload, mock_config, graph_id, tmp_path + ): + """Test that title is derived from filename when not provided.""" + md_file = tmp_path / "revenue-analysis.md" + md_file.write_text("# Revenue\n\nAnalysis content.") + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-title", section_count=1) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + client.upload_file(graph_id=graph_id, file_path=md_file) + + # Verify the body was created with title derived from filename + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["body"].title == "Revenue Analysis" + + +@pytest.mark.unit +class TestDocumentUploadDirectory: + """Test suite for DocumentClient.upload_directory method.""" + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_directory(self, mock_upload, mock_config, graph_id, tmp_path): + """Test uploading all markdown files from a directory.""" + (tmp_path / "doc1.md").write_text("# Doc 1") + (tmp_path / "doc2.md").write_text("# Doc 2") + (tmp_path / "not-md.txt").write_text("Skip this") + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-dir", section_count=1) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + results = client.upload_directory(graph_id=graph_id, directory=tmp_path) + + assert len(results) == 2 # Only .md files diff --git a/robosystems_client/extensions/tests/test_integration.py b/tests/test_extensions_integration.py similarity index 100% rename from robosystems_client/extensions/tests/test_integration.py rename to tests/test_extensions_integration.py diff --git a/tests/test_file_client.py b/tests/test_file_client.py new file mode 100644 index 0000000..ddbd73f --- /dev/null +++ b/tests/test_file_client.py @@ -0,0 +1,458 @@ +"""Unit tests for FileClient.""" + +import pytest +from io import BytesIO +from unittest.mock import Mock, patch +from robosystems_client.extensions.file_client import ( + FileClient, + FileUploadOptions, + FileUploadResult, + FileInfo, +) + + +@pytest.mark.unit +class TestFileClientInit: + """Test suite for FileClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = FileClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.s3_endpoint_url is None + + def test_client_with_s3_override(self, mock_config): + """Test client with S3 endpoint override.""" + mock_config["s3_endpoint_url"] = "http://localhost:4566" + client = FileClient(mock_config) + + assert client.s3_endpoint_url == "http://localhost:4566" + + def test_client_cleanup(self, mock_config): + """Test that cleanup closes HTTP client.""" + client = FileClient(mock_config) + assert hasattr(client, "_http_client") + client.__del__() + + +@pytest.mark.unit +class TestFileDataclasses: + """Test suite for file-related dataclasses.""" + + def test_file_upload_options_defaults(self): + """Test FileUploadOptions default values.""" + options = FileUploadOptions() + + assert options.on_progress is None + assert options.ingest_to_graph is False + + def test_file_upload_options_custom(self): + """Test FileUploadOptions with custom values.""" + progress_fn = Mock() + options = FileUploadOptions(on_progress=progress_fn, ingest_to_graph=True) + + assert options.on_progress is progress_fn + assert options.ingest_to_graph is True + + def test_file_upload_result(self): + """Test FileUploadResult dataclass creation.""" + result = FileUploadResult( + file_id="file-123", + file_size=50000, + row_count=100, + table_name="Entity", + file_name="data.parquet", + success=True, + ) + + assert result.file_id == "file-123" + assert result.file_size == 50000 + assert result.row_count == 100 + assert result.table_name == "Entity" + assert result.file_name == "data.parquet" + assert result.success is True + assert result.error is None + + def test_file_upload_result_with_error(self): + """Test FileUploadResult with error.""" + result = FileUploadResult( + file_id="", + file_size=0, + row_count=0, + table_name="Entity", + file_name="bad.parquet", + success=False, + error="Upload failed", + ) + + assert result.success is False + assert result.error == "Upload failed" + + def test_file_info(self): + """Test FileInfo dataclass creation.""" + info = FileInfo( + file_id="file-456", + file_name="transactions.parquet", + file_format="parquet", + size_bytes=120000, + row_count=500, + upload_status="uploaded", + table_name="Transaction", + created_at="2025-01-15T10:00:00Z", + uploaded_at="2025-01-15T10:01:00Z", + ) + + assert info.file_id == "file-456" + assert info.file_name == "transactions.parquet" + assert info.file_format == "parquet" + assert info.size_bytes == 120000 + assert info.row_count == 500 + assert info.upload_status == "uploaded" + assert info.table_name == "Transaction" + assert info.layers is None + + def test_file_info_with_layers(self): + """Test FileInfo with layer tracking.""" + layers = {"s3": "uploaded", "duckdb": "staged", "graph": "pending"} + info = FileInfo( + file_id="file-789", + file_name="data.parquet", + file_format="parquet", + size_bytes=1000, + row_count=10, + upload_status="uploaded", + table_name="Entity", + created_at=None, + uploaded_at=None, + layers=layers, + ) + + assert info.layers == layers + + +@pytest.mark.unit +class TestFileUpload: + """Test suite for FileClient.upload method.""" + + @patch("robosystems_client.extensions.file_client.update_file") + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_bytesio(self, mock_create, mock_update, mock_config, graph_id): + """Test uploading a BytesIO buffer.""" + # Mock presigned URL response + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = "http://s3.localhost/bucket/file.parquet" + mock_create_resp.parsed.file_id = "file-new-123" + mock_create.return_value = mock_create_resp + + # Mock update response + mock_update_resp = Mock() + mock_update_resp.status_code = 200 + mock_update_resp.parsed = Mock() + mock_update_resp.parsed.file_size_bytes = 1234 + mock_update_resp.parsed.row_count = 10 + mock_update.return_value = mock_update_resp + + client = FileClient(mock_config) + # Mock the S3 PUT + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=200) + + buf = BytesIO(b"fake parquet data") + result = client.upload(graph_id, "Entity", buf) + + assert result.success is True + assert result.file_id == "file-new-123" + assert result.file_size == 1234 + assert result.row_count == 10 + assert result.table_name == "Entity" + + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_presigned_url_failure(self, mock_create, mock_config, graph_id): + """Test upload failure when presigned URL request fails.""" + mock_create_resp = Mock() + mock_create_resp.status_code = 500 + mock_create_resp.parsed = None + mock_create.return_value = mock_create_resp + + client = FileClient(mock_config) + result = client.upload(graph_id, "Entity", BytesIO(b"data")) + + assert result.success is False + assert "Failed to get upload URL" in result.error + + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_s3_failure(self, mock_create, mock_config, graph_id): + """Test upload failure when S3 PUT fails.""" + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = "http://s3.localhost/bucket/file.parquet" + mock_create_resp.parsed.file_id = "file-s3-fail" + mock_create.return_value = mock_create_resp + + client = FileClient(mock_config) + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=500) + + result = client.upload(graph_id, "Entity", BytesIO(b"data")) + + assert result.success is False + assert "S3 upload failed" in result.error + + @patch("robosystems_client.extensions.file_client.update_file") + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_status_update_failure( + self, mock_create, mock_update, mock_config, graph_id + ): + """Test upload failure when status update fails.""" + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = "http://s3.localhost/bucket/file.parquet" + mock_create_resp.parsed.file_id = "file-update-fail" + mock_create.return_value = mock_create_resp + + mock_update_resp = Mock() + mock_update_resp.status_code = 500 + mock_update_resp.parsed = None + mock_update.return_value = mock_update_resp + + client = FileClient(mock_config) + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=200) + + result = client.upload(graph_id, "Entity", BytesIO(b"data")) + + assert result.success is False + assert "Failed to complete file upload" in result.error + + def test_upload_no_token(self, mock_config, graph_id): + """Test upload fails without API key.""" + mock_config["token"] = None + client = FileClient(mock_config) + + result = client.upload(graph_id, "Entity", BytesIO(b"data")) + + assert result.success is False + assert "No API key" in result.error + + def test_upload_unsupported_type(self, mock_config, graph_id): + """Test upload fails with unsupported file type.""" + client = FileClient(mock_config) + + result = client.upload(graph_id, "Entity", 12345) + + assert result.success is False + assert "Unsupported file type" in result.error + + @patch("robosystems_client.extensions.file_client.update_file") + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_with_progress_callback( + self, mock_create, mock_update, mock_config, graph_id + ): + """Test upload calls progress callback at each step.""" + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = "http://s3.localhost/bucket/file.parquet" + mock_create_resp.parsed.file_id = "file-progress" + mock_create.return_value = mock_create_resp + + mock_update_resp = Mock() + mock_update_resp.status_code = 200 + mock_update_resp.parsed = Mock() + mock_update_resp.parsed.file_size_bytes = 100 + mock_update_resp.parsed.row_count = 5 + mock_update.return_value = mock_update_resp + + client = FileClient(mock_config) + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=200) + + progress_messages = [] + options = FileUploadOptions(on_progress=lambda msg: progress_messages.append(msg)) + + client.upload(graph_id, "Entity", BytesIO(b"data"), options) + + assert len(progress_messages) >= 3 # URL, upload, mark uploaded + + @patch("robosystems_client.extensions.file_client.update_file") + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_with_s3_endpoint_override( + self, mock_create, mock_update, mock_config, graph_id + ): + """Test upload rewrites S3 URL when s3_endpoint_url is set.""" + mock_config["s3_endpoint_url"] = "http://localhost:4566" + + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = ( + "https://s3.amazonaws.com/bucket/file.parquet?sig=abc" + ) + mock_create_resp.parsed.file_id = "file-s3-override" + mock_create.return_value = mock_create_resp + + mock_update_resp = Mock() + mock_update_resp.status_code = 200 + mock_update_resp.parsed = Mock() + mock_update_resp.parsed.file_size_bytes = 100 + mock_update_resp.parsed.row_count = 5 + mock_update.return_value = mock_update_resp + + client = FileClient(mock_config) + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=200) + + client.upload(graph_id, "Entity", BytesIO(b"data")) + + # Verify S3 PUT was called with overridden URL + put_url = client._http_client.put.call_args[0][0] + assert "localhost:4566" in put_url + + +@pytest.mark.unit +class TestFileList: + """Test suite for FileClient.list method.""" + + @patch("robosystems_client.extensions.file_client.list_files") + def test_list_files(self, mock_list, mock_config, graph_id): + """Test listing files returns FileInfo objects.""" + mock_file = Mock() + mock_file.file_id = "file-1" + mock_file.file_name = "data.parquet" + mock_file.file_format = "parquet" + mock_file.size_bytes = 5000 + mock_file.row_count = 50 + mock_file.upload_status = "uploaded" + mock_file.table_name = "Entity" + mock_file.created_at = "2025-01-01T00:00:00Z" + mock_file.uploaded_at = "2025-01-01T00:01:00Z" + + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.files = [mock_file] + mock_list.return_value = mock_resp + + client = FileClient(mock_config) + files = client.list(graph_id) + + assert len(files) == 1 + assert files[0].file_id == "file-1" + assert files[0].file_name == "data.parquet" + assert files[0].size_bytes == 5000 + + @patch("robosystems_client.extensions.file_client.list_files") + def test_list_files_failure(self, mock_list, mock_config, graph_id): + """Test listing files returns empty list on failure.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_resp.parsed = None + mock_list.return_value = mock_resp + + client = FileClient(mock_config) + files = client.list(graph_id) + + assert files == [] + + def test_list_files_no_token(self, mock_config, graph_id): + """Test listing files returns empty list without token.""" + mock_config["token"] = None + client = FileClient(mock_config) + files = client.list(graph_id) + + assert files == [] + + +@pytest.mark.unit +class TestFileGet: + """Test suite for FileClient.get method.""" + + @patch("robosystems_client.extensions.file_client.get_file") + def test_get_file(self, mock_get, mock_config, graph_id): + """Test getting a specific file.""" + mock_file_data = Mock() + mock_file_data.file_id = "file-detail" + mock_file_data.file_name = "detail.parquet" + mock_file_data.file_format = "parquet" + mock_file_data.size_bytes = 8000 + mock_file_data.row_count = 80 + mock_file_data.upload_status = "uploaded" + mock_file_data.table_name = "Entity" + mock_file_data.created_at = "2025-01-01T00:00:00Z" + mock_file_data.uploaded_at = "2025-01-01T00:01:00Z" + mock_file_data.layers = {"s3": "uploaded", "duckdb": "staged"} + + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = mock_file_data + mock_get.return_value = mock_resp + + client = FileClient(mock_config) + info = client.get(graph_id, "file-detail") + + assert info is not None + assert info.file_id == "file-detail" + assert info.layers == {"s3": "uploaded", "duckdb": "staged"} + + @patch("robosystems_client.extensions.file_client.get_file") + def test_get_file_not_found(self, mock_get, mock_config, graph_id): + """Test getting a file that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = 404 + mock_resp.parsed = None + mock_get.return_value = mock_resp + + client = FileClient(mock_config) + info = client.get(graph_id, "nonexistent") + + assert info is None + + +@pytest.mark.unit +class TestFileDelete: + """Test suite for FileClient.delete method.""" + + @patch("robosystems_client.extensions.file_client.delete_file") + def test_delete_file(self, mock_delete, mock_config, graph_id): + """Test deleting a file.""" + mock_resp = Mock() + mock_resp.status_code = 204 + mock_delete.return_value = mock_resp + + client = FileClient(mock_config) + result = client.delete(graph_id, "file-to-delete") + + assert result is True + + @patch("robosystems_client.extensions.file_client.delete_file") + def test_delete_file_failure(self, mock_delete, mock_config, graph_id): + """Test delete failure returns False.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_delete.return_value = mock_resp + + client = FileClient(mock_config) + result = client.delete(graph_id, "file-to-delete") + + assert result is False + + @patch("robosystems_client.extensions.file_client.delete_file") + def test_delete_file_with_cascade(self, mock_delete, mock_config, graph_id): + """Test cascade delete passes parameter.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_delete.return_value = mock_resp + + client = FileClient(mock_config) + result = client.delete(graph_id, "file-cascade", cascade=True) + + assert result is True + call_kwargs = mock_delete.call_args[1] + assert call_kwargs["cascade"] is True diff --git a/tests/test_ledger_client.py b/tests/test_ledger_client.py new file mode 100644 index 0000000..7e075dd --- /dev/null +++ b/tests/test_ledger_client.py @@ -0,0 +1,533 @@ +"""Unit tests for LedgerClient.""" + +import pytest +from http import HTTPStatus +from unittest.mock import Mock, patch +from robosystems_client.extensions.ledger_client import LedgerClient + + +@pytest.mark.unit +class TestLedgerClientInit: + """Test suite for LedgerClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = LedgerClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.timeout == 60 + + def test_get_client_no_token(self, mock_config): + """Test _get_client raises without token.""" + mock_config["token"] = None + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="No API key"): + client._get_client() + + +@pytest.mark.unit +class TestLedgerEntity: + """Test suite for entity operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_entity") + def test_get_entity(self, mock_get, mock_config, graph_id): + """Test getting entity for a graph.""" + mock_parsed = Mock() + mock_parsed.entity_name = "ACME Corp" + mock_parsed.cik = "0001234567" + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = mock_parsed + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_entity(graph_id) + + assert result.entity_name == "ACME Corp" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_entity") + def test_get_entity_not_found(self, mock_get, mock_config, graph_id): + """Test getting entity that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_entity(graph_id) + + assert result is None + + @patch("robosystems_client.extensions.ledger_client.get_ledger_entity") + def test_get_entity_error(self, mock_get, mock_config, graph_id): + """Test get entity raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Get entity failed"): + client.get_entity(graph_id) + + +@pytest.mark.unit +class TestLedgerAccounts: + """Test suite for account operations.""" + + @patch("robosystems_client.extensions.ledger_client.list_ledger_accounts") + def test_list_accounts(self, mock_list, mock_config, graph_id): + """Test listing accounts.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(accounts=[Mock(), Mock()]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_accounts(graph_id) + + assert len(result.accounts) == 2 + + @patch("robosystems_client.extensions.ledger_client.list_ledger_accounts") + def test_list_accounts_error(self, mock_list, mock_config, graph_id): + """Test list accounts raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="List accounts failed"): + client.list_accounts(graph_id) + + @patch("robosystems_client.extensions.ledger_client.get_ledger_account_tree") + def test_get_account_tree(self, mock_tree, mock_config, graph_id): + """Test getting account tree.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(root=Mock(children=[Mock(), Mock()])) + mock_tree.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_account_tree(graph_id) + + assert len(result.root.children) == 2 + + +@pytest.mark.unit +class TestLedgerTransactions: + """Test suite for transaction operations.""" + + @patch("robosystems_client.extensions.ledger_client.list_ledger_transactions") + def test_list_transactions(self, mock_list, mock_config, graph_id): + """Test listing transactions.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(transactions=[Mock()]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_transactions(graph_id) + + assert len(result.transactions) == 1 + + @patch("robosystems_client.extensions.ledger_client.list_ledger_transactions") + def test_list_transactions_with_filters(self, mock_list, mock_config, graph_id): + """Test listing transactions with date filters.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(transactions=[]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + client.list_transactions( + graph_id, + start_date="2025-01-01", + end_date="2025-03-31", + limit=50, + offset=0, + ) + + mock_list.assert_called_once() + call_kwargs = mock_list.call_args[1] + assert call_kwargs["start_date"] == "2025-01-01" + assert call_kwargs["end_date"] == "2025-03-31" + assert call_kwargs["limit"] == 50 + + @patch("robosystems_client.extensions.ledger_client.get_ledger_transaction") + def test_get_transaction(self, mock_get, mock_config, graph_id): + """Test getting a specific transaction.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(transaction_id="txn-123", entries=[Mock()]) + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_transaction(graph_id, "txn-123") + + assert result.transaction_id == "txn-123" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_transaction") + def test_get_transaction_error(self, mock_get, mock_config, graph_id): + """Test get transaction raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Get transaction failed"): + client.get_transaction(graph_id, "nonexistent") + + +@pytest.mark.unit +class TestLedgerTrialBalance: + """Test suite for trial balance operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_trial_balance") + def test_get_trial_balance(self, mock_tb, mock_config, graph_id): + """Test getting trial balance.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + total_debits=100000, total_credits=100000, accounts=[Mock()] + ) + mock_tb.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_trial_balance(graph_id) + + assert result.total_debits == 100000 + assert result.total_credits == 100000 + + @patch("robosystems_client.extensions.ledger_client.get_ledger_trial_balance") + def test_get_trial_balance_with_dates(self, mock_tb, mock_config, graph_id): + """Test getting trial balance with date range.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock() + mock_tb.return_value = mock_resp + + client = LedgerClient(mock_config) + client.get_trial_balance(graph_id, start_date="2025-01-01", end_date="2025-03-31") + + call_kwargs = mock_tb.call_args[1] + assert call_kwargs["start_date"] == "2025-01-01" + assert call_kwargs["end_date"] == "2025-03-31" + + @patch("robosystems_client.extensions.ledger_client.get_mapped_trial_balance") + def test_get_mapped_trial_balance(self, mock_mtb, mock_config, graph_id): + """Test getting mapped trial balance.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(elements=[Mock()]) + mock_mtb.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_mapped_trial_balance(graph_id, mapping_id="map-1") + + assert len(result.elements) == 1 + + +@pytest.mark.unit +class TestLedgerSummary: + """Test suite for summary operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_summary") + def test_get_summary(self, mock_summary, mock_config, graph_id): + """Test getting ledger summary.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + account_count=50, transaction_count=200, entity_name="ACME Corp" + ) + mock_summary.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_summary(graph_id) + + assert result.account_count == 50 + assert result.transaction_count == 200 + + +@pytest.mark.unit +class TestLedgerTaxonomy: + """Test suite for taxonomy operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_reporting_taxonomy") + def test_get_reporting_taxonomy(self, mock_tax, mock_config, graph_id): + """Test getting reporting taxonomy.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(taxonomy_id="tax_usgaap_reporting") + mock_tax.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_reporting_taxonomy(graph_id) + + assert result.taxonomy_id == "tax_usgaap_reporting" + + @patch("robosystems_client.extensions.ledger_client.list_structures") + def test_list_structures(self, mock_structs, mock_config, graph_id): + """Test listing reporting structures.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(structures=[Mock(), Mock(), Mock()]) + mock_structs.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_structures(graph_id) + + assert len(result.structures) == 3 + + @patch("robosystems_client.extensions.ledger_client.list_elements") + def test_list_elements(self, mock_elems, mock_config, graph_id): + """Test listing elements.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(elements=[Mock()]) + mock_elems.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_elements(graph_id, source="coa", limit=10) + + assert len(result.elements) == 1 + call_kwargs = mock_elems.call_args[1] + assert call_kwargs["source"] == "coa" + assert call_kwargs["limit"] == 10 + + +@pytest.mark.unit +class TestLedgerMappings: + """Test suite for mapping operations.""" + + @patch("robosystems_client.extensions.ledger_client.list_mappings") + def test_list_mappings(self, mock_list, mock_config, graph_id): + """Test listing mappings.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(mappings=[Mock()]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_mappings(graph_id) + + assert len(result.mappings) == 1 + + @patch("robosystems_client.extensions.ledger_client.get_mapping_detail") + def test_get_mapping_detail(self, mock_detail, mock_config, graph_id): + """Test getting mapping detail.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(mapping_id="map-1", associations=[Mock(), Mock()]) + mock_detail.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_mapping_detail(graph_id, "map-1") + + assert len(result.associations) == 2 + + @patch("robosystems_client.extensions.ledger_client.get_mapping_coverage") + def test_get_mapping_coverage(self, mock_cov, mock_config, graph_id): + """Test getting mapping coverage.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total_elements=50, mapped_elements=45, coverage=0.9) + mock_cov.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_mapping_coverage(graph_id, "map-1") + + assert result.coverage == 0.9 + + @patch("robosystems_client.extensions.ledger_client.create_mapping_association") + def test_create_mapping(self, mock_create, mock_config, graph_id): + """Test creating a mapping association.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_create.return_value = mock_resp + + client = LedgerClient(mock_config) + client.create_mapping( + graph_id, "map-1", from_element_id="elem-a", to_element_id="elem-b" + ) + + mock_create.assert_called_once() + + @patch("robosystems_client.extensions.ledger_client.create_mapping_association") + def test_create_mapping_error(self, mock_create, mock_config, graph_id): + """Test create mapping raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_create.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Create mapping failed"): + client.create_mapping(graph_id, "map-1", "elem-a", "elem-b") + + @patch("robosystems_client.extensions.ledger_client.delete_mapping_association") + def test_delete_mapping(self, mock_delete, mock_config, graph_id): + """Test deleting a mapping association.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NO_CONTENT + mock_delete.return_value = mock_resp + + client = LedgerClient(mock_config) + client.delete_mapping(graph_id, "map-1", "assoc-1") + + mock_delete.assert_called_once() + + @patch("robosystems_client.extensions.ledger_client.auto_map_elements") + def test_auto_map(self, mock_auto, mock_config, graph_id): + """Test triggering auto-mapping.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.ACCEPTED + mock_resp.parsed = {"operation_id": "op-auto-1"} + mock_auto.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.auto_map(graph_id, "map-1") + + assert result["operation_id"] == "op-auto-1" + + @patch("robosystems_client.extensions.ledger_client.auto_map_elements") + def test_auto_map_error(self, mock_auto, mock_config, graph_id): + """Test auto-map raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_auto.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Auto-map failed"): + client.auto_map(graph_id, "map-1") + + @patch("robosystems_client.extensions.ledger_client.create_structure") + def test_create_mapping_structure(self, mock_create_struct, mock_config, graph_id): + """Test creating a mapping structure.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_resp.parsed = Mock(structure_id="struct-new") + mock_create_struct.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.create_mapping_structure(graph_id, name="Custom Mapping") + + assert result.structure_id == "struct-new" + + +@pytest.mark.unit +class TestLedgerSchedules: + """Test suite for schedule operations.""" + + @patch("robosystems_client.extensions.ledger_client.list_schedules") + def test_list_schedules(self, mock_list, mock_config, graph_id): + """Test listing schedules.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(schedules=[Mock()]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_schedules(graph_id) + + assert len(result.schedules) == 1 + + @patch("robosystems_client.extensions.ledger_client.get_schedule_facts") + def test_get_schedule_facts(self, mock_facts, mock_config, graph_id): + """Test getting schedule facts.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(facts=[Mock(), Mock()]) + mock_facts.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_schedule_facts(graph_id, "struct-1") + + assert len(result.facts) == 2 + + @patch("robosystems_client.extensions.ledger_client.get_period_close_status") + def test_get_period_close_status(self, mock_status, mock_config, graph_id): + """Test getting period close status.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(status="open", schedules=[Mock()]) + mock_status.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_period_close_status( + graph_id, period_start="2025-01-01", period_end="2025-03-31" + ) + + assert result.status == "open" + + @patch("robosystems_client.extensions.ledger_client.create_closing_entry") + def test_create_closing_entry(self, mock_close, mock_config, graph_id): + """Test creating a closing entry.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_resp.parsed = Mock(transaction_id="txn-close-1") + mock_close.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.create_closing_entry( + graph_id, + structure_id="struct-1", + posting_date="2025-03-31", + period_start="2025-01-01", + period_end="2025-03-31", + memo="Q1 close", + ) + + assert result.transaction_id == "txn-close-1" + + @patch("robosystems_client.extensions.ledger_client.create_closing_entry") + def test_create_closing_entry_error(self, mock_close, mock_config, graph_id): + """Test create closing entry raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_close.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Create closing entry failed"): + client.create_closing_entry( + graph_id, "struct-1", "2025-03-31", "2025-01-01", "2025-03-31" + ) + + +@pytest.mark.unit +class TestLedgerClosingBook: + """Test suite for closing book operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_closing_book_structures") + def test_get_closing_book_structures(self, mock_cb, mock_config, graph_id): + """Test getting closing book structures.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(categories=[Mock(), Mock()]) + mock_cb.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_closing_book_structures(graph_id) + + assert len(result.categories) == 2 + + @patch("robosystems_client.extensions.ledger_client.get_account_rollups") + def test_get_account_rollups(self, mock_rollups, mock_config, graph_id): + """Test getting account rollups.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(rollups=[Mock()]) + mock_rollups.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_account_rollups( + graph_id, mapping_id="map-1", start_date="2025-01-01" + ) + + assert len(result.rollups) == 1 diff --git a/tests/test_materialization_client.py b/tests/test_materialization_client.py new file mode 100644 index 0000000..a3fa61a --- /dev/null +++ b/tests/test_materialization_client.py @@ -0,0 +1,340 @@ +"""Unit tests for MaterializationClient.""" + +import pytest +from unittest.mock import Mock, patch +from robosystems_client.extensions.materialization_client import ( + MaterializationClient, + MaterializationOptions, + MaterializationResult, + MaterializationStatus, +) +from robosystems_client.extensions.operation_client import ( + OperationResult, + OperationStatus, +) + + +@pytest.mark.unit +class TestMaterializationDataclasses: + """Test suite for materialization-related dataclasses.""" + + def test_materialization_options_defaults(self): + """Test MaterializationOptions default values.""" + options = MaterializationOptions() + + assert options.ignore_errors is True + assert options.rebuild is False + assert options.force is False + assert options.materialize_embeddings is False + assert options.on_progress is None + assert options.timeout == 600 + + def test_materialization_options_custom(self): + """Test MaterializationOptions with custom values.""" + progress_fn = Mock() + options = MaterializationOptions( + ignore_errors=False, + rebuild=True, + force=True, + materialize_embeddings=True, + on_progress=progress_fn, + timeout=300, + ) + + assert options.ignore_errors is False + assert options.rebuild is True + assert options.force is True + assert options.materialize_embeddings is True + assert options.on_progress is progress_fn + assert options.timeout == 300 + + def test_materialization_result(self): + """Test MaterializationResult dataclass.""" + result = MaterializationResult( + status="success", + was_stale=True, + stale_reason="New files uploaded", + tables_materialized=["Entity", "Transaction"], + total_rows=1500, + execution_time_ms=3000.0, + message="Graph materialized successfully", + ) + + assert result.status == "success" + assert result.was_stale is True + assert result.stale_reason == "New files uploaded" + assert len(result.tables_materialized) == 2 + assert result.total_rows == 1500 + assert result.execution_time_ms == 3000.0 + assert result.success is True + assert result.error is None + + def test_materialization_result_with_error(self): + """Test MaterializationResult with error.""" + result = MaterializationResult( + status="failed", + was_stale=False, + stale_reason=None, + tables_materialized=[], + total_rows=0, + execution_time_ms=0, + message="Failed to materialize", + success=False, + error="Connection timeout", + ) + + assert result.success is False + assert result.error == "Connection timeout" + + def test_materialization_status(self): + """Test MaterializationStatus dataclass.""" + status = MaterializationStatus( + graph_id="graph-123", + is_stale=True, + stale_reason="Files uploaded since last materialization", + stale_since="2025-01-15T10:00:00Z", + last_materialized_at="2025-01-14T08:00:00Z", + materialization_count=5, + hours_since_materialization=26.0, + message="Graph is stale", + ) + + assert status.graph_id == "graph-123" + assert status.is_stale is True + assert status.materialization_count == 5 + assert status.hours_since_materialization == 26.0 + + +@pytest.mark.unit +class TestMaterializationClientInit: + """Test suite for MaterializationClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = MaterializationClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client._operation_client is None + + def test_operation_client_lazy_creation(self, mock_config): + """Test that operation client is created lazily.""" + client = MaterializationClient(mock_config) + + assert client._operation_client is None + op_client = client.operation_client + assert op_client is not None + # Second access returns same instance + assert client.operation_client is op_client + + +@pytest.mark.unit +class TestMaterialize: + """Test suite for MaterializationClient.materialize method.""" + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_success(self, mock_mat, mock_config, graph_id): + """Test successful materialization.""" + # Mock initial response with operation_id + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock(operation_id="op-mat-123") + mock_mat.return_value = mock_resp + + # Mock the operation client monitoring + op_result = OperationResult( + operation_id="op-mat-123", + status=OperationStatus.COMPLETED, + result={ + "was_stale": True, + "stale_reason": "New files", + "tables_materialized": ["Entity"], + "total_rows": 100, + "execution_time_ms": 2000.0, + "message": "Done", + }, + execution_time_ms=2000.0, + ) + + client = MaterializationClient(mock_config) + client._operation_client = Mock() + client._operation_client.monitor_operation.return_value = op_result + + result = client.materialize(graph_id) + + assert result.success is True + assert result.status == "success" + assert result.tables_materialized == ["Entity"] + assert result.total_rows == 100 + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_api_failure(self, mock_mat, mock_config, graph_id): + """Test materialization when API returns error.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_resp.parsed = None + mock_resp.content = b'{"detail": "Internal error"}' + mock_mat.return_value = mock_resp + + client = MaterializationClient(mock_config) + result = client.materialize(graph_id) + + assert result.success is False + assert result.status == "failed" + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_operation_failed(self, mock_mat, mock_config, graph_id): + """Test materialization when operation fails.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock(operation_id="op-fail") + mock_mat.return_value = mock_resp + + op_result = OperationResult( + operation_id="op-fail", + status=OperationStatus.FAILED, + error="Dagster job failed", + execution_time_ms=1000.0, + ) + + client = MaterializationClient(mock_config) + client._operation_client = Mock() + client._operation_client.monitor_operation.return_value = op_result + + result = client.materialize(graph_id) + + assert result.success is False + assert result.error == "Dagster job failed" + + def test_materialize_no_token(self, mock_config, graph_id): + """Test materialize fails without API key.""" + mock_config["token"] = None + client = MaterializationClient(mock_config) + + result = client.materialize(graph_id) + + assert result.success is False + assert "No API key" in result.error + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_with_progress(self, mock_mat, mock_config, graph_id): + """Test materialize calls progress callback.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock(operation_id="op-progress") + mock_mat.return_value = mock_resp + + op_result = OperationResult( + operation_id="op-progress", + status=OperationStatus.COMPLETED, + result={ + "was_stale": False, + "tables_materialized": [], + "total_rows": 0, + "execution_time_ms": 500.0, + "message": "Already up to date", + }, + execution_time_ms=500.0, + ) + + progress_messages = [] + options = MaterializationOptions( + on_progress=lambda msg: progress_messages.append(msg) + ) + + client = MaterializationClient(mock_config) + client._operation_client = Mock() + client._operation_client.monitor_operation.return_value = op_result + + client.materialize(graph_id, options) + + assert len(progress_messages) >= 2 # At least "Submitting" and "queued" + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_with_rebuild(self, mock_mat, mock_config, graph_id): + """Test materialize passes rebuild option.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock(operation_id="op-rebuild") + mock_mat.return_value = mock_resp + + op_result = OperationResult( + operation_id="op-rebuild", + status=OperationStatus.COMPLETED, + result={ + "was_stale": True, + "tables_materialized": ["Entity"], + "total_rows": 50, + "execution_time_ms": 5000.0, + "message": "Rebuilt", + }, + execution_time_ms=5000.0, + ) + + client = MaterializationClient(mock_config) + client._operation_client = Mock() + client._operation_client.monitor_operation.return_value = op_result + + options = MaterializationOptions(rebuild=True, force=True) + result = client.materialize(graph_id, options) + + assert result.success is True + # Verify the request body had rebuild=True + call_kwargs = mock_mat.call_args[1] + assert call_kwargs["body"].rebuild is True + assert call_kwargs["body"].force is True + + +@pytest.mark.unit +class TestMaterializationStatus: + """Test suite for MaterializationClient.status method.""" + + @patch( + "robosystems_client.extensions.materialization_client.get_materialization_status" + ) + def test_get_status(self, mock_status, mock_config, graph_id): + """Test getting materialization status.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock( + graph_id=graph_id, + is_stale=True, + stale_reason="New files uploaded", + stale_since="2025-01-15T10:00:00Z", + last_materialized_at="2025-01-14T08:00:00Z", + materialization_count=3, + hours_since_materialization=26.0, + message="Graph is stale", + ) + mock_status.return_value = mock_resp + + client = MaterializationClient(mock_config) + status = client.status(graph_id) + + assert status is not None + assert status.is_stale is True + assert status.materialization_count == 3 + + @patch( + "robosystems_client.extensions.materialization_client.get_materialization_status" + ) + def test_get_status_failure(self, mock_status, mock_config, graph_id): + """Test status returns None on failure.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_resp.parsed = None + mock_status.return_value = mock_resp + + client = MaterializationClient(mock_config) + status = client.status(graph_id) + + assert status is None + + def test_status_no_token(self, mock_config, graph_id): + """Test status returns None without token.""" + mock_config["token"] = None + client = MaterializationClient(mock_config) + status = client.status(graph_id) + + assert status is None diff --git a/tests/test_operation_client_ops.py b/tests/test_operation_client_ops.py new file mode 100644 index 0000000..115ddc4 --- /dev/null +++ b/tests/test_operation_client_ops.py @@ -0,0 +1,350 @@ +"""Unit tests for OperationClient operational logic. + +Covers: monitor_operation (SSE event handling, completion, error, cancel, +timeout, progress callbacks, queue updates), get_operation_status, +cancel_operation, close_operation, close_all. + +Dataclass and enum tests already exist in tests/test_operation_client.py. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from robosystems_client.extensions.operation_client import ( + OperationClient, + OperationStatus, + OperationProgress, + MonitorOptions, +) +from robosystems_client.extensions.sse_client import SSEClient + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _fire_events(sse_client_instance, events): + """Simulate SSE events by directly calling registered listeners. + + Each event is a (event_type, data) tuple. + """ + for event_type, data in events: + sse_client_instance.emit(event_type, data) + + +# ── monitor_operation ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestMonitorOperation: + """Test OperationClient.monitor_operation via mocked SSE.""" + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_completed(self, mock_sleep, MockSSE, mock_config): + """Test monitoring an operation that completes successfully.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + # Simulate: started → progress → completed + listeners["operation_started"]({"agent": "financial"}) + listeners["operation_progress"]({"message": "Processing", "percentage": 50}) + listeners["operation_completed"]( + {"result": {"rows": 100}, "execution_time_ms": 2000} + ) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + result = client.monitor_operation("op-123") + + assert result.status == OperationStatus.COMPLETED + assert result.result == {"rows": 100} + assert result.execution_time_ms == 2000 + assert len(result.progress) == 1 + assert result.progress[0].message == "Processing" + assert result.progress[0].percentage == 50 + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_failed(self, mock_sleep, MockSSE, mock_config): + """Test monitoring an operation that fails.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_error"]({"message": "Database connection lost"}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + result = client.monitor_operation("op-fail") + + assert result.status == OperationStatus.FAILED + assert result.error == "Database connection lost" + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_cancelled(self, mock_sleep, MockSSE, mock_config): + """Test monitoring a cancelled operation.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_cancelled"]() + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + result = client.monitor_operation("op-cancel") + + assert result.status == OperationStatus.CANCELLED + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_with_progress_callback(self, mock_sleep, MockSSE, mock_config): + """Test that progress callback is invoked.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_progress"]( + {"message": "Step 1", "percentage": 25, "current_step": 1, "total_steps": 4} + ) + listeners["operation_progress"]( + {"message": "Step 2", "percentage": 50, "current_step": 2, "total_steps": 4} + ) + listeners["operation_completed"]({"result": {}}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + progress_updates = [] + options = MonitorOptions(on_progress=lambda p: progress_updates.append(p)) + + client = OperationClient(mock_config) + client.monitor_operation("op-progress", options) + + assert len(progress_updates) == 2 + assert isinstance(progress_updates[0], OperationProgress) + assert progress_updates[0].message == "Step 1" + assert progress_updates[1].current_step == 2 + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_with_queue_update(self, mock_sleep, MockSSE, mock_config): + """Test queue update callback.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["queue_update"]({"position": 3, "estimated_wait_seconds": 15}) + listeners["operation_completed"]({"result": {}}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + queue_updates = [] + options = MonitorOptions( + on_queue_update=lambda pos, wait: queue_updates.append((pos, wait)) + ) + + client = OperationClient(mock_config) + result = client.monitor_operation("op-queue", options) + + assert queue_updates == [(3, 15)] + assert result.status == OperationStatus.COMPLETED + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_error_uses_fallback_key(self, mock_sleep, MockSSE, mock_config): + """Test error event falls back to 'error' key when 'message' absent.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_error"]({"error": "Timeout exceeded"}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + result = client.monitor_operation("op-err") + + assert result.status == OperationStatus.FAILED + assert result.error == "Timeout exceeded" + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_cleanup_on_completion(self, mock_sleep, MockSSE, mock_config): + """Test that SSE client is cleaned up after completion.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_completed"]({"result": {}}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + client.monitor_operation("op-cleanup") + + # Operation should be removed from active_operations + assert "op-cleanup" not in client.active_operations + # SSE client should have been closed + fake_sse.close.assert_called() + + +# ── get_operation_status ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestGetOperationStatus: + """Test OperationClient.get_operation_status.""" + + @patch("robosystems_client.api.operations.get_operation_status.sync_detailed") + def test_get_status_success(self, mock_get, mock_config): + """Test successful status retrieval.""" + mock_resp = Mock() + mock_resp.parsed = Mock() + mock_resp.parsed.status = "running" + mock_resp.parsed.progress = 50 + mock_resp.parsed.result = None + mock_resp.parsed.error = None + mock_get.return_value = mock_resp + + client = OperationClient(mock_config) + status = client.get_operation_status("op-status") + + assert status["status"] == "running" + assert status["progress"] == 50 + + @patch("robosystems_client.api.operations.get_operation_status.sync_detailed") + def test_get_status_error(self, mock_get, mock_config): + """Test status retrieval on error.""" + mock_get.side_effect = Exception("Network error") + + client = OperationClient(mock_config) + status = client.get_operation_status("op-err") + + assert status["status"] == "error" + assert "Network error" in status["error"] + + @patch("robosystems_client.api.operations.get_operation_status.sync_detailed") + def test_get_status_no_parsed(self, mock_get, mock_config): + """Test status when response has no parsed data.""" + mock_resp = Mock() + mock_resp.parsed = None + mock_get.return_value = mock_resp + + client = OperationClient(mock_config) + status = client.get_operation_status("op-none") + + assert status["status"] == "unknown" + + +# ── cancel_operation ───────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCancelOperation: + """Test OperationClient.cancel_operation.""" + + @patch("robosystems_client.api.operations.cancel_operation.sync_detailed") + def test_cancel_success(self, mock_cancel, mock_config): + """Test successful cancellation.""" + mock_resp = Mock() + mock_resp.parsed = Mock() + mock_resp.parsed.cancelled = True + mock_cancel.return_value = mock_resp + + client = OperationClient(mock_config) + result = client.cancel_operation("op-cancel") + + assert result is True + + @patch("robosystems_client.api.operations.cancel_operation.sync_detailed") + def test_cancel_failure(self, mock_cancel, mock_config): + """Test cancellation failure.""" + mock_cancel.side_effect = Exception("Not found") + + client = OperationClient(mock_config) + result = client.cancel_operation("op-missing") + + assert result is False + + +# ── close_operation / close_all ────────────────────────────────────── + + +@pytest.mark.unit +class TestCloseOperations: + """Test close_operation and close_all.""" + + def test_close_operation(self, mock_config): + """Test closing a specific operation monitor.""" + client = OperationClient(mock_config) + mock_sse = Mock() + client.active_operations["op-1"] = mock_sse + + client.close_operation("op-1") + + mock_sse.close.assert_called_once() + assert "op-1" not in client.active_operations + + def test_close_operation_nonexistent(self, mock_config): + """Test closing a nonexistent operation is a no-op.""" + client = OperationClient(mock_config) + client.close_operation("op-missing") # Should not raise + + def test_close_all_multiple(self, mock_config): + """Test closing all active operations.""" + client = OperationClient(mock_config) + mock_sse1 = Mock() + mock_sse2 = Mock() + client.active_operations["op-1"] = mock_sse1 + client.active_operations["op-2"] = mock_sse2 + + client.close_all() + + mock_sse1.close.assert_called_once() + mock_sse2.close.assert_called_once() + assert len(client.active_operations) == 0 diff --git a/tests/test_query_client_ops.py b/tests/test_query_client_ops.py new file mode 100644 index 0000000..f4e2f4f --- /dev/null +++ b/tests/test_query_client_ops.py @@ -0,0 +1,524 @@ +"""Unit tests for QueryClient operational logic. + +Covers: execute_query (sync dict response, attrs response, queued response, +NDJSON streaming, error handling), _parse_ndjson_response, +_wait_for_query_completion, query convenience, query_batch, stream_query. + +Dataclass tests already exist in tests/test_query_client.py. +""" + +import pytest +from unittest.mock import Mock, patch +from robosystems_client.extensions.query_client import ( + QueryClient, + QueryRequest, + QueryOptions, + QueryResult, + QueuedQueryError, +) + + +# ── execute_query — sync dict response ─────────────────────────────── + + +@pytest.mark.unit +class TestExecuteQuerySync: + """Test execute_query with immediate sync responses.""" + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_dict_response(self, mock_exec, mock_config, graph_id): + """Test execute_query with a dict response.""" + mock_resp = Mock() + mock_resp.parsed = { + "data": [{"name": "ACME"}], + "columns": ["name"], + "row_count": 1, + "execution_time_ms": 42, + "timestamp": "2025-01-15T10:00:00Z", + } + # No NDJSON headers + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n.name") + result = client.execute_query(graph_id, request) + + assert isinstance(result, QueryResult) + assert result.data == [{"name": "ACME"}] + assert result.columns == ["name"] + assert result.row_count == 1 + assert result.execution_time_ms == 42 + assert result.graph_id == graph_id + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_attrs_response(self, mock_exec, mock_config, graph_id): + """Test execute_query with an attrs object response.""" + mock_data_item = Mock() + mock_data_item.to_dict.return_value = {"name": "Beta Corp"} + + mock_parsed = Mock() + mock_parsed.data = [mock_data_item] + mock_parsed.columns = ["name"] + mock_parsed.row_count = 1 + mock_parsed.execution_time_ms = 30 + mock_parsed.timestamp = "2025-01-15T10:00:00Z" + + mock_resp = Mock() + mock_resp.parsed = mock_parsed + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n.name") + result = client.execute_query(graph_id, request) + + assert isinstance(result, QueryResult) + assert result.data == [{"name": "Beta Corp"}] + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_attrs_with_additional_properties(self, mock_exec, mock_config, graph_id): + """Test attrs response items using additional_properties fallback.""" + mock_data_item = Mock(spec=[]) + mock_data_item.additional_properties = {"ticker": "ACM"} + # No to_dict method + + mock_parsed = Mock() + mock_parsed.data = [mock_data_item] + mock_parsed.columns = ["ticker"] + mock_parsed.row_count = 1 + mock_parsed.execution_time_ms = 10 + mock_parsed.timestamp = "2025-01-15T10:00:00Z" + + mock_resp = Mock() + mock_resp.parsed = mock_parsed + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n.ticker") + result = client.execute_query(graph_id, request) + + assert result.data == [{"ticker": "ACM"}] + + +# ── execute_query — queued response ────────────────────────────────── + + +@pytest.mark.unit +class TestExecuteQueryQueued: + """Test execute_query with queued responses.""" + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_queued_max_wait_zero_raises(self, mock_exec, mock_config, graph_id): + """Test queued response with max_wait=0 raises QueuedQueryError.""" + mock_resp = Mock() + mock_resp.parsed = { + "status": "queued", + "operation_id": "op-q1", + "queue_position": 3, + "estimated_wait_seconds": 15, + "message": "Query queued", + } + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + options = QueryOptions(max_wait=0) + + with pytest.raises(QueuedQueryError) as exc_info: + client.execute_query(graph_id, request, options) + + assert exc_info.value.queue_info.operation_id == "op-q1" + assert exc_info.value.queue_info.queue_position == 3 + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_queued_calls_queue_update_callback(self, mock_exec, mock_config, graph_id): + """Test queued response invokes on_queue_update.""" + mock_resp = Mock() + mock_resp.parsed = { + "status": "queued", + "operation_id": "op-q2", + "queue_position": 5, + "estimated_wait_seconds": 30, + "message": "Queued", + } + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + queue_updates = [] + options = QueryOptions( + max_wait=0, + on_queue_update=lambda pos, wait: queue_updates.append((pos, wait)), + ) + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + + with pytest.raises(QueuedQueryError): + client.execute_query(graph_id, request, options) + + assert queue_updates == [(5, 30)] + + +# ── execute_query — error handling ─────────────────────────────────── + + +@pytest.mark.unit +class TestExecuteQueryErrors: + """Test error handling in execute_query.""" + + def test_no_token_raises(self, mock_config, graph_id): + """Test execute_query fails without token.""" + mock_config["token"] = None + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + + with pytest.raises(Exception, match="Authentication failed|No API key"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_auth_error_wrapped(self, mock_exec, mock_config, graph_id): + """Test 401/403 errors are wrapped as auth errors.""" + mock_exec.side_effect = Exception("401 Unauthorized") + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + + with pytest.raises(Exception, match="Authentication failed"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_generic_error_wrapped(self, mock_exec, mock_config, graph_id): + """Test generic errors are wrapped.""" + mock_exec.side_effect = Exception("Connection timeout") + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + + with pytest.raises(Exception, match="Query execution failed"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_http_error_response(self, mock_exec, mock_config, graph_id): + """Test HTTP 4xx/5xx with error body.""" + mock_resp = Mock() + mock_resp.parsed = None + mock_resp.status_code = 400 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.content = b'{"detail": "Invalid query syntax"}' + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="BAD QUERY") + + with pytest.raises(Exception, match="Invalid query syntax"): + client.execute_query(graph_id, request) + + +# ── NDJSON parsing ─────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestNDJSON: + """Test _parse_ndjson_response.""" + + def test_single_chunk(self, mock_config, graph_id): + """Test parsing a single NDJSON chunk.""" + content = '{"columns": ["name"], "rows": [["ACME"]], "execution_time_ms": 10}\n' + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.columns == ["name"] + assert result.data == [["ACME"]] + assert result.row_count == 1 + assert result.execution_time_ms == 10 + + def test_multiple_chunks(self, mock_config, graph_id): + """Test parsing multiple NDJSON chunks.""" + content = ( + '{"columns": ["id"], "rows": [[1], [2]], "execution_time_ms": 5}\n' + '{"rows": [[3], [4]], "execution_time_ms": 12}\n' + ) + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.columns == ["id"] + assert result.data == [[1], [2], [3], [4]] + assert result.row_count == 4 + assert result.execution_time_ms == 12 # max of 5, 12 + + def test_data_key_fallback(self, mock_config, graph_id): + """Test NDJSON with 'data' key instead of 'rows'.""" + content = '{"columns": ["x"], "data": [{"x": 1}], "execution_time_ms": 3}\n' + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.data == [{"x": 1}] + + def test_empty_lines_skipped(self, mock_config, graph_id): + """Test that empty lines in NDJSON are skipped.""" + content = '{"columns": ["a"], "rows": [[1]]}\n\n\n' + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.row_count == 1 + + def test_invalid_json_raises(self, mock_config, graph_id): + """Test invalid JSON in NDJSON raises error.""" + content = "not json at all\n" + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + + with pytest.raises(Exception, match="Failed to parse NDJSON"): + client._parse_ndjson_response(mock_resp, graph_id) + + def test_no_columns_defaults_to_empty(self, mock_config, graph_id): + """Test NDJSON without columns field defaults to empty list.""" + content = '{"rows": [[1]]}\n' + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.columns == [] + + +# ── NDJSON detection in execute_query ──────────────────────────────── + + +@pytest.mark.unit +class TestNDJSONDetection: + """Test that NDJSON responses are detected and parsed.""" + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_ndjson_content_type_detected(self, mock_exec, mock_config, graph_id): + """Test NDJSON response detected by content-type header.""" + mock_resp = Mock() + mock_resp.parsed = None + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/x-ndjson"} + mock_resp.content = b'{"columns": ["n"], "rows": [["x"]]}\n' + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + result = client.execute_query(graph_id, request) + + assert isinstance(result, QueryResult) + assert result.data == [["x"]] + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_ndjson_stream_format_header(self, mock_exec, mock_config, graph_id): + """Test NDJSON detected by x-stream-format header.""" + mock_resp = Mock() + mock_resp.parsed = None + mock_resp.status_code = 200 + mock_resp.headers = { + "content-type": "application/json", + "x-stream-format": "ndjson", + } + mock_resp.content = b'{"columns": ["a"], "rows": [[1]]}\n' + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + result = client.execute_query(graph_id, request) + + assert result.data == [[1]] + + +# ── query convenience ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQueryConvenience: + """Test query() convenience method.""" + + @patch.object(QueryClient, "execute_query") + def test_query_returns_result(self, mock_exec, mock_config, graph_id): + """Test query() returns QueryResult directly.""" + mock_exec.return_value = QueryResult( + data=[{"n": 1}], columns=["n"], row_count=1, execution_time_ms=5 + ) + + client = QueryClient(mock_config) + result = client.query(graph_id, "MATCH (n) RETURN n") + + assert result.row_count == 1 + + @patch.object(QueryClient, "execute_query") + def test_query_collects_iterator(self, mock_exec, mock_config, graph_id): + """Test query() collects iterator results into QueryResult.""" + mock_exec.return_value = iter([{"n": 1}, {"n": 2}]) + + client = QueryClient(mock_config) + result = client.query(graph_id, "MATCH (n) RETURN n") + + assert result.row_count == 2 + assert result.data == [{"n": 1}, {"n": 2}] + + @patch.object(QueryClient, "execute_query") + def test_query_passes_parameters(self, mock_exec, mock_config, graph_id): + """Test query() forwards parameters.""" + mock_exec.return_value = QueryResult( + data=[], columns=[], row_count=0, execution_time_ms=0 + ) + + client = QueryClient(mock_config) + client.query(graph_id, "MATCH (n) WHERE n.x > $v RETURN n", {"v": 10}) + + call_args = mock_exec.call_args + assert call_args[0][1].parameters == {"v": 10} + + +# ── query_batch ────────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQueryBatch: + """Test query_batch method.""" + + @patch.object(QueryClient, "query") + def test_batch_success(self, mock_query, mock_config, graph_id): + """Test batch execution of multiple queries.""" + mock_query.side_effect = [ + QueryResult( + data=[{"count": 5}], columns=["count"], row_count=1, execution_time_ms=10 + ), + QueryResult( + data=[{"count": 3}], columns=["count"], row_count=1, execution_time_ms=8 + ), + ] + + client = QueryClient(mock_config) + results = client.query_batch( + graph_id, + ["MATCH (n:Person) RETURN count(n)", "MATCH (c:Company) RETURN count(c)"], + ) + + assert len(results) == 2 + assert all(isinstance(r, QueryResult) for r in results) + + @patch.object(QueryClient, "query") + def test_batch_with_error(self, mock_query, mock_config, graph_id): + """Test batch handles individual query failures.""" + mock_query.side_effect = [ + QueryResult(data=[], columns=[], row_count=0, execution_time_ms=0), + Exception("Query 2 failed"), + ] + + client = QueryClient(mock_config) + results = client.query_batch(graph_id, ["query1", "query2"]) + + assert isinstance(results[0], QueryResult) + assert isinstance(results[1], dict) + assert "error" in results[1] + + def test_batch_length_mismatch(self, mock_config, graph_id): + """Test batch raises on mismatched query/params lengths.""" + client = QueryClient(mock_config) + + with pytest.raises(ValueError, match="same length"): + client.query_batch(graph_id, ["q1", "q2"], [{"a": 1}]) + + @patch.object(QueryClient, "query") + def test_batch_with_parameters(self, mock_query, mock_config, graph_id): + """Test batch passes parameters per query.""" + mock_query.return_value = QueryResult( + data=[], columns=[], row_count=0, execution_time_ms=0 + ) + + client = QueryClient(mock_config) + client.query_batch( + graph_id, + ["MATCH (n) WHERE n.x > $v RETURN n", "MATCH (n) WHERE n.y = $w RETURN n"], + [{"v": 10}, {"w": "abc"}], + ) + + assert mock_query.call_count == 2 + assert mock_query.call_args_list[0][0][2] == {"v": 10} + assert mock_query.call_args_list[1][0][2] == {"w": "abc"} + + +# ── stream_query ───────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestStreamQuery: + """Test stream_query method.""" + + @patch.object(QueryClient, "execute_query") + def test_stream_yields_from_iterator(self, mock_exec, mock_config, graph_id): + """Test stream_query yields items from iterator result.""" + mock_exec.return_value = iter([{"id": 1}, {"id": 2}, {"id": 3}]) + + client = QueryClient(mock_config) + items = list(client.stream_query(graph_id, "MATCH (n) RETURN n")) + + assert items == [{"id": 1}, {"id": 2}, {"id": 3}] + + @patch.object(QueryClient, "execute_query") + def test_stream_yields_from_query_result(self, mock_exec, mock_config, graph_id): + """Test stream_query yields items when result is QueryResult.""" + mock_exec.return_value = QueryResult( + data=[{"id": 1}, {"id": 2}], + columns=["id"], + row_count=2, + execution_time_ms=10, + ) + + client = QueryClient(mock_config) + items = list(client.stream_query(graph_id, "MATCH (n) RETURN n")) + + assert items == [{"id": 1}, {"id": 2}] + + @patch.object(QueryClient, "execute_query") + def test_stream_calls_progress(self, mock_exec, mock_config, graph_id): + """Test stream_query calls on_progress callback.""" + mock_exec.return_value = QueryResult( + data=[{"id": i} for i in range(5)], + columns=["id"], + row_count=5, + execution_time_ms=10, + ) + + progress_calls = [] + client = QueryClient(mock_config) + list( + client.stream_query( + graph_id, + "MATCH (n) RETURN n", + on_progress=lambda cur, tot: progress_calls.append((cur, tot)), + ) + ) + + assert len(progress_calls) == 5 + assert progress_calls[-1] == (5, 5) diff --git a/tests/test_report_client.py b/tests/test_report_client.py new file mode 100644 index 0000000..e99e0d6 --- /dev/null +++ b/tests/test_report_client.py @@ -0,0 +1,302 @@ +"""Unit tests for ReportClient.""" + +import pytest +from http import HTTPStatus +from unittest.mock import Mock, patch +from robosystems_client.extensions.report_client import ReportClient + + +@pytest.mark.unit +class TestReportClientInit: + """Test suite for ReportClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = ReportClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.timeout == 60 + + def test_get_client_no_token(self, mock_config): + """Test _get_client raises without token.""" + mock_config["token"] = None + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="No API key"): + client._get_client() + + +@pytest.mark.unit +class TestReportCreate: + """Test suite for ReportClient.create method.""" + + @patch("robosystems_client.extensions.report_client.create_report") + def test_create_report(self, mock_create, mock_config, graph_id): + """Test creating a report.""" + mock_parsed = Mock() + mock_parsed.report_id = "rpt-123" + mock_parsed.report_name = "Q1 Report" + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_resp.parsed = mock_parsed + mock_create.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.create( + graph_id=graph_id, + name="Q1 Report", + mapping_id="map-1", + period_start="2025-01-01", + period_end="2025-03-31", + ) + + assert result.report_id == "rpt-123" + assert result.report_name == "Q1 Report" + + @patch("robosystems_client.extensions.report_client.create_report") + def test_create_report_with_options(self, mock_create, mock_config, graph_id): + """Test creating a report with all options.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_resp.parsed = Mock(report_id="rpt-full") + mock_create.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.create( + graph_id=graph_id, + name="Annual Report", + mapping_id="map-2", + period_start="2024-01-01", + period_end="2024-12-31", + taxonomy_id="tax_usgaap_reporting", + period_type="annual", + comparative=False, + ) + + assert result.report_id == "rpt-full" + mock_create.assert_called_once() + + @patch("robosystems_client.extensions.report_client.create_report") + def test_create_report_error(self, mock_create, mock_config, graph_id): + """Test create report raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_create.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Create report failed"): + client.create( + graph_id=graph_id, + name="Bad", + mapping_id="map-x", + period_start="bad", + period_end="bad", + ) + + +@pytest.mark.unit +class TestReportList: + """Test suite for ReportClient.list method.""" + + @patch("robosystems_client.extensions.report_client.list_reports") + def test_list_reports(self, mock_list, mock_config, graph_id): + """Test listing reports.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(reports=[Mock(), Mock()]) + mock_list.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.list(graph_id) + + assert len(result.reports) == 2 + + @patch("robosystems_client.extensions.report_client.list_reports") + def test_list_reports_error(self, mock_list, mock_config, graph_id): + """Test list reports raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_list.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="List reports failed"): + client.list(graph_id) + + +@pytest.mark.unit +class TestReportGet: + """Test suite for ReportClient.get method.""" + + @patch("robosystems_client.extensions.report_client.get_report") + def test_get_report(self, mock_get, mock_config, graph_id): + """Test getting a report.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(report_id="rpt-456", name="Q2 Report", structures=[Mock()]) + mock_get.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.get(graph_id, "rpt-456") + + assert result.report_id == "rpt-456" + assert len(result.structures) == 1 + + @patch("robosystems_client.extensions.report_client.get_report") + def test_get_report_error(self, mock_get, mock_config, graph_id): + """Test get report raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Get report failed"): + client.get(graph_id, "nonexistent") + + +@pytest.mark.unit +class TestReportStatement: + """Test suite for ReportClient.statement method.""" + + @patch("robosystems_client.extensions.report_client.get_statement") + def test_get_statement(self, mock_stmt, mock_config, graph_id): + """Test getting a financial statement.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + structure_type="income_statement", line_items=[Mock(), Mock()] + ) + mock_stmt.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.statement(graph_id, "rpt-123", "income_statement") + + assert result.structure_type == "income_statement" + assert len(result.line_items) == 2 + + @patch("robosystems_client.extensions.report_client.get_statement") + def test_get_statement_error(self, mock_stmt, mock_config, graph_id): + """Test statement raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_stmt.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Get statement failed"): + client.statement(graph_id, "rpt-bad", "income_statement") + + +@pytest.mark.unit +class TestReportRegenerate: + """Test suite for ReportClient.regenerate method.""" + + @patch("robosystems_client.extensions.report_client.regenerate_report") + def test_regenerate_report(self, mock_regen, mock_config, graph_id): + """Test regenerating a report.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(report_id="rpt-regen") + mock_regen.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.regenerate( + graph_id, "rpt-regen", period_start="2025-04-01", period_end="2025-06-30" + ) + + assert result.report_id == "rpt-regen" + + @patch("robosystems_client.extensions.report_client.regenerate_report") + def test_regenerate_report_error(self, mock_regen, mock_config, graph_id): + """Test regenerate raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_regen.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Regenerate report failed"): + client.regenerate(graph_id, "rpt-bad", "bad", "bad") + + +@pytest.mark.unit +class TestReportDelete: + """Test suite for ReportClient.delete method.""" + + @patch("robosystems_client.extensions.report_client.delete_report") + def test_delete_report(self, mock_delete, mock_config, graph_id): + """Test deleting a report.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NO_CONTENT + mock_delete.return_value = mock_resp + + client = ReportClient(mock_config) + client.delete(graph_id, "rpt-del") # Should not raise + + mock_delete.assert_called_once() + + @patch("robosystems_client.extensions.report_client.delete_report") + def test_delete_report_error(self, mock_delete, mock_config, graph_id): + """Test delete raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_delete.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Delete report failed"): + client.delete(graph_id, "rpt-err") + + +@pytest.mark.unit +class TestReportShare: + """Test suite for ReportClient.share method.""" + + @patch("robosystems_client.extensions.report_client.share_report") + def test_share_report(self, mock_share, mock_config, graph_id): + """Test sharing a report.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(shared_count=3) + mock_share.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.share(graph_id, "rpt-share", "pub-list-1") + + assert result.shared_count == 3 + + @patch("robosystems_client.extensions.report_client.share_report") + def test_share_report_error(self, mock_share, mock_config, graph_id): + """Test share raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_share.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Share report failed"): + client.share(graph_id, "rpt-bad", "pub-bad") + + +@pytest.mark.unit +class TestReportIsShared: + """Test suite for ReportClient.is_shared_report method.""" + + def test_is_shared_report_true(self, mock_config): + """Test detecting a shared report.""" + client = ReportClient(mock_config) + report = Mock(source_graph_id="other-graph-123") + + assert client.is_shared_report(report) is True + + def test_is_shared_report_false(self, mock_config): + """Test detecting a non-shared report.""" + client = ReportClient(mock_config) + report = Mock(spec=[]) # No source_graph_id attribute + + assert client.is_shared_report(report) is False diff --git a/tests/test_sse_client.py b/tests/test_sse_client.py new file mode 100644 index 0000000..f27f209 --- /dev/null +++ b/tests/test_sse_client.py @@ -0,0 +1,503 @@ +"""Unit tests for SSEClient operational logic. + +Covers: connection, event processing, reconnection/retry, heartbeat, +completion-event auto-close, and close/cleanup. Dataclass and listener +tests already exist in extensions/tests/test_unit.py. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from robosystems_client.extensions.sse_client import ( + SSEClient, + SSEConfig, +) + + +@pytest.fixture +def sse_config(): + return SSEConfig( + base_url="http://localhost:8000", + headers={"X-API-Key": "test-key"}, + max_retries=3, + retry_delay=100, + timeout=10, + ) + + +@pytest.fixture +def sse_client(sse_config): + return SSEClient(sse_config) + + +# ── Event processing ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestProcessEvents: + """Test _process_events with mocked SSE streams.""" + + def _make_client_with_lines(self, sse_config, lines): + """Create an SSEClient with a fake response that yields lines.""" + client = SSEClient(sse_config) + mock_response = Mock() + mock_response.iter_lines.return_value = iter(lines) + client._response = mock_response + client.closed = False + return client + + def test_simple_event(self, sse_config): + """Test processing a single complete event.""" + lines = [ + "event: operation_progress", + 'data: {"message": "Working"}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_progress", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + assert events[0] == {"message": "Working"} + + def test_multiline_data(self, sse_config): + """Test multiline data is joined with newlines.""" + lines = [ + "event: data_chunk", + 'data: {"rows": [', + 'data: {"id": 1}', + "data: ]}", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("data_chunk", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + assert events[0] == {"rows": [{"id": 1}]} + + def test_comment_lines_skipped(self, sse_config): + """Test that comment lines (starting with :) are skipped.""" + lines = [ + ": this is a comment", + "event: operation_progress", + 'data: {"message": "hello"}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_progress", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + + def test_id_field_tracked(self, sse_config): + """Test that id field is tracked for reconnection.""" + lines = [ + "event: operation_progress", + 'data: {"message": "step 1"}', + "id: 42", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + client._process_events() + + assert client.last_event_id == "42" + + def test_retry_field_parsed(self, sse_config): + """Test retry field is parsed from events.""" + lines = [ + "event: operation_progress", + 'data: {"message": "retry test"}', + "retry: 5000", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + # The retry value is parsed in the event buffer but not stored on client + # Just verify it doesn't crash + client._process_events() + + def test_invalid_retry_ignored(self, sse_config): + """Test invalid retry value is silently ignored.""" + lines = [ + "event: operation_progress", + 'data: {"message": "test"}', + "retry: not-a-number", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + client._process_events() # Should not raise + + def test_multiple_events(self, sse_config): + """Test processing multiple events in sequence.""" + lines = [ + "event: operation_progress", + 'data: {"message": "step 1"}', + "", + "event: operation_progress", + 'data: {"message": "step 2"}', + "", + "event: operation_completed", + 'data: {"result": {"done": true}}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + progress = [] + completed = [] + client.on("operation_progress", lambda d: progress.append(d)) + client.on("operation_completed", lambda d: completed.append(d)) + + client._process_events() + + assert len(progress) == 2 + assert len(completed) == 1 + + def test_default_event_type_is_message(self, sse_config): + """Test events without event: field default to 'message'.""" + lines = [ + 'data: {"hello": "world"}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("message", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + assert events[0] == {"hello": "world"} + + def test_non_json_data_kept_as_string(self, sse_config): + """Test that non-JSON data is kept as a string.""" + lines = [ + "event: operation_progress", + "data: plain text message", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_progress", lambda d: events.append(d)) + + client._process_events() + + assert events[0] == "plain text message" + + def test_final_event_without_trailing_newline(self, sse_config): + """Test that a final event is dispatched even without trailing empty line.""" + lines = [ + "event: operation_completed", + 'data: {"result": {}}', + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_completed", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + + def test_closed_flag_stops_processing(self, sse_config): + """Test that setting closed=True mid-stream stops processing.""" + call_count = 0 + + def close_on_second(data): + nonlocal call_count + call_count += 1 + + lines = [ + "event: operation_progress", + 'data: {"message": "first"}', + "", + "event: operation_progress", + 'data: {"message": "second"}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + client.on("operation_progress", close_on_second) + + # Close after first event dispatch + original_dispatch = client._dispatch_event + + def dispatch_and_close(buf): + original_dispatch(buf) + client.closed = True + + client._dispatch_event = dispatch_and_close + client._process_events() + + assert call_count == 1 + + def test_data_field_with_no_value(self, sse_config): + """Test 'data' line without colon appends empty string.""" + lines = [ + "event: operation_progress", + "data", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_progress", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + assert events[0] == "" # Empty string, not JSON + + +# ── Completion auto-close ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestCompletionAutoClose: + """Test that terminal events set closed flag.""" + + def _dispatch(self, sse_config, event_type): + client = SSEClient(sse_config) + buf = {"event": event_type, "data": ["{}"], "id": None, "retry": None} + client._dispatch_event(buf) + return client.closed + + def test_completed_sets_closed(self, sse_config): + assert self._dispatch(sse_config, "operation_completed") is True + + def test_error_sets_closed(self, sse_config): + assert self._dispatch(sse_config, "operation_error") is True + + def test_cancelled_sets_closed(self, sse_config): + assert self._dispatch(sse_config, "operation_cancelled") is True + + def test_progress_does_not_close(self, sse_config): + assert self._dispatch(sse_config, "operation_progress") is False + + def test_data_chunk_does_not_close(self, sse_config): + assert self._dispatch(sse_config, "data_chunk") is False + + +# ── Reconnection / retry ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestReconnection: + """Test _handle_error retry logic.""" + + @patch("robosystems_client.extensions.sse_client.time.sleep") + @patch("robosystems_client.extensions.sse_client.httpx") + def test_retry_with_backoff(self, mock_httpx, mock_sleep, sse_config): + """Test exponential backoff on retry.""" + # Make httpx.Client always raise to trigger retry chain + mock_httpx.Client.side_effect = Exception("connection refused") + + client = SSEClient(sse_config) + + reconnect_events = [] + exceeded_events = [] + client.on("reconnecting", lambda d: reconnect_events.append(d)) + client.on("max_retries_exceeded", lambda d: exceeded_events.append(d)) + + # connect → fail → _handle_error → retry connect → fail → ... → max retries + client.connect("op-1") + + # Should have attempted max_retries reconnections + assert len(reconnect_events) == sse_config.max_retries + assert len(exceeded_events) == 1 + + # Verify backoff: delays should double each time + # retry_delay=100ms, so: 100ms, 200ms, 400ms → sleep(0.1), sleep(0.2), sleep(0.4) + sleep_calls = [call.args[0] for call in mock_sleep.call_args_list] + assert sleep_calls[0] == pytest.approx(0.1) + assert sleep_calls[1] == pytest.approx(0.2) + assert sleep_calls[2] == pytest.approx(0.4) + + def test_max_retries_exceeded_emits_event(self, sse_config): + """Test that exceeding max retries emits event and closes.""" + sse_config.max_retries = 0 # No retries allowed + client = SSEClient(sse_config) + + exceeded_events = [] + client.on("max_retries_exceeded", lambda d: exceeded_events.append(d)) + + client._handle_error(Exception("fail"), "op-1", 0) + + assert len(exceeded_events) == 1 + assert client.closed is True + + @patch("robosystems_client.extensions.sse_client.time.sleep") + def test_resume_from_last_event_id(self, mock_sleep, sse_config): + """Test reconnection resumes from last_event_id.""" + sse_config.max_retries = 1 + client = SSEClient(sse_config) + client.last_event_id = "10" + + connect_calls = [] + # Make connect succeed (no-op) to stop recursion + with patch.object( + client, "connect", side_effect=lambda op, seq: connect_calls.append((op, seq)) + ): + client._handle_error(Exception("lost"), "op-1", 0) + + assert connect_calls == [("op-1", 11)] # last_event_id + 1 + + @patch("robosystems_client.extensions.sse_client.time.sleep") + def test_resume_with_non_numeric_event_id(self, mock_sleep, sse_config): + """Test reconnection falls back to from_sequence for non-numeric IDs.""" + sse_config.max_retries = 1 + client = SSEClient(sse_config) + client.last_event_id = "not-a-number" + + connect_calls = [] + with patch.object( + client, "connect", side_effect=lambda op, seq: connect_calls.append((op, seq)) + ): + client._handle_error(Exception("lost"), "op-1", 5) + + assert connect_calls == [("op-1", 5)] # Falls back to from_sequence + + def test_no_retry_when_closed(self, sse_config): + """Test no retry when client is already closed.""" + client = SSEClient(sse_config) + client.closed = True + + reconnect_events = [] + client.on("reconnecting", lambda d: reconnect_events.append(d)) + + client._handle_error(Exception("fail"), "op-1", 0) + + assert len(reconnect_events) == 0 + + +# ── Connect ────────────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestConnect: + """Test connect method.""" + + @patch("robosystems_client.extensions.sse_client.httpx") + def test_connect_sets_up_stream(self, mock_httpx, sse_config): + """Test connect creates httpx client and starts streaming.""" + mock_http_client = MagicMock() + mock_context = MagicMock() + mock_response = MagicMock() + mock_response.iter_lines.return_value = iter( + [ + "event: operation_completed", + 'data: {"result": {}}', + "", + ] + ) + mock_context.__enter__ = Mock(return_value=mock_response) + mock_http_client.stream.return_value = mock_context + mock_httpx.Client.return_value = mock_http_client + + client = SSEClient(sse_config) + connected_events = [] + client.on("connected", lambda d: connected_events.append(d)) + + client.connect("op-123") + + assert len(connected_events) == 1 + assert client.reconnect_attempts == 0 + + @patch("robosystems_client.extensions.sse_client.httpx") + def test_connect_error_triggers_retry(self, mock_httpx, sse_config): + """Test that connection error triggers _handle_error.""" + sse_config.max_retries = 0 # Don't actually retry + mock_httpx.Client.side_effect = Exception("Connection refused") + + client = SSEClient(sse_config) + exceeded = [] + client.on("max_retries_exceeded", lambda d: exceeded.append(d)) + + client.connect("op-123") + + assert len(exceeded) == 1 + + +# ── Close / cleanup ───────────────────────────────────────────────── + + +@pytest.mark.unit +class TestClose: + """Test close and cleanup.""" + + def test_close_clears_listeners(self, sse_config): + """Test close removes all listeners.""" + client = SSEClient(sse_config) + client.on("test", lambda d: None) + assert len(client.listeners) > 0 + + client.close() + + assert len(client.listeners) == 0 + assert client.closed is True + + def test_close_emits_closed_event(self, sse_config): + """Test close emits 'closed' event before clearing listeners.""" + client = SSEClient(sse_config) + closed_events = [] + client.on("closed", lambda d: closed_events.append(d)) + + client.close() + + assert len(closed_events) == 1 + + def test_close_cleans_up_http_client(self, sse_config): + """Test close closes the httpx client.""" + client = SSEClient(sse_config) + mock_http = Mock() + client.client = mock_http + + client.close() + + mock_http.close.assert_called_once() + assert client.client is None + + def test_close_cleans_up_context_manager(self, sse_config): + """Test close exits the stream context manager.""" + client = SSEClient(sse_config) + mock_ctx = MagicMock() + client._context_manager = mock_ctx + + client.close() + + mock_ctx.__exit__.assert_called_once_with(None, None, None) + + def test_is_connected(self, sse_config): + """Test is_connected reflects state.""" + client = SSEClient(sse_config) + assert client.is_connected() is False + + client.client = Mock() + assert client.is_connected() is True + + client.closed = True + assert client.is_connected() is False + + +# ── Listener error handling ────────────────────────────────────────── + + +@pytest.mark.unit +class TestListenerErrors: + """Test that listener errors don't break other listeners.""" + + def test_error_in_listener_doesnt_stop_others(self, sse_config): + """Test that a failing listener doesn't prevent others from running.""" + client = SSEClient(sse_config) + + results = [] + client.on("test", lambda d: (_ for _ in ()).throw(ValueError("boom"))) + client.on("test", lambda d: results.append(d)) + + client.emit("test", "data") + + assert results == ["data"] + + def test_emit_to_nonexistent_event_is_noop(self, sse_config): + """Test emitting to an event with no listeners doesn't raise.""" + client = SSEClient(sse_config) + client.emit("nonexistent", "data") # Should not raise diff --git a/tests/test_table_client.py b/tests/test_table_client.py new file mode 100644 index 0000000..0188f98 --- /dev/null +++ b/tests/test_table_client.py @@ -0,0 +1,240 @@ +"""Unit tests for TableClient.""" + +import pytest +from unittest.mock import Mock, patch +from robosystems_client.extensions.table_client import ( + TableClient, + TableInfo, + QueryResult, +) + + +@pytest.mark.unit +class TestTableDataclasses: + """Test suite for table-related dataclasses.""" + + def test_table_info(self): + """Test TableInfo dataclass.""" + info = TableInfo( + table_name="Entity", + table_type="parquet", + row_count=100, + file_count=3, + total_size_bytes=50000, + ) + + assert info.table_name == "Entity" + assert info.table_type == "parquet" + assert info.row_count == 100 + assert info.file_count == 3 + assert info.total_size_bytes == 50000 + + def test_query_result(self): + """Test QueryResult dataclass.""" + result = QueryResult( + columns=["name", "revenue"], + rows=[["ACME", 1000000], ["Beta Corp", 2000000]], + row_count=2, + execution_time_ms=45.0, + ) + + assert result.columns == ["name", "revenue"] + assert len(result.rows) == 2 + assert result.row_count == 2 + assert result.execution_time_ms == 45.0 + assert result.success is True + assert result.error is None + + def test_query_result_with_error(self): + """Test QueryResult with error.""" + result = QueryResult( + columns=[], + rows=[], + row_count=0, + execution_time_ms=0, + success=False, + error="SQL syntax error", + ) + + assert result.success is False + assert result.error == "SQL syntax error" + + +@pytest.mark.unit +class TestTableClientInit: + """Test suite for TableClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = TableClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + + +@pytest.mark.unit +class TestTableList: + """Test suite for TableClient.list method.""" + + @patch("robosystems_client.extensions.table_client.list_tables") + def test_list_tables(self, mock_list, mock_config, graph_id): + """Test listing tables.""" + mock_table_1 = Mock() + mock_table_1.table_name = "Entity" + mock_table_1.table_type = "parquet" + mock_table_1.row_count = 100 + mock_table_1.file_count = 2 + mock_table_1.total_size_bytes = 50000 + + mock_table_2 = Mock() + mock_table_2.table_name = "Transaction" + mock_table_2.table_type = "parquet" + mock_table_2.row_count = 500 + mock_table_2.file_count = 1 + mock_table_2.total_size_bytes = 120000 + + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.tables = [mock_table_1, mock_table_2] + mock_list.return_value = mock_resp + + client = TableClient(mock_config) + tables = client.list(graph_id) + + assert len(tables) == 2 + assert tables[0].table_name == "Entity" + assert tables[0].row_count == 100 + assert tables[1].table_name == "Transaction" + assert tables[1].row_count == 500 + + @patch("robosystems_client.extensions.table_client.list_tables") + def test_list_tables_failure(self, mock_list, mock_config, graph_id): + """Test list tables returns empty list on failure.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_resp.parsed = None + mock_list.return_value = mock_resp + + client = TableClient(mock_config) + tables = client.list(graph_id) + + assert tables == [] + + def test_list_tables_no_token(self, mock_config, graph_id): + """Test list tables returns empty list without token.""" + mock_config["token"] = None + client = TableClient(mock_config) + tables = client.list(graph_id) + + assert tables == [] + + @patch("robosystems_client.extensions.table_client.list_tables") + def test_list_tables_empty(self, mock_list, mock_config, graph_id): + """Test listing when no tables exist.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.tables = [] + mock_list.return_value = mock_resp + + client = TableClient(mock_config) + tables = client.list(graph_id) + + assert tables == [] + + +@pytest.mark.unit +class TestTableQuery: + """Test suite for TableClient.query method.""" + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_success(self, mock_query, mock_config, graph_id): + """Test successful SQL query.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.columns = ["name", "ticker"] + mock_resp.parsed.rows = [["ACME Corp", "ACM"], ["Beta Inc", "BET"]] + mock_resp.parsed.execution_time_ms = 30.0 + mock_query.return_value = mock_resp + + client = TableClient(mock_config) + result = client.query(graph_id, "SELECT name, ticker FROM Entity") + + assert result.success is True + assert result.columns == ["name", "ticker"] + assert len(result.rows) == 2 + assert result.row_count == 2 + assert result.execution_time_ms == 30.0 + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_with_limit(self, mock_query, mock_config, graph_id): + """Test query with limit appended.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.columns = ["name"] + mock_resp.parsed.rows = [["ACME"]] + mock_resp.parsed.execution_time_ms = 10.0 + mock_query.return_value = mock_resp + + client = TableClient(mock_config) + client.query(graph_id, "SELECT name FROM Entity", limit=5) + + call_kwargs = mock_query.call_args[1] + assert "LIMIT 5" in call_kwargs["body"].sql + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_strips_trailing_semicolon(self, mock_query, mock_config, graph_id): + """Test query strips trailing semicolon before adding LIMIT.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.columns = ["name"] + mock_resp.parsed.rows = [] + mock_resp.parsed.execution_time_ms = 5.0 + mock_query.return_value = mock_resp + + client = TableClient(mock_config) + client.query(graph_id, "SELECT name FROM Entity;", limit=10) + + call_kwargs = mock_query.call_args[1] + sql = call_kwargs["body"].sql + assert not sql.endswith("; LIMIT 10") + assert sql.endswith("LIMIT 10") + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_failure(self, mock_query, mock_config, graph_id): + """Test query failure.""" + mock_resp = Mock() + mock_resp.status_code = 400 + mock_resp.parsed = None + mock_query.return_value = mock_resp + + client = TableClient(mock_config) + result = client.query(graph_id, "INVALID SQL") + + assert result.success is False + assert "Query failed" in result.error + + def test_query_no_token(self, mock_config, graph_id): + """Test query fails without token.""" + mock_config["token"] = None + client = TableClient(mock_config) + result = client.query(graph_id, "SELECT 1") + + assert result.success is False + assert "No API key" in result.error + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_exception_handling(self, mock_query, mock_config, graph_id): + """Test query handles unexpected exceptions.""" + mock_query.side_effect = Exception("Network error") + + client = TableClient(mock_config) + result = client.query(graph_id, "SELECT 1") + + assert result.success is False + assert "Network error" in result.error diff --git a/robosystems_client/extensions/tests/test_token_utils.py b/tests/test_token_utils.py similarity index 100% rename from robosystems_client/extensions/tests/test_token_utils.py rename to tests/test_token_utils.py diff --git a/robosystems_client/extensions/tests/test_unit.py b/tests/test_utils_unit.py similarity index 98% rename from robosystems_client/extensions/tests/test_unit.py rename to tests/test_utils_unit.py index 8a21f97..1cb6090 100644 --- a/robosystems_client/extensions/tests/test_unit.py +++ b/tests/test_utils_unit.py @@ -7,13 +7,6 @@ from datetime import datetime from unittest.mock import Mock -import sys -import os - -# Add parent directories to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../..")) -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - from robosystems_client.extensions import ( SSEClient, SSEConfig, From 064eabb9f40a873fa126946fdb629d1b3f0bec98 Mon Sep 17 00:00:00 2001 From: "Joseph T. French" Date: Fri, 10 Apr 2026 20:51:03 -0500 Subject: [PATCH 3/4] Refactor datetime column inference in parse_datetime_columns function - Updated the condition for identifying string columns in the `parse_datetime_columns` function to use `pd.api.types.is_string_dtype` instead of checking for the "object" dtype. This change improves the accuracy of datetime column inference by leveraging pandas' type checking capabilities. These modifications enhance the robustness of the datetime parsing logic in the dataframe utilities. --- robosystems_client/extensions/dataframe_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/robosystems_client/extensions/dataframe_utils.py b/robosystems_client/extensions/dataframe_utils.py index 8035ea4..a630b9b 100644 --- a/robosystems_client/extensions/dataframe_utils.py +++ b/robosystems_client/extensions/dataframe_utils.py @@ -105,7 +105,7 @@ def parse_datetime_columns( elif infer: # Infer datetime columns for col in df.columns: - if df[col].dtype == "object": + if pd.api.types.is_string_dtype(df[col]): # Check if column contains date-like strings sample = df[col].dropna().head(10) if len(sample) > 0: From f6a31ab525700eba3c9c04f5d20f8f94a75bc2db Mon Sep 17 00:00:00 2001 From: "Joseph T. French" Date: Fri, 10 Apr 2026 20:52:49 -0500 Subject: [PATCH 4/4] Update test for parse_datetime_columns to assert non-datetime dtype - Modified the test case in `test_dataframe_utils.py` to check that the "not_a_date" column is not of datetime type, enhancing the accuracy of the test for the `parse_datetime_columns` function. This change improves the reliability of the test suite by ensuring proper validation of non-datetime columns. --- tests/test_dataframe_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataframe_utils.py b/tests/test_dataframe_utils.py index 4aad57e..05d9b24 100644 --- a/tests/test_dataframe_utils.py +++ b/tests/test_dataframe_utils.py @@ -126,7 +126,7 @@ def test_parse_datetime_columns_infer(self): df = parse_datetime_columns(df, infer=True) assert pd.api.types.is_datetime64_any_dtype(df["timestamp"]) - assert df["not_a_date"].dtype == "object" + assert not pd.api.types.is_datetime64_any_dtype(df["not_a_date"]) class TestStreamToDataFrame: