diff --git a/robosystems_client/api/ledger/get_account_rollups.py b/robosystems_client/api/ledger/get_account_rollups.py new file mode 100644 index 0000000..1c92f73 --- /dev/null +++ b/robosystems_client/api/ledger/get_account_rollups.py @@ -0,0 +1,251 @@ +import datetime +from http import HTTPStatus +from typing import Any +from urllib.parse import quote + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.account_rollups_response import AccountRollupsResponse +from ...models.http_validation_error import HTTPValidationError +from ...types import UNSET, Response, Unset + + +def _get_kwargs( + graph_id: str, + *, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> dict[str, Any]: + params: dict[str, Any] = {} + + json_mapping_id: None | str | Unset + if isinstance(mapping_id, Unset): + json_mapping_id = UNSET + else: + json_mapping_id = mapping_id + params["mapping_id"] = json_mapping_id + + json_start_date: None | str | Unset + if isinstance(start_date, Unset): + json_start_date = UNSET + elif isinstance(start_date, datetime.date): + json_start_date = start_date.isoformat() + else: + json_start_date = start_date + params["start_date"] = json_start_date + + json_end_date: None | str | Unset + if isinstance(end_date, Unset): + json_end_date = UNSET + elif isinstance(end_date, datetime.date): + json_end_date = end_date.isoformat() + else: + json_end_date = end_date + params["end_date"] = json_end_date + + params = {k: v for k, v in params.items() if v is not UNSET and v is not None} + + _kwargs: dict[str, Any] = { + "method": "get", + "url": "/v1/ledger/{graph_id}/account-rollups".format( + graph_id=quote(str(graph_id), safe=""), + ), + "params": params, + } + + return _kwargs + + +def _parse_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> AccountRollupsResponse | HTTPValidationError | None: + if response.status_code == 200: + response_200 = AccountRollupsResponse.from_dict(response.json()) + + return response_200 + + if response.status_code == 422: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> Response[AccountRollupsResponse | HTTPValidationError]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + graph_id: str, + *, + client: AuthenticatedClient, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> Response[AccountRollupsResponse | HTTPValidationError]: + """Account Rollups + + Account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting + line items. Auto-discovers the mapping structure if mapping_id is not provided. + + Args: + graph_id (str): + mapping_id (None | str | Unset): Mapping structure ID (auto-discovers if omitted) + start_date (datetime.date | None | Unset): Start date (inclusive) + end_date (datetime.date | None | Unset): End date (inclusive) + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[AccountRollupsResponse | HTTPValidationError] + """ + + kwargs = _get_kwargs( + graph_id=graph_id, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + graph_id: str, + *, + client: AuthenticatedClient, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> AccountRollupsResponse | HTTPValidationError | None: + """Account Rollups + + Account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting + line items. Auto-discovers the mapping structure if mapping_id is not provided. + + Args: + graph_id (str): + mapping_id (None | str | Unset): Mapping structure ID (auto-discovers if omitted) + start_date (datetime.date | None | Unset): Start date (inclusive) + end_date (datetime.date | None | Unset): End date (inclusive) + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + AccountRollupsResponse | HTTPValidationError + """ + + return sync_detailed( + graph_id=graph_id, + client=client, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + ).parsed + + +async def asyncio_detailed( + graph_id: str, + *, + client: AuthenticatedClient, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> Response[AccountRollupsResponse | HTTPValidationError]: + """Account Rollups + + Account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting + line items. Auto-discovers the mapping structure if mapping_id is not provided. + + Args: + graph_id (str): + mapping_id (None | str | Unset): Mapping structure ID (auto-discovers if omitted) + start_date (datetime.date | None | Unset): Start date (inclusive) + end_date (datetime.date | None | Unset): End date (inclusive) + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[AccountRollupsResponse | HTTPValidationError] + """ + + kwargs = _get_kwargs( + graph_id=graph_id, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + graph_id: str, + *, + client: AuthenticatedClient, + mapping_id: None | str | Unset = UNSET, + start_date: datetime.date | None | Unset = UNSET, + end_date: datetime.date | None | Unset = UNSET, +) -> AccountRollupsResponse | HTTPValidationError | None: + """Account Rollups + + Account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting + line items. Auto-discovers the mapping structure if mapping_id is not provided. + + Args: + graph_id (str): + mapping_id (None | str | Unset): Mapping structure ID (auto-discovers if omitted) + start_date (datetime.date | None | Unset): Start date (inclusive) + end_date (datetime.date | None | Unset): End date (inclusive) + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + AccountRollupsResponse | HTTPValidationError + """ + + return ( + await asyncio_detailed( + graph_id=graph_id, + client=client, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + ) + ).parsed diff --git a/robosystems_client/api/ledger/get_closing_book_structures.py b/robosystems_client/api/ledger/get_closing_book_structures.py new file mode 100644 index 0000000..349ea45 --- /dev/null +++ b/robosystems_client/api/ledger/get_closing_book_structures.py @@ -0,0 +1,184 @@ +from http import HTTPStatus +from typing import Any +from urllib.parse import quote + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.closing_book_structures_response import ClosingBookStructuresResponse +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + graph_id: str, +) -> dict[str, Any]: + _kwargs: dict[str, Any] = { + "method": "get", + "url": "/v1/ledger/{graph_id}/closing-book/structures".format( + graph_id=quote(str(graph_id), safe=""), + ), + } + + return _kwargs + + +def _parse_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> ClosingBookStructuresResponse | HTTPValidationError | None: + if response.status_code == 200: + response_200 = ClosingBookStructuresResponse.from_dict(response.json()) + + return response_200 + + if response.status_code == 422: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> Response[ClosingBookStructuresResponse | HTTPValidationError]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + graph_id: str, + *, + client: AuthenticatedClient, +) -> Response[ClosingBookStructuresResponse | HTTPValidationError]: + """Closing Book Structures + + Returns all structure categories for the closing book sidebar. + + Aggregates statements (from latest report), schedules, account rollups + (from mapping structures), and trial balance availability into a single + response for the viewer sidebar navigation. + + Args: + graph_id (str): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[ClosingBookStructuresResponse | HTTPValidationError] + """ + + kwargs = _get_kwargs( + graph_id=graph_id, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + graph_id: str, + *, + client: AuthenticatedClient, +) -> ClosingBookStructuresResponse | HTTPValidationError | None: + """Closing Book Structures + + Returns all structure categories for the closing book sidebar. + + Aggregates statements (from latest report), schedules, account rollups + (from mapping structures), and trial balance availability into a single + response for the viewer sidebar navigation. + + Args: + graph_id (str): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + ClosingBookStructuresResponse | HTTPValidationError + """ + + return sync_detailed( + graph_id=graph_id, + client=client, + ).parsed + + +async def asyncio_detailed( + graph_id: str, + *, + client: AuthenticatedClient, +) -> Response[ClosingBookStructuresResponse | HTTPValidationError]: + """Closing Book Structures + + Returns all structure categories for the closing book sidebar. + + Aggregates statements (from latest report), schedules, account rollups + (from mapping structures), and trial balance availability into a single + response for the viewer sidebar navigation. + + Args: + graph_id (str): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[ClosingBookStructuresResponse | HTTPValidationError] + """ + + kwargs = _get_kwargs( + graph_id=graph_id, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + graph_id: str, + *, + client: AuthenticatedClient, +) -> ClosingBookStructuresResponse | HTTPValidationError | None: + """Closing Book Structures + + Returns all structure categories for the closing book sidebar. + + Aggregates statements (from latest report), schedules, account rollups + (from mapping structures), and trial balance availability into a single + response for the viewer sidebar navigation. + + Args: + graph_id (str): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + ClosingBookStructuresResponse | HTTPValidationError + """ + + return ( + await asyncio_detailed( + graph_id=graph_id, + client=client, + ) + ).parsed diff --git a/robosystems_client/extensions/dataframe_utils.py b/robosystems_client/extensions/dataframe_utils.py index 8035ea4..a630b9b 100644 --- a/robosystems_client/extensions/dataframe_utils.py +++ b/robosystems_client/extensions/dataframe_utils.py @@ -105,7 +105,7 @@ def parse_datetime_columns( elif infer: # Infer datetime columns for col in df.columns: - if df[col].dtype == "object": + if pd.api.types.is_string_dtype(df[col]): # Check if column contains date-like strings sample = df[col].dropna().head(10) if len(sample) > 0: diff --git a/robosystems_client/extensions/ledger_client.py b/robosystems_client/extensions/ledger_client.py index 365810a..4cca900 100644 --- a/robosystems_client/extensions/ledger_client.py +++ b/robosystems_client/extensions/ledger_client.py @@ -10,13 +10,19 @@ from typing import Any from ..api.ledger.auto_map_elements import sync_detailed as auto_map_elements +from ..api.ledger.create_closing_entry import sync_detailed as create_closing_entry from ..api.ledger.create_mapping_association import ( sync_detailed as create_mapping_association, ) +from ..api.ledger.create_schedule import sync_detailed as create_schedule from ..api.ledger.create_structure import sync_detailed as create_structure from ..api.ledger.delete_mapping_association import ( sync_detailed as delete_mapping_association, ) +from ..api.ledger.get_account_rollups import sync_detailed as get_account_rollups +from ..api.ledger.get_closing_book_structures import ( + sync_detailed as get_closing_book_structures, +) from ..api.ledger.get_ledger_account_tree import ( sync_detailed as get_ledger_account_tree, ) @@ -35,9 +41,13 @@ sync_detailed as get_mapping_coverage, ) from ..api.ledger.get_mapping_detail import sync_detailed as get_mapping_detail +from ..api.ledger.get_period_close_status import ( + sync_detailed as get_period_close_status, +) from ..api.ledger.get_reporting_taxonomy import ( sync_detailed as get_reporting_taxonomy, ) +from ..api.ledger.get_schedule_facts import sync_detailed as get_schedule_facts from ..api.ledger.list_elements import sync_detailed as list_elements from ..api.ledger.list_ledger_accounts import ( sync_detailed as list_ledger_accounts, @@ -46,6 +56,7 @@ sync_detailed as list_ledger_transactions, ) from ..api.ledger.list_mappings import sync_detailed as list_mappings +from ..api.ledger.list_schedules import sync_detailed as list_schedules from ..api.ledger.list_structures import sync_detailed as list_structures from ..client import AuthenticatedClient @@ -280,11 +291,9 @@ def create_mapping( confidence: float = 1.0, ) -> None: """Create a manual mapping association (CoA element → GAAP element).""" - from ..models.create_mapping_association_request import ( - CreateMappingAssociationRequest, - ) + from ..models.create_association_request import CreateAssociationRequest - body = CreateMappingAssociationRequest( + body = CreateAssociationRequest( from_element_id=from_element_id, to_element_id=to_element_id, confidence=confidence, @@ -314,3 +323,165 @@ def auto_map(self, graph_id: str, mapping_id: str) -> dict[str, Any]: if response.status_code != HTTPStatus.ACCEPTED: raise RuntimeError(f"Auto-map failed: {response.status_code}") return response.parsed or {} + + # ── Schedules ────────────────────────────────────────────────────── + + def create_schedule( + self, + graph_id: str, + *, + name: str, + element_ids: list[str], + period_start: str, + period_end: str, + monthly_amount: int, + debit_element_id: str, + credit_element_id: str, + entry_type: str = "closing", + memo_template: str = "", + taxonomy_id: str | None = None, + method: str | None = None, + original_amount: int | None = None, + residual_value: int | None = None, + useful_life_months: int | None = None, + asset_element_id: str | None = None, + auto_reverse: bool = False, + ) -> Any: + """Create a schedule with pre-generated monthly facts.""" + body: dict[str, Any] = { + "name": name, + "element_ids": element_ids, + "period_start": period_start, + "period_end": period_end, + "monthly_amount": monthly_amount, + "entry_template": { + "debit_element_id": debit_element_id, + "credit_element_id": credit_element_id, + "entry_type": entry_type, + "memo_template": memo_template, + "auto_reverse": auto_reverse, + }, + } + if taxonomy_id: + body["taxonomy_id"] = taxonomy_id + schedule_metadata: dict[str, Any] = {} + if method: + schedule_metadata["method"] = method + if original_amount is not None: + schedule_metadata["original_amount"] = original_amount + if residual_value is not None: + schedule_metadata["residual_value"] = residual_value + if useful_life_months is not None: + schedule_metadata["useful_life_months"] = useful_life_months + if asset_element_id: + schedule_metadata["asset_element_id"] = asset_element_id + if schedule_metadata: + body["schedule_metadata"] = schedule_metadata + + response = create_schedule(graph_id=graph_id, body=body, client=self._get_client()) + if response.status_code != HTTPStatus.CREATED: + raise RuntimeError(f"Create schedule failed: {response.status_code}") + return response.parsed + + def list_schedules(self, graph_id: str) -> Any: + """List all active schedule structures.""" + response = list_schedules(graph_id=graph_id, client=self._get_client()) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"List schedules failed: {response.status_code}") + return response.parsed + + def get_schedule_facts( + self, + graph_id: str, + structure_id: str, + period_start: str | None = None, + period_end: str | None = None, + ) -> Any: + """Get fact values for a schedule, optionally filtered by period.""" + response = get_schedule_facts( + graph_id=graph_id, + structure_id=structure_id, + period_start=period_start, + period_end=period_end, + client=self._get_client(), + ) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"Get schedule facts failed: {response.status_code}") + return response.parsed + + def get_period_close_status( + self, + graph_id: str, + period_start: str, + period_end: str, + ) -> Any: + """Get close status for all schedules in a fiscal period.""" + response = get_period_close_status( + graph_id=graph_id, + period_start=period_start, + period_end=period_end, + client=self._get_client(), + ) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"Get period close status failed: {response.status_code}") + return response.parsed + + def create_closing_entry( + self, + graph_id: str, + structure_id: str, + posting_date: str, + period_start: str, + period_end: str, + memo: str | None = None, + ) -> Any: + """Create a draft closing entry from a schedule's facts for a period.""" + body: dict[str, Any] = { + "posting_date": posting_date, + "period_start": period_start, + "period_end": period_end, + } + if memo: + body["memo"] = memo + + response = create_closing_entry( + graph_id=graph_id, + structure_id=structure_id, + body=body, + client=self._get_client(), + ) + if response.status_code != HTTPStatus.CREATED: + raise RuntimeError(f"Create closing entry failed: {response.status_code}") + return response.parsed + + # ── Closing Book ───────────────────────────────────────────────────── + + def get_closing_book_structures(self, graph_id: str) -> Any: + """Get all closing book structure categories for the sidebar.""" + response = get_closing_book_structures(graph_id=graph_id, client=self._get_client()) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"Get closing book structures failed: {response.status_code}") + return response.parsed + + def get_account_rollups( + self, + graph_id: str, + mapping_id: str | None = None, + start_date: str | None = None, + end_date: str | None = None, + ) -> Any: + """Get account rollups — CoA accounts grouped by reporting element with balances. + + Shows how company-specific accounts roll up to standardized reporting lines. + Auto-discovers the mapping structure if mapping_id is not provided. + """ + response = get_account_rollups( + graph_id=graph_id, + mapping_id=mapping_id, + start_date=start_date, + end_date=end_date, + client=self._get_client(), + ) + if response.status_code != HTTPStatus.OK: + raise RuntimeError(f"Get account rollups failed: {response.status_code}") + return response.parsed diff --git a/robosystems_client/extensions/tests/__init__.py b/robosystems_client/extensions/tests/__init__.py deleted file mode 100644 index 6e8da8f..0000000 --- a/robosystems_client/extensions/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for RoboSystems Client Extensions""" diff --git a/robosystems_client/models/__init__.py b/robosystems_client/models/__init__.py index f9caa2f..86285be 100644 --- a/robosystems_client/models/__init__.py +++ b/robosystems_client/models/__init__.py @@ -3,6 +3,9 @@ from .account_info import AccountInfo from .account_list_response import AccountListResponse from .account_response import AccountResponse +from .account_rollup_group import AccountRollupGroup +from .account_rollup_row import AccountRollupRow +from .account_rollups_response import AccountRollupsResponse from .account_tree_node import AccountTreeNode from .account_tree_response import AccountTreeResponse from .add_members_request import AddMembersRequest @@ -56,6 +59,9 @@ ) from .checkout_response import CheckoutResponse from .checkout_status_response import CheckoutStatusResponse +from .closing_book_category import ClosingBookCategory +from .closing_book_item import ClosingBookItem +from .closing_book_structures_response import ClosingBookStructuresResponse from .closing_entry_response import ClosingEntryResponse from .connection_options_response import ConnectionOptionsResponse from .connection_provider_info import ConnectionProviderInfo @@ -382,6 +388,9 @@ "AccountInfo", "AccountListResponse", "AccountResponse", + "AccountRollupGroup", + "AccountRollupRow", + "AccountRollupsResponse", "AccountTreeNode", "AccountTreeResponse", "AddMembersRequest", @@ -427,6 +436,9 @@ "CancelOperationResponseCanceloperation", "CheckoutResponse", "CheckoutStatusResponse", + "ClosingBookCategory", + "ClosingBookItem", + "ClosingBookStructuresResponse", "ClosingEntryResponse", "ConnectionOptionsResponse", "ConnectionProviderInfo", diff --git a/robosystems_client/models/account_rollup_group.py b/robosystems_client/models/account_rollup_group.py new file mode 100644 index 0000000..f7cffae --- /dev/null +++ b/robosystems_client/models/account_rollup_group.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.account_rollup_row import AccountRollupRow + + +T = TypeVar("T", bound="AccountRollupGroup") + + +@_attrs_define +class AccountRollupGroup: + """ + Attributes: + reporting_element_id (str): + reporting_name (str): + reporting_qname (str): + classification (str): + balance_type (str): + total (float): + accounts (list[AccountRollupRow]): + """ + + reporting_element_id: str + reporting_name: str + reporting_qname: str + classification: str + balance_type: str + total: float + accounts: list[AccountRollupRow] + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + reporting_element_id = self.reporting_element_id + + reporting_name = self.reporting_name + + reporting_qname = self.reporting_qname + + classification = self.classification + + balance_type = self.balance_type + + total = self.total + + accounts = [] + for accounts_item_data in self.accounts: + accounts_item = accounts_item_data.to_dict() + accounts.append(accounts_item) + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "reporting_element_id": reporting_element_id, + "reporting_name": reporting_name, + "reporting_qname": reporting_qname, + "classification": classification, + "balance_type": balance_type, + "total": total, + "accounts": accounts, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.account_rollup_row import AccountRollupRow + + d = dict(src_dict) + reporting_element_id = d.pop("reporting_element_id") + + reporting_name = d.pop("reporting_name") + + reporting_qname = d.pop("reporting_qname") + + classification = d.pop("classification") + + balance_type = d.pop("balance_type") + + total = d.pop("total") + + accounts = [] + _accounts = d.pop("accounts") + for accounts_item_data in _accounts: + accounts_item = AccountRollupRow.from_dict(accounts_item_data) + + accounts.append(accounts_item) + + account_rollup_group = cls( + reporting_element_id=reporting_element_id, + reporting_name=reporting_name, + reporting_qname=reporting_qname, + classification=classification, + balance_type=balance_type, + total=total, + accounts=accounts, + ) + + account_rollup_group.additional_properties = d + return account_rollup_group + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/account_rollup_row.py b/robosystems_client/models/account_rollup_row.py new file mode 100644 index 0000000..2396990 --- /dev/null +++ b/robosystems_client/models/account_rollup_row.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="AccountRollupRow") + + +@_attrs_define +class AccountRollupRow: + """ + Attributes: + element_id (str): + account_name (str): + total_debits (float): + total_credits (float): + net_balance (float): + account_code (None | str | Unset): + """ + + element_id: str + account_name: str + total_debits: float + total_credits: float + net_balance: float + account_code: None | str | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + element_id = self.element_id + + account_name = self.account_name + + total_debits = self.total_debits + + total_credits = self.total_credits + + net_balance = self.net_balance + + account_code: None | str | Unset + if isinstance(self.account_code, Unset): + account_code = UNSET + else: + account_code = self.account_code + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "element_id": element_id, + "account_name": account_name, + "total_debits": total_debits, + "total_credits": total_credits, + "net_balance": net_balance, + } + ) + if account_code is not UNSET: + field_dict["account_code"] = account_code + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + element_id = d.pop("element_id") + + account_name = d.pop("account_name") + + total_debits = d.pop("total_debits") + + total_credits = d.pop("total_credits") + + net_balance = d.pop("net_balance") + + def _parse_account_code(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + account_code = _parse_account_code(d.pop("account_code", UNSET)) + + account_rollup_row = cls( + element_id=element_id, + account_name=account_name, + total_debits=total_debits, + total_credits=total_credits, + net_balance=net_balance, + account_code=account_code, + ) + + account_rollup_row.additional_properties = d + return account_rollup_row + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/account_rollups_response.py b/robosystems_client/models/account_rollups_response.py new file mode 100644 index 0000000..b90f429 --- /dev/null +++ b/robosystems_client/models/account_rollups_response.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.account_rollup_group import AccountRollupGroup + + +T = TypeVar("T", bound="AccountRollupsResponse") + + +@_attrs_define +class AccountRollupsResponse: + """ + Attributes: + mapping_id (str): + mapping_name (str): + groups (list[AccountRollupGroup]): + total_mapped (int): + total_unmapped (int): + """ + + mapping_id: str + mapping_name: str + groups: list[AccountRollupGroup] + total_mapped: int + total_unmapped: int + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + mapping_id = self.mapping_id + + mapping_name = self.mapping_name + + groups = [] + for groups_item_data in self.groups: + groups_item = groups_item_data.to_dict() + groups.append(groups_item) + + total_mapped = self.total_mapped + + total_unmapped = self.total_unmapped + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "mapping_id": mapping_id, + "mapping_name": mapping_name, + "groups": groups, + "total_mapped": total_mapped, + "total_unmapped": total_unmapped, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.account_rollup_group import AccountRollupGroup + + d = dict(src_dict) + mapping_id = d.pop("mapping_id") + + mapping_name = d.pop("mapping_name") + + groups = [] + _groups = d.pop("groups") + for groups_item_data in _groups: + groups_item = AccountRollupGroup.from_dict(groups_item_data) + + groups.append(groups_item) + + total_mapped = d.pop("total_mapped") + + total_unmapped = d.pop("total_unmapped") + + account_rollups_response = cls( + mapping_id=mapping_id, + mapping_name=mapping_name, + groups=groups, + total_mapped=total_mapped, + total_unmapped=total_unmapped, + ) + + account_rollups_response.additional_properties = d + return account_rollups_response + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/closing_book_category.py b/robosystems_client/models/closing_book_category.py new file mode 100644 index 0000000..16f41a6 --- /dev/null +++ b/robosystems_client/models/closing_book_category.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.closing_book_item import ClosingBookItem + + +T = TypeVar("T", bound="ClosingBookCategory") + + +@_attrs_define +class ClosingBookCategory: + """ + Attributes: + label (str): + items (list[ClosingBookItem]): + """ + + label: str + items: list[ClosingBookItem] + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + label = self.label + + items = [] + for items_item_data in self.items: + items_item = items_item_data.to_dict() + items.append(items_item) + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "label": label, + "items": items, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.closing_book_item import ClosingBookItem + + d = dict(src_dict) + label = d.pop("label") + + items = [] + _items = d.pop("items") + for items_item_data in _items: + items_item = ClosingBookItem.from_dict(items_item_data) + + items.append(items_item) + + closing_book_category = cls( + label=label, + items=items, + ) + + closing_book_category.additional_properties = d + return closing_book_category + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/closing_book_item.py b/robosystems_client/models/closing_book_item.py new file mode 100644 index 0000000..37f1247 --- /dev/null +++ b/robosystems_client/models/closing_book_item.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="ClosingBookItem") + + +@_attrs_define +class ClosingBookItem: + """ + Attributes: + id (str): + name (str): + item_type (str): + structure_type (None | str | Unset): + report_id (None | str | Unset): + status (None | str | Unset): + """ + + id: str + name: str + item_type: str + structure_type: None | str | Unset = UNSET + report_id: None | str | Unset = UNSET + status: None | str | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + id = self.id + + name = self.name + + item_type = self.item_type + + structure_type: None | str | Unset + if isinstance(self.structure_type, Unset): + structure_type = UNSET + else: + structure_type = self.structure_type + + report_id: None | str | Unset + if isinstance(self.report_id, Unset): + report_id = UNSET + else: + report_id = self.report_id + + status: None | str | Unset + if isinstance(self.status, Unset): + status = UNSET + else: + status = self.status + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "id": id, + "name": name, + "item_type": item_type, + } + ) + if structure_type is not UNSET: + field_dict["structure_type"] = structure_type + if report_id is not UNSET: + field_dict["report_id"] = report_id + if status is not UNSET: + field_dict["status"] = status + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + id = d.pop("id") + + name = d.pop("name") + + item_type = d.pop("item_type") + + def _parse_structure_type(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + structure_type = _parse_structure_type(d.pop("structure_type", UNSET)) + + def _parse_report_id(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + report_id = _parse_report_id(d.pop("report_id", UNSET)) + + def _parse_status(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + status = _parse_status(d.pop("status", UNSET)) + + closing_book_item = cls( + id=id, + name=name, + item_type=item_type, + structure_type=structure_type, + report_id=report_id, + status=status, + ) + + closing_book_item.additional_properties = d + return closing_book_item + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/closing_book_structures_response.py b/robosystems_client/models/closing_book_structures_response.py new file mode 100644 index 0000000..2d6631a --- /dev/null +++ b/robosystems_client/models/closing_book_structures_response.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.closing_book_category import ClosingBookCategory + + +T = TypeVar("T", bound="ClosingBookStructuresResponse") + + +@_attrs_define +class ClosingBookStructuresResponse: + """ + Attributes: + categories (list[ClosingBookCategory]): + has_data (bool): + """ + + categories: list[ClosingBookCategory] + has_data: bool + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + categories = [] + for categories_item_data in self.categories: + categories_item = categories_item_data.to_dict() + categories.append(categories_item) + + has_data = self.has_data + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "categories": categories, + "has_data": has_data, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.closing_book_category import ClosingBookCategory + + d = dict(src_dict) + categories = [] + _categories = d.pop("categories") + for categories_item_data in _categories: + categories_item = ClosingBookCategory.from_dict(categories_item_data) + + categories.append(categories_item) + + has_data = d.pop("has_data") + + closing_book_structures_response = cls( + categories=categories, + has_data=has_data, + ) + + closing_book_structures_response.additional_properties = d + return closing_book_structures_response + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/robosystems_client/models/closing_entry_response.py b/robosystems_client/models/closing_entry_response.py index a8d33c4..9cbe8e5 100644 --- a/robosystems_client/models/closing_entry_response.py +++ b/robosystems_client/models/closing_entry_response.py @@ -2,12 +2,14 @@ import datetime from collections.abc import Mapping -from typing import Any, TypeVar +from typing import Any, TypeVar, cast from attrs import define as _attrs_define from attrs import field as _attrs_field from dateutil.parser import isoparse +from ..types import UNSET, Unset + T = TypeVar("T", bound="ClosingEntryResponse") @@ -22,6 +24,7 @@ class ClosingEntryResponse: debit_element_id (str): credit_element_id (str): amount (float): + reversal (ClosingEntryResponse | None | Unset): """ entry_id: str @@ -31,6 +34,7 @@ class ClosingEntryResponse: debit_element_id: str credit_element_id: str amount: float + reversal: ClosingEntryResponse | None | Unset = UNSET additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: @@ -48,6 +52,14 @@ def to_dict(self) -> dict[str, Any]: amount = self.amount + reversal: dict[str, Any] | None | Unset + if isinstance(self.reversal, Unset): + reversal = UNSET + elif isinstance(self.reversal, ClosingEntryResponse): + reversal = self.reversal.to_dict() + else: + reversal = self.reversal + field_dict: dict[str, Any] = {} field_dict.update(self.additional_properties) field_dict.update( @@ -61,6 +73,8 @@ def to_dict(self) -> dict[str, Any]: "amount": amount, } ) + if reversal is not UNSET: + field_dict["reversal"] = reversal return field_dict @@ -81,6 +95,23 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: amount = d.pop("amount") + def _parse_reversal(data: object) -> ClosingEntryResponse | None | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + try: + if not isinstance(data, dict): + raise TypeError() + reversal_type_0 = ClosingEntryResponse.from_dict(data) + + return reversal_type_0 + except (TypeError, ValueError, AttributeError, KeyError): + pass + return cast(ClosingEntryResponse | None | Unset, data) + + reversal = _parse_reversal(d.pop("reversal", UNSET)) + closing_entry_response = cls( entry_id=entry_id, status=status, @@ -89,6 +120,7 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: debit_element_id=debit_element_id, credit_element_id=credit_element_id, amount=amount, + reversal=reversal, ) closing_entry_response.additional_properties = d diff --git a/robosystems_client/models/entry_template_request.py b/robosystems_client/models/entry_template_request.py index e46f5f8..2351740 100644 --- a/robosystems_client/models/entry_template_request.py +++ b/robosystems_client/models/entry_template_request.py @@ -19,12 +19,14 @@ class EntryTemplateRequest: credit_element_id (str): Element to credit (e.g., Accumulated Depreciation) entry_type (str | Unset): Entry type for generated entries Default: 'closing'. memo_template (str | Unset): Memo template ({structure_name} is replaced) Default: ''. + auto_reverse (bool | Unset): Auto-generate a reversing entry on the first day of the next period Default: False. """ debit_element_id: str credit_element_id: str entry_type: str | Unset = "closing" memo_template: str | Unset = "" + auto_reverse: bool | Unset = False additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: @@ -36,6 +38,8 @@ def to_dict(self) -> dict[str, Any]: memo_template = self.memo_template + auto_reverse = self.auto_reverse + field_dict: dict[str, Any] = {} field_dict.update(self.additional_properties) field_dict.update( @@ -48,6 +52,8 @@ def to_dict(self) -> dict[str, Any]: field_dict["entry_type"] = entry_type if memo_template is not UNSET: field_dict["memo_template"] = memo_template + if auto_reverse is not UNSET: + field_dict["auto_reverse"] = auto_reverse return field_dict @@ -62,11 +68,14 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: memo_template = d.pop("memo_template", UNSET) + auto_reverse = d.pop("auto_reverse", UNSET) + entry_template_request = cls( debit_element_id=debit_element_id, credit_element_id=credit_element_id, entry_type=entry_type, memo_template=memo_template, + auto_reverse=auto_reverse, ) entry_template_request.additional_properties = d diff --git a/robosystems_client/models/period_close_item_response.py b/robosystems_client/models/period_close_item_response.py index f44515c..52218c1 100644 --- a/robosystems_client/models/period_close_item_response.py +++ b/robosystems_client/models/period_close_item_response.py @@ -20,6 +20,8 @@ class PeriodCloseItemResponse: amount (float): status (str): entry_id (None | str | Unset): + reversal_entry_id (None | str | Unset): + reversal_status (None | str | Unset): """ structure_id: str @@ -27,6 +29,8 @@ class PeriodCloseItemResponse: amount: float status: str entry_id: None | str | Unset = UNSET + reversal_entry_id: None | str | Unset = UNSET + reversal_status: None | str | Unset = UNSET additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: @@ -44,6 +48,18 @@ def to_dict(self) -> dict[str, Any]: else: entry_id = self.entry_id + reversal_entry_id: None | str | Unset + if isinstance(self.reversal_entry_id, Unset): + reversal_entry_id = UNSET + else: + reversal_entry_id = self.reversal_entry_id + + reversal_status: None | str | Unset + if isinstance(self.reversal_status, Unset): + reversal_status = UNSET + else: + reversal_status = self.reversal_status + field_dict: dict[str, Any] = {} field_dict.update(self.additional_properties) field_dict.update( @@ -56,6 +72,10 @@ def to_dict(self) -> dict[str, Any]: ) if entry_id is not UNSET: field_dict["entry_id"] = entry_id + if reversal_entry_id is not UNSET: + field_dict["reversal_entry_id"] = reversal_entry_id + if reversal_status is not UNSET: + field_dict["reversal_status"] = reversal_status return field_dict @@ -79,12 +99,32 @@ def _parse_entry_id(data: object) -> None | str | Unset: entry_id = _parse_entry_id(d.pop("entry_id", UNSET)) + def _parse_reversal_entry_id(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + reversal_entry_id = _parse_reversal_entry_id(d.pop("reversal_entry_id", UNSET)) + + def _parse_reversal_status(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + reversal_status = _parse_reversal_status(d.pop("reversal_status", UNSET)) + period_close_item_response = cls( structure_id=structure_id, structure_name=structure_name, amount=amount, status=status, entry_id=entry_id, + reversal_entry_id=reversal_entry_id, + reversal_status=reversal_status, ) period_close_item_response.additional_properties = d diff --git a/robosystems_client/models/search_hit.py b/robosystems_client/models/search_hit.py index f642234..4776ffc 100644 --- a/robosystems_client/models/search_hit.py +++ b/robosystems_client/models/search_hit.py @@ -20,6 +20,7 @@ class SearchHit: score (float): source_type (str): snippet (str): + parent_document_id (None | str | Unset): entity_ticker (None | str | Unset): entity_name (None | str | Unset): section_label (None | str | Unset): @@ -40,6 +41,7 @@ class SearchHit: score: float source_type: str snippet: str + parent_document_id: None | str | Unset = UNSET entity_ticker: None | str | Unset = UNSET entity_name: None | str | Unset = UNSET section_label: None | str | Unset = UNSET @@ -65,6 +67,12 @@ def to_dict(self) -> dict[str, Any]: snippet = self.snippet + parent_document_id: None | str | Unset + if isinstance(self.parent_document_id, Unset): + parent_document_id = UNSET + else: + parent_document_id = self.parent_document_id + entity_ticker: None | str | Unset if isinstance(self.entity_ticker, Unset): entity_ticker = UNSET @@ -161,6 +169,8 @@ def to_dict(self) -> dict[str, Any]: "snippet": snippet, } ) + if parent_document_id is not UNSET: + field_dict["parent_document_id"] = parent_document_id if entity_ticker is not UNSET: field_dict["entity_ticker"] = entity_ticker if entity_name is not UNSET: @@ -203,6 +213,15 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: snippet = d.pop("snippet") + def _parse_parent_document_id(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + parent_document_id = _parse_parent_document_id(d.pop("parent_document_id", UNSET)) + def _parse_entity_ticker(data: object) -> None | str | Unset: if data is None: return data @@ -343,6 +362,7 @@ def _parse_folder(data: object) -> None | str | Unset: score=score, source_type=source_type, snippet=snippet, + parent_document_id=parent_document_id, entity_ticker=entity_ticker, entity_name=entity_name, section_label=section_label, diff --git a/tests/test_agent_client.py b/tests/test_agent_client.py new file mode 100644 index 0000000..85b160f --- /dev/null +++ b/tests/test_agent_client.py @@ -0,0 +1,335 @@ +"""Unit tests for AgentClient.""" + +import pytest +from unittest.mock import Mock, patch +from robosystems_client.extensions.agent_client import ( + AgentClient, + AgentQueryRequest, + AgentOptions, + AgentResult, + QueuedAgentResponse, + QueuedAgentError, +) + + +@pytest.mark.unit +class TestAgentDataclasses: + """Test suite for agent-related dataclasses.""" + + def test_agent_query_request(self): + """Test AgentQueryRequest dataclass.""" + request = AgentQueryRequest( + message="What is the revenue?", + history=[{"role": "user", "content": "Hello"}], + context={"fiscal_year": 2024}, + mode="standard", + enable_rag=True, + force_extended_analysis=False, + ) + + assert request.message == "What is the revenue?" + assert len(request.history) == 1 + assert request.context == {"fiscal_year": 2024} + assert request.mode == "standard" + assert request.enable_rag is True + assert request.force_extended_analysis is False + + def test_agent_query_request_defaults(self): + """Test AgentQueryRequest default values.""" + request = AgentQueryRequest(message="Simple query") + + assert request.message == "Simple query" + assert request.history is None + assert request.context is None + assert request.mode is None + assert request.enable_rag is None + assert request.force_extended_analysis is None + + def test_agent_options_defaults(self): + """Test AgentOptions default values.""" + options = AgentOptions() + + assert options.mode == "auto" + assert options.max_wait is None + assert options.on_progress is None + + def test_agent_options_custom(self): + """Test AgentOptions with custom values.""" + progress_fn = Mock() + options = AgentOptions(mode="sync", max_wait=30, on_progress=progress_fn) + + assert options.mode == "sync" + assert options.max_wait == 30 + assert options.on_progress is progress_fn + + def test_agent_result(self): + """Test AgentResult dataclass.""" + result = AgentResult( + content="Revenue is $1B", + agent_used="financial", + mode_used="standard", + metadata={"graph_id": "g-123"}, + tokens_used={"input": 100, "output": 50}, + confidence_score=0.95, + execution_time=1.5, + timestamp="2025-01-15T10:00:00Z", + ) + + assert result.content == "Revenue is $1B" + assert result.agent_used == "financial" + assert result.mode_used == "standard" + assert result.confidence_score == 0.95 + assert result.execution_time == 1.5 + + def test_agent_result_defaults(self): + """Test AgentResult with minimal fields.""" + result = AgentResult(content="Answer", agent_used="rag", mode_used="quick") + + assert result.metadata is None + assert result.tokens_used is None + assert result.confidence_score is None + assert result.execution_time is None + assert result.timestamp is None + + def test_queued_agent_response(self): + """Test QueuedAgentResponse dataclass.""" + response = QueuedAgentResponse( + status="queued", + operation_id="op-agent-1", + message="Agent execution queued", + sse_endpoint="/v1/operations/op-agent-1/stream", + ) + + assert response.status == "queued" + assert response.operation_id == "op-agent-1" + assert response.message == "Agent execution queued" + assert response.sse_endpoint == "/v1/operations/op-agent-1/stream" + + def test_queued_agent_response_defaults(self): + """Test QueuedAgentResponse defaults.""" + response = QueuedAgentResponse( + status="queued", operation_id="op-1", message="Queued" + ) + + assert response.sse_endpoint is None + + def test_queued_agent_error(self): + """Test QueuedAgentError exception.""" + queue_info = QueuedAgentResponse( + status="queued", operation_id="op-err", message="Queued" + ) + error = QueuedAgentError(queue_info) + + assert str(error) == "Agent execution was queued" + assert error.queue_info.operation_id == "op-err" + + +@pytest.mark.unit +class TestAgentClientInit: + """Test suite for AgentClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = AgentClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.sse_client is None + + def test_close_without_sse(self, mock_config): + """Test close when no SSE client exists.""" + client = AgentClient(mock_config) + client.close() # Should not raise + + def test_close_with_sse(self, mock_config): + """Test close cleans up SSE client.""" + client = AgentClient(mock_config) + client.sse_client = Mock() + client.close() + + assert client.sse_client is None + + +@pytest.mark.unit +class TestAgentExecuteQuery: + """Test suite for AgentClient.execute_query method.""" + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_dict_response(self, mock_auto, mock_config, graph_id): + """Test execute_query with dict response.""" + mock_resp = Mock() + mock_resp.parsed = { + "content": "Revenue is $1B for FY2024.", + "agent_used": "financial", + "mode_used": "standard", + "metadata": {}, + "tokens_used": {"input": 100, "output": 50}, + "confidence_score": 0.9, + "execution_time": 2.5, + "timestamp": "2025-01-15T10:00:00Z", + } + mock_auto.return_value = mock_resp + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="What is revenue?") + result = client.execute_query(graph_id, request) + + assert result.content == "Revenue is $1B for FY2024." + assert result.agent_used == "financial" + assert result.confidence_score == 0.9 + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_queued_max_wait_zero(self, mock_auto, mock_config, graph_id): + """Test execute_query raises QueuedAgentError when max_wait=0.""" + mock_resp = Mock() + mock_resp.parsed = { + "operation_id": "op-queued", + "status": "queued", + "message": "Agent execution queued", + } + mock_auto.return_value = mock_resp + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="Complex query") + options = AgentOptions(max_wait=0) + + with pytest.raises(QueuedAgentError) as exc_info: + client.execute_query(graph_id, request, options) + + assert exc_info.value.queue_info.operation_id == "op-queued" + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_no_token(self, mock_auto, mock_config, graph_id): + """Test execute_query fails without token.""" + mock_config["token"] = None + client = AgentClient(mock_config) + request = AgentQueryRequest(message="test") + + with pytest.raises(Exception, match="Authentication failed|No API key"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_auth_error(self, mock_auto, mock_config, graph_id): + """Test execute_query wraps 401 errors.""" + mock_auto.side_effect = Exception("401 Unauthorized") + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="test") + + with pytest.raises(Exception, match="Authentication failed"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.agent_client.auto_select_agent") + def test_execute_query_generic_error(self, mock_auto, mock_config, graph_id): + """Test execute_query wraps generic errors.""" + mock_auto.side_effect = Exception("Connection timeout") + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="test") + + with pytest.raises(Exception, match="Agent execution failed"): + client.execute_query(graph_id, request) + + +@pytest.mark.unit +class TestAgentExecuteSpecific: + """Test suite for AgentClient.execute_agent method.""" + + @patch("robosystems_client.extensions.agent_client.execute_specific_agent") + def test_execute_specific_agent(self, mock_exec, mock_config, graph_id): + """Test executing a specific agent type.""" + mock_resp = Mock() + mock_resp.parsed = { + "content": "Deep analysis of financials.", + "agent_used": "financial", + "mode_used": "extended", + } + mock_exec.return_value = mock_resp + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="Analyze financials") + result = client.execute_agent(graph_id, "financial", request) + + assert result.content == "Deep analysis of financials." + assert result.agent_used == "financial" + + @patch("robosystems_client.extensions.agent_client.execute_specific_agent") + def test_execute_specific_agent_queued(self, mock_exec, mock_config, graph_id): + """Test specific agent returns queued response.""" + mock_resp = Mock() + mock_resp.parsed = { + "operation_id": "op-specific-queued", + "status": "queued", + "message": "Queued for processing", + } + mock_exec.return_value = mock_resp + + client = AgentClient(mock_config) + request = AgentQueryRequest(message="Heavy analysis") + options = AgentOptions(max_wait=0) + + with pytest.raises(QueuedAgentError) as exc_info: + client.execute_agent(graph_id, "research", request, options) + + assert exc_info.value.queue_info.operation_id == "op-specific-queued" + + +@pytest.mark.unit +class TestAgentConvenienceMethods: + """Test suite for convenience methods.""" + + @patch.object(AgentClient, "execute_query") + def test_query_convenience(self, mock_exec, mock_config, graph_id): + """Test query() convenience method.""" + mock_exec.return_value = AgentResult( + content="Answer", agent_used="rag", mode_used="quick" + ) + + client = AgentClient(mock_config) + result = client.query(graph_id, "What is X?") + + assert result.content == "Answer" + mock_exec.assert_called_once() + + @patch.object(AgentClient, "execute_agent") + def test_analyze_financials_convenience(self, mock_exec, mock_config, graph_id): + """Test analyze_financials() convenience method.""" + mock_exec.return_value = AgentResult( + content="Financial analysis", agent_used="financial", mode_used="standard" + ) + + client = AgentClient(mock_config) + result = client.analyze_financials(graph_id, "Analyze revenue trends") + + assert result.agent_used == "financial" + call_args = mock_exec.call_args + assert call_args[0][1] == "financial" # agent_type + + @patch.object(AgentClient, "execute_agent") + def test_research_convenience(self, mock_exec, mock_config, graph_id): + """Test research() convenience method.""" + mock_exec.return_value = AgentResult( + content="Research results", agent_used="research", mode_used="extended" + ) + + client = AgentClient(mock_config) + result = client.research(graph_id, "Deep dive into market") + + assert result.agent_used == "research" + call_args = mock_exec.call_args + assert call_args[0][1] == "research" + + @patch.object(AgentClient, "execute_agent") + def test_rag_convenience(self, mock_exec, mock_config, graph_id): + """Test rag() convenience method.""" + mock_exec.return_value = AgentResult( + content="RAG answer", agent_used="rag", mode_used="quick" + ) + + client = AgentClient(mock_config) + result = client.rag(graph_id, "Quick lookup") + + assert result.agent_used == "rag" + call_args = mock_exec.call_args + assert call_args[0][1] == "rag" diff --git a/robosystems_client/extensions/tests/test_dataframe_utils.py b/tests/test_dataframe_utils.py similarity index 99% rename from robosystems_client/extensions/tests/test_dataframe_utils.py rename to tests/test_dataframe_utils.py index 4aad57e..05d9b24 100644 --- a/robosystems_client/extensions/tests/test_dataframe_utils.py +++ b/tests/test_dataframe_utils.py @@ -126,7 +126,7 @@ def test_parse_datetime_columns_infer(self): df = parse_datetime_columns(df, infer=True) assert pd.api.types.is_datetime64_any_dtype(df["timestamp"]) - assert df["not_a_date"].dtype == "object" + assert not pd.api.types.is_datetime64_any_dtype(df["not_a_date"]) class TestStreamToDataFrame: diff --git a/tests/test_document_client.py b/tests/test_document_client.py new file mode 100644 index 0000000..cf77c0c --- /dev/null +++ b/tests/test_document_client.py @@ -0,0 +1,433 @@ +"""Unit tests for DocumentClient.""" + +import pytest +from http import HTTPStatus +from unittest.mock import Mock, patch +from robosystems_client.extensions.document_client import DocumentClient + + +@pytest.mark.unit +class TestDocumentClientInit: + """Test suite for DocumentClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = DocumentClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.timeout == 60 + + def test_client_custom_timeout(self, mock_config): + """Test client with custom timeout.""" + mock_config["timeout"] = 120 + client = DocumentClient(mock_config) + + assert client.timeout == 120 + + def test_get_client_no_token(self, mock_config): + """Test _get_client raises without token.""" + mock_config["token"] = None + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="No API key"): + client._get_client() + + def test_close_is_noop(self, mock_config): + """Test close method doesn't raise.""" + client = DocumentClient(mock_config) + client.close() # Should not raise + + +@pytest.mark.unit +class TestDocumentUpload: + """Test suite for DocumentClient.upload method.""" + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_document(self, mock_upload, mock_config, graph_id): + """Test uploading a markdown document.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-123", section_count=3) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.upload( + graph_id=graph_id, + title="Test Document", + content="# Heading\n\nSome content.", + tags=["test", "demo"], + folder="reports", + ) + + assert result.document_id == "doc-123" + assert result.section_count == 3 + mock_upload.assert_called_once() + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_document_failure(self, mock_upload, mock_config, graph_id): + """Test upload failure raises exception.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_resp.content = b"Server error" + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Document upload failed"): + client.upload(graph_id=graph_id, title="Bad", content="content") + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_minimal(self, mock_upload, mock_config, graph_id): + """Test upload with only required fields.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-min", section_count=1) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.upload(graph_id=graph_id, title="Minimal", content="Just text.") + + assert result.document_id == "doc-min" + + +@pytest.mark.unit +class TestDocumentGet: + """Test suite for DocumentClient.get method.""" + + @patch("robosystems_client.extensions.document_client.get_document") + def test_get_document(self, mock_get, mock_config, graph_id): + """Test getting a document by ID.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + document_id="doc-456", title="Found Document", content="# Content" + ) + mock_get.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.get(graph_id=graph_id, document_id="doc-456") + + assert result is not None + assert result.title == "Found Document" + + @patch("robosystems_client.extensions.document_client.get_document") + def test_get_document_not_found(self, mock_get, mock_config, graph_id): + """Test getting a document that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.get(graph_id=graph_id, document_id="nonexistent") + + assert result is None + + @patch("robosystems_client.extensions.document_client.get_document") + def test_get_document_server_error(self, mock_get, mock_config, graph_id): + """Test get raises on server error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_resp.content = b"Internal error" + mock_get.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Get document failed"): + client.get(graph_id=graph_id, document_id="doc-err") + + +@pytest.mark.unit +class TestDocumentUpdate: + """Test suite for DocumentClient.update method.""" + + @patch("robosystems_client.extensions.document_client.update_document") + def test_update_document(self, mock_update, mock_config, graph_id): + """Test updating a document.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-upd", section_count=5) + mock_update.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.update( + graph_id=graph_id, + document_id="doc-upd", + title="Updated Title", + content="# Updated\n\nNew content.", + ) + + assert result.section_count == 5 + + @patch("robosystems_client.extensions.document_client.update_document") + def test_update_document_failure(self, mock_update, mock_config, graph_id): + """Test update failure raises exception.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_resp.content = b"Bad request" + mock_update.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Update document failed"): + client.update(graph_id=graph_id, document_id="doc-bad", title="Bad") + + +@pytest.mark.unit +class TestDocumentSearch: + """Test suite for DocumentClient.search method.""" + + @patch("robosystems_client.extensions.document_client.search_documents") + def test_search_documents(self, mock_search, mock_config, graph_id): + """Test searching documents.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=2, hits=[Mock(), Mock()]) + mock_search.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.search(graph_id=graph_id, query="revenue growth") + + assert result.total == 2 + assert len(result.hits) == 2 + + @patch("robosystems_client.extensions.document_client.search_documents") + def test_search_with_filters(self, mock_search, mock_config, graph_id): + """Test searching with filters.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=1, hits=[Mock()]) + mock_search.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.search( + graph_id=graph_id, + query="risk factors", + source_type="xbrl_textblock", + form_type="10-K", + entity="AAPL", + fiscal_year=2024, + size=5, + ) + + assert result.total == 1 + mock_search.assert_called_once() + + @patch("robosystems_client.extensions.document_client.search_documents") + def test_search_failure(self, mock_search, mock_config, graph_id): + """Test search failure raises exception.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_resp.content = b"Search error" + mock_search.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Document search failed"): + client.search(graph_id=graph_id, query="test") + + +@pytest.mark.unit +class TestDocumentList: + """Test suite for DocumentClient.list method.""" + + @patch("robosystems_client.extensions.document_client.list_documents") + def test_list_documents(self, mock_list, mock_config, graph_id): + """Test listing documents.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=3, documents=[Mock(), Mock(), Mock()]) + mock_list.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.list(graph_id=graph_id) + + assert result.total == 3 + + @patch("robosystems_client.extensions.document_client.list_documents") + def test_list_documents_with_filter(self, mock_list, mock_config, graph_id): + """Test listing with source type filter.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=1, documents=[Mock()]) + mock_list.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.list(graph_id=graph_id, source_type="uploaded_doc") + + assert result.total == 1 + + +@pytest.mark.unit +class TestDocumentDelete: + """Test suite for DocumentClient.delete method.""" + + @patch("robosystems_client.extensions.document_client.delete_document") + def test_delete_document(self, mock_delete, mock_config, graph_id): + """Test deleting a document.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NO_CONTENT + mock_delete.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.delete(graph_id=graph_id, document_id="doc-del") + + assert result is True + + @patch("robosystems_client.extensions.document_client.delete_document") + def test_delete_document_not_found(self, mock_delete, mock_config, graph_id): + """Test deleting a document that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_delete.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.delete(graph_id=graph_id, document_id="doc-missing") + + assert result is False + + @patch("robosystems_client.extensions.document_client.delete_document") + def test_delete_document_server_error(self, mock_delete, mock_config, graph_id): + """Test delete raises on server error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_resp.content = b"Delete error" + mock_delete.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Delete document failed"): + client.delete(graph_id=graph_id, document_id="doc-err") + + +@pytest.mark.unit +class TestDocumentSection: + """Test suite for DocumentClient.get_section method.""" + + @patch("robosystems_client.extensions.document_client.get_document_section") + def test_get_section(self, mock_get_section, mock_config, graph_id): + """Test getting a document section.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + section_id="sec-1", title="Revenue", content="Revenue was $1B." + ) + mock_get_section.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.get_section(graph_id=graph_id, document_id="sec-1") + + assert result is not None + assert result.title == "Revenue" + + @patch("robosystems_client.extensions.document_client.get_document_section") + def test_get_section_not_found(self, mock_get_section, mock_config, graph_id): + """Test getting a section that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get_section.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.get_section(graph_id=graph_id, document_id="nonexistent") + + assert result is None + + +@pytest.mark.unit +class TestDocumentBulkUpload: + """Test suite for DocumentClient.upload_bulk method.""" + + @patch("robosystems_client.extensions.document_client.upload_documents_bulk") + def test_upload_bulk(self, mock_bulk, mock_config, graph_id): + """Test bulk uploading documents.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total=2, succeeded=2, failed=0, results=[Mock(), Mock()]) + mock_bulk.return_value = mock_resp + + client = DocumentClient(mock_config) + docs = [ + {"title": "Doc 1", "content": "Content 1"}, + {"title": "Doc 2", "content": "Content 2", "tags": ["tag1"]}, + ] + result = client.upload_bulk(graph_id=graph_id, documents=docs) + + assert result.total == 2 + assert result.succeeded == 2 + + @patch("robosystems_client.extensions.document_client.upload_documents_bulk") + def test_upload_bulk_failure(self, mock_bulk, mock_config, graph_id): + """Test bulk upload failure raises exception.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_resp.content = b"Bulk upload error" + mock_bulk.return_value = mock_resp + + client = DocumentClient(mock_config) + + with pytest.raises(Exception, match="Bulk upload failed"): + client.upload_bulk(graph_id=graph_id, documents=[{"title": "x", "content": "y"}]) + + +@pytest.mark.unit +class TestDocumentUploadFile: + """Test suite for DocumentClient.upload_file method.""" + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_file(self, mock_upload, mock_config, graph_id, tmp_path): + """Test uploading a file from disk.""" + md_file = tmp_path / "test-doc.md" + md_file.write_text("# Test\n\nContent here.") + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-file", section_count=2) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + result = client.upload_file(graph_id=graph_id, file_path=md_file) + + assert result.document_id == "doc-file" + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_file_title_from_filename( + self, mock_upload, mock_config, graph_id, tmp_path + ): + """Test that title is derived from filename when not provided.""" + md_file = tmp_path / "revenue-analysis.md" + md_file.write_text("# Revenue\n\nAnalysis content.") + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-title", section_count=1) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + client.upload_file(graph_id=graph_id, file_path=md_file) + + # Verify the body was created with title derived from filename + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["body"].title == "Revenue Analysis" + + +@pytest.mark.unit +class TestDocumentUploadDirectory: + """Test suite for DocumentClient.upload_directory method.""" + + @patch("robosystems_client.extensions.document_client.upload_document") + def test_upload_directory(self, mock_upload, mock_config, graph_id, tmp_path): + """Test uploading all markdown files from a directory.""" + (tmp_path / "doc1.md").write_text("# Doc 1") + (tmp_path / "doc2.md").write_text("# Doc 2") + (tmp_path / "not-md.txt").write_text("Skip this") + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(document_id="doc-dir", section_count=1) + mock_upload.return_value = mock_resp + + client = DocumentClient(mock_config) + results = client.upload_directory(graph_id=graph_id, directory=tmp_path) + + assert len(results) == 2 # Only .md files diff --git a/robosystems_client/extensions/tests/test_integration.py b/tests/test_extensions_integration.py similarity index 100% rename from robosystems_client/extensions/tests/test_integration.py rename to tests/test_extensions_integration.py diff --git a/tests/test_file_client.py b/tests/test_file_client.py new file mode 100644 index 0000000..ddbd73f --- /dev/null +++ b/tests/test_file_client.py @@ -0,0 +1,458 @@ +"""Unit tests for FileClient.""" + +import pytest +from io import BytesIO +from unittest.mock import Mock, patch +from robosystems_client.extensions.file_client import ( + FileClient, + FileUploadOptions, + FileUploadResult, + FileInfo, +) + + +@pytest.mark.unit +class TestFileClientInit: + """Test suite for FileClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = FileClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.s3_endpoint_url is None + + def test_client_with_s3_override(self, mock_config): + """Test client with S3 endpoint override.""" + mock_config["s3_endpoint_url"] = "http://localhost:4566" + client = FileClient(mock_config) + + assert client.s3_endpoint_url == "http://localhost:4566" + + def test_client_cleanup(self, mock_config): + """Test that cleanup closes HTTP client.""" + client = FileClient(mock_config) + assert hasattr(client, "_http_client") + client.__del__() + + +@pytest.mark.unit +class TestFileDataclasses: + """Test suite for file-related dataclasses.""" + + def test_file_upload_options_defaults(self): + """Test FileUploadOptions default values.""" + options = FileUploadOptions() + + assert options.on_progress is None + assert options.ingest_to_graph is False + + def test_file_upload_options_custom(self): + """Test FileUploadOptions with custom values.""" + progress_fn = Mock() + options = FileUploadOptions(on_progress=progress_fn, ingest_to_graph=True) + + assert options.on_progress is progress_fn + assert options.ingest_to_graph is True + + def test_file_upload_result(self): + """Test FileUploadResult dataclass creation.""" + result = FileUploadResult( + file_id="file-123", + file_size=50000, + row_count=100, + table_name="Entity", + file_name="data.parquet", + success=True, + ) + + assert result.file_id == "file-123" + assert result.file_size == 50000 + assert result.row_count == 100 + assert result.table_name == "Entity" + assert result.file_name == "data.parquet" + assert result.success is True + assert result.error is None + + def test_file_upload_result_with_error(self): + """Test FileUploadResult with error.""" + result = FileUploadResult( + file_id="", + file_size=0, + row_count=0, + table_name="Entity", + file_name="bad.parquet", + success=False, + error="Upload failed", + ) + + assert result.success is False + assert result.error == "Upload failed" + + def test_file_info(self): + """Test FileInfo dataclass creation.""" + info = FileInfo( + file_id="file-456", + file_name="transactions.parquet", + file_format="parquet", + size_bytes=120000, + row_count=500, + upload_status="uploaded", + table_name="Transaction", + created_at="2025-01-15T10:00:00Z", + uploaded_at="2025-01-15T10:01:00Z", + ) + + assert info.file_id == "file-456" + assert info.file_name == "transactions.parquet" + assert info.file_format == "parquet" + assert info.size_bytes == 120000 + assert info.row_count == 500 + assert info.upload_status == "uploaded" + assert info.table_name == "Transaction" + assert info.layers is None + + def test_file_info_with_layers(self): + """Test FileInfo with layer tracking.""" + layers = {"s3": "uploaded", "duckdb": "staged", "graph": "pending"} + info = FileInfo( + file_id="file-789", + file_name="data.parquet", + file_format="parquet", + size_bytes=1000, + row_count=10, + upload_status="uploaded", + table_name="Entity", + created_at=None, + uploaded_at=None, + layers=layers, + ) + + assert info.layers == layers + + +@pytest.mark.unit +class TestFileUpload: + """Test suite for FileClient.upload method.""" + + @patch("robosystems_client.extensions.file_client.update_file") + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_bytesio(self, mock_create, mock_update, mock_config, graph_id): + """Test uploading a BytesIO buffer.""" + # Mock presigned URL response + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = "http://s3.localhost/bucket/file.parquet" + mock_create_resp.parsed.file_id = "file-new-123" + mock_create.return_value = mock_create_resp + + # Mock update response + mock_update_resp = Mock() + mock_update_resp.status_code = 200 + mock_update_resp.parsed = Mock() + mock_update_resp.parsed.file_size_bytes = 1234 + mock_update_resp.parsed.row_count = 10 + mock_update.return_value = mock_update_resp + + client = FileClient(mock_config) + # Mock the S3 PUT + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=200) + + buf = BytesIO(b"fake parquet data") + result = client.upload(graph_id, "Entity", buf) + + assert result.success is True + assert result.file_id == "file-new-123" + assert result.file_size == 1234 + assert result.row_count == 10 + assert result.table_name == "Entity" + + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_presigned_url_failure(self, mock_create, mock_config, graph_id): + """Test upload failure when presigned URL request fails.""" + mock_create_resp = Mock() + mock_create_resp.status_code = 500 + mock_create_resp.parsed = None + mock_create.return_value = mock_create_resp + + client = FileClient(mock_config) + result = client.upload(graph_id, "Entity", BytesIO(b"data")) + + assert result.success is False + assert "Failed to get upload URL" in result.error + + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_s3_failure(self, mock_create, mock_config, graph_id): + """Test upload failure when S3 PUT fails.""" + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = "http://s3.localhost/bucket/file.parquet" + mock_create_resp.parsed.file_id = "file-s3-fail" + mock_create.return_value = mock_create_resp + + client = FileClient(mock_config) + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=500) + + result = client.upload(graph_id, "Entity", BytesIO(b"data")) + + assert result.success is False + assert "S3 upload failed" in result.error + + @patch("robosystems_client.extensions.file_client.update_file") + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_status_update_failure( + self, mock_create, mock_update, mock_config, graph_id + ): + """Test upload failure when status update fails.""" + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = "http://s3.localhost/bucket/file.parquet" + mock_create_resp.parsed.file_id = "file-update-fail" + mock_create.return_value = mock_create_resp + + mock_update_resp = Mock() + mock_update_resp.status_code = 500 + mock_update_resp.parsed = None + mock_update.return_value = mock_update_resp + + client = FileClient(mock_config) + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=200) + + result = client.upload(graph_id, "Entity", BytesIO(b"data")) + + assert result.success is False + assert "Failed to complete file upload" in result.error + + def test_upload_no_token(self, mock_config, graph_id): + """Test upload fails without API key.""" + mock_config["token"] = None + client = FileClient(mock_config) + + result = client.upload(graph_id, "Entity", BytesIO(b"data")) + + assert result.success is False + assert "No API key" in result.error + + def test_upload_unsupported_type(self, mock_config, graph_id): + """Test upload fails with unsupported file type.""" + client = FileClient(mock_config) + + result = client.upload(graph_id, "Entity", 12345) + + assert result.success is False + assert "Unsupported file type" in result.error + + @patch("robosystems_client.extensions.file_client.update_file") + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_with_progress_callback( + self, mock_create, mock_update, mock_config, graph_id + ): + """Test upload calls progress callback at each step.""" + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = "http://s3.localhost/bucket/file.parquet" + mock_create_resp.parsed.file_id = "file-progress" + mock_create.return_value = mock_create_resp + + mock_update_resp = Mock() + mock_update_resp.status_code = 200 + mock_update_resp.parsed = Mock() + mock_update_resp.parsed.file_size_bytes = 100 + mock_update_resp.parsed.row_count = 5 + mock_update.return_value = mock_update_resp + + client = FileClient(mock_config) + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=200) + + progress_messages = [] + options = FileUploadOptions(on_progress=lambda msg: progress_messages.append(msg)) + + client.upload(graph_id, "Entity", BytesIO(b"data"), options) + + assert len(progress_messages) >= 3 # URL, upload, mark uploaded + + @patch("robosystems_client.extensions.file_client.update_file") + @patch("robosystems_client.extensions.file_client.create_file_upload") + def test_upload_with_s3_endpoint_override( + self, mock_create, mock_update, mock_config, graph_id + ): + """Test upload rewrites S3 URL when s3_endpoint_url is set.""" + mock_config["s3_endpoint_url"] = "http://localhost:4566" + + mock_create_resp = Mock() + mock_create_resp.status_code = 200 + mock_create_resp.parsed = Mock() + mock_create_resp.parsed.upload_url = ( + "https://s3.amazonaws.com/bucket/file.parquet?sig=abc" + ) + mock_create_resp.parsed.file_id = "file-s3-override" + mock_create.return_value = mock_create_resp + + mock_update_resp = Mock() + mock_update_resp.status_code = 200 + mock_update_resp.parsed = Mock() + mock_update_resp.parsed.file_size_bytes = 100 + mock_update_resp.parsed.row_count = 5 + mock_update.return_value = mock_update_resp + + client = FileClient(mock_config) + client._http_client = Mock() + client._http_client.put.return_value = Mock(status_code=200) + + client.upload(graph_id, "Entity", BytesIO(b"data")) + + # Verify S3 PUT was called with overridden URL + put_url = client._http_client.put.call_args[0][0] + assert "localhost:4566" in put_url + + +@pytest.mark.unit +class TestFileList: + """Test suite for FileClient.list method.""" + + @patch("robosystems_client.extensions.file_client.list_files") + def test_list_files(self, mock_list, mock_config, graph_id): + """Test listing files returns FileInfo objects.""" + mock_file = Mock() + mock_file.file_id = "file-1" + mock_file.file_name = "data.parquet" + mock_file.file_format = "parquet" + mock_file.size_bytes = 5000 + mock_file.row_count = 50 + mock_file.upload_status = "uploaded" + mock_file.table_name = "Entity" + mock_file.created_at = "2025-01-01T00:00:00Z" + mock_file.uploaded_at = "2025-01-01T00:01:00Z" + + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.files = [mock_file] + mock_list.return_value = mock_resp + + client = FileClient(mock_config) + files = client.list(graph_id) + + assert len(files) == 1 + assert files[0].file_id == "file-1" + assert files[0].file_name == "data.parquet" + assert files[0].size_bytes == 5000 + + @patch("robosystems_client.extensions.file_client.list_files") + def test_list_files_failure(self, mock_list, mock_config, graph_id): + """Test listing files returns empty list on failure.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_resp.parsed = None + mock_list.return_value = mock_resp + + client = FileClient(mock_config) + files = client.list(graph_id) + + assert files == [] + + def test_list_files_no_token(self, mock_config, graph_id): + """Test listing files returns empty list without token.""" + mock_config["token"] = None + client = FileClient(mock_config) + files = client.list(graph_id) + + assert files == [] + + +@pytest.mark.unit +class TestFileGet: + """Test suite for FileClient.get method.""" + + @patch("robosystems_client.extensions.file_client.get_file") + def test_get_file(self, mock_get, mock_config, graph_id): + """Test getting a specific file.""" + mock_file_data = Mock() + mock_file_data.file_id = "file-detail" + mock_file_data.file_name = "detail.parquet" + mock_file_data.file_format = "parquet" + mock_file_data.size_bytes = 8000 + mock_file_data.row_count = 80 + mock_file_data.upload_status = "uploaded" + mock_file_data.table_name = "Entity" + mock_file_data.created_at = "2025-01-01T00:00:00Z" + mock_file_data.uploaded_at = "2025-01-01T00:01:00Z" + mock_file_data.layers = {"s3": "uploaded", "duckdb": "staged"} + + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = mock_file_data + mock_get.return_value = mock_resp + + client = FileClient(mock_config) + info = client.get(graph_id, "file-detail") + + assert info is not None + assert info.file_id == "file-detail" + assert info.layers == {"s3": "uploaded", "duckdb": "staged"} + + @patch("robosystems_client.extensions.file_client.get_file") + def test_get_file_not_found(self, mock_get, mock_config, graph_id): + """Test getting a file that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = 404 + mock_resp.parsed = None + mock_get.return_value = mock_resp + + client = FileClient(mock_config) + info = client.get(graph_id, "nonexistent") + + assert info is None + + +@pytest.mark.unit +class TestFileDelete: + """Test suite for FileClient.delete method.""" + + @patch("robosystems_client.extensions.file_client.delete_file") + def test_delete_file(self, mock_delete, mock_config, graph_id): + """Test deleting a file.""" + mock_resp = Mock() + mock_resp.status_code = 204 + mock_delete.return_value = mock_resp + + client = FileClient(mock_config) + result = client.delete(graph_id, "file-to-delete") + + assert result is True + + @patch("robosystems_client.extensions.file_client.delete_file") + def test_delete_file_failure(self, mock_delete, mock_config, graph_id): + """Test delete failure returns False.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_delete.return_value = mock_resp + + client = FileClient(mock_config) + result = client.delete(graph_id, "file-to-delete") + + assert result is False + + @patch("robosystems_client.extensions.file_client.delete_file") + def test_delete_file_with_cascade(self, mock_delete, mock_config, graph_id): + """Test cascade delete passes parameter.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_delete.return_value = mock_resp + + client = FileClient(mock_config) + result = client.delete(graph_id, "file-cascade", cascade=True) + + assert result is True + call_kwargs = mock_delete.call_args[1] + assert call_kwargs["cascade"] is True diff --git a/tests/test_ledger_client.py b/tests/test_ledger_client.py new file mode 100644 index 0000000..7e075dd --- /dev/null +++ b/tests/test_ledger_client.py @@ -0,0 +1,533 @@ +"""Unit tests for LedgerClient.""" + +import pytest +from http import HTTPStatus +from unittest.mock import Mock, patch +from robosystems_client.extensions.ledger_client import LedgerClient + + +@pytest.mark.unit +class TestLedgerClientInit: + """Test suite for LedgerClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = LedgerClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.timeout == 60 + + def test_get_client_no_token(self, mock_config): + """Test _get_client raises without token.""" + mock_config["token"] = None + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="No API key"): + client._get_client() + + +@pytest.mark.unit +class TestLedgerEntity: + """Test suite for entity operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_entity") + def test_get_entity(self, mock_get, mock_config, graph_id): + """Test getting entity for a graph.""" + mock_parsed = Mock() + mock_parsed.entity_name = "ACME Corp" + mock_parsed.cik = "0001234567" + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = mock_parsed + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_entity(graph_id) + + assert result.entity_name == "ACME Corp" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_entity") + def test_get_entity_not_found(self, mock_get, mock_config, graph_id): + """Test getting entity that doesn't exist.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_entity(graph_id) + + assert result is None + + @patch("robosystems_client.extensions.ledger_client.get_ledger_entity") + def test_get_entity_error(self, mock_get, mock_config, graph_id): + """Test get entity raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Get entity failed"): + client.get_entity(graph_id) + + +@pytest.mark.unit +class TestLedgerAccounts: + """Test suite for account operations.""" + + @patch("robosystems_client.extensions.ledger_client.list_ledger_accounts") + def test_list_accounts(self, mock_list, mock_config, graph_id): + """Test listing accounts.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(accounts=[Mock(), Mock()]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_accounts(graph_id) + + assert len(result.accounts) == 2 + + @patch("robosystems_client.extensions.ledger_client.list_ledger_accounts") + def test_list_accounts_error(self, mock_list, mock_config, graph_id): + """Test list accounts raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="List accounts failed"): + client.list_accounts(graph_id) + + @patch("robosystems_client.extensions.ledger_client.get_ledger_account_tree") + def test_get_account_tree(self, mock_tree, mock_config, graph_id): + """Test getting account tree.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(root=Mock(children=[Mock(), Mock()])) + mock_tree.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_account_tree(graph_id) + + assert len(result.root.children) == 2 + + +@pytest.mark.unit +class TestLedgerTransactions: + """Test suite for transaction operations.""" + + @patch("robosystems_client.extensions.ledger_client.list_ledger_transactions") + def test_list_transactions(self, mock_list, mock_config, graph_id): + """Test listing transactions.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(transactions=[Mock()]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_transactions(graph_id) + + assert len(result.transactions) == 1 + + @patch("robosystems_client.extensions.ledger_client.list_ledger_transactions") + def test_list_transactions_with_filters(self, mock_list, mock_config, graph_id): + """Test listing transactions with date filters.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(transactions=[]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + client.list_transactions( + graph_id, + start_date="2025-01-01", + end_date="2025-03-31", + limit=50, + offset=0, + ) + + mock_list.assert_called_once() + call_kwargs = mock_list.call_args[1] + assert call_kwargs["start_date"] == "2025-01-01" + assert call_kwargs["end_date"] == "2025-03-31" + assert call_kwargs["limit"] == 50 + + @patch("robosystems_client.extensions.ledger_client.get_ledger_transaction") + def test_get_transaction(self, mock_get, mock_config, graph_id): + """Test getting a specific transaction.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(transaction_id="txn-123", entries=[Mock()]) + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_transaction(graph_id, "txn-123") + + assert result.transaction_id == "txn-123" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_transaction") + def test_get_transaction_error(self, mock_get, mock_config, graph_id): + """Test get transaction raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Get transaction failed"): + client.get_transaction(graph_id, "nonexistent") + + +@pytest.mark.unit +class TestLedgerTrialBalance: + """Test suite for trial balance operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_trial_balance") + def test_get_trial_balance(self, mock_tb, mock_config, graph_id): + """Test getting trial balance.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + total_debits=100000, total_credits=100000, accounts=[Mock()] + ) + mock_tb.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_trial_balance(graph_id) + + assert result.total_debits == 100000 + assert result.total_credits == 100000 + + @patch("robosystems_client.extensions.ledger_client.get_ledger_trial_balance") + def test_get_trial_balance_with_dates(self, mock_tb, mock_config, graph_id): + """Test getting trial balance with date range.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock() + mock_tb.return_value = mock_resp + + client = LedgerClient(mock_config) + client.get_trial_balance(graph_id, start_date="2025-01-01", end_date="2025-03-31") + + call_kwargs = mock_tb.call_args[1] + assert call_kwargs["start_date"] == "2025-01-01" + assert call_kwargs["end_date"] == "2025-03-31" + + @patch("robosystems_client.extensions.ledger_client.get_mapped_trial_balance") + def test_get_mapped_trial_balance(self, mock_mtb, mock_config, graph_id): + """Test getting mapped trial balance.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(elements=[Mock()]) + mock_mtb.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_mapped_trial_balance(graph_id, mapping_id="map-1") + + assert len(result.elements) == 1 + + +@pytest.mark.unit +class TestLedgerSummary: + """Test suite for summary operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_ledger_summary") + def test_get_summary(self, mock_summary, mock_config, graph_id): + """Test getting ledger summary.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + account_count=50, transaction_count=200, entity_name="ACME Corp" + ) + mock_summary.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_summary(graph_id) + + assert result.account_count == 50 + assert result.transaction_count == 200 + + +@pytest.mark.unit +class TestLedgerTaxonomy: + """Test suite for taxonomy operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_reporting_taxonomy") + def test_get_reporting_taxonomy(self, mock_tax, mock_config, graph_id): + """Test getting reporting taxonomy.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(taxonomy_id="tax_usgaap_reporting") + mock_tax.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_reporting_taxonomy(graph_id) + + assert result.taxonomy_id == "tax_usgaap_reporting" + + @patch("robosystems_client.extensions.ledger_client.list_structures") + def test_list_structures(self, mock_structs, mock_config, graph_id): + """Test listing reporting structures.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(structures=[Mock(), Mock(), Mock()]) + mock_structs.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_structures(graph_id) + + assert len(result.structures) == 3 + + @patch("robosystems_client.extensions.ledger_client.list_elements") + def test_list_elements(self, mock_elems, mock_config, graph_id): + """Test listing elements.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(elements=[Mock()]) + mock_elems.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_elements(graph_id, source="coa", limit=10) + + assert len(result.elements) == 1 + call_kwargs = mock_elems.call_args[1] + assert call_kwargs["source"] == "coa" + assert call_kwargs["limit"] == 10 + + +@pytest.mark.unit +class TestLedgerMappings: + """Test suite for mapping operations.""" + + @patch("robosystems_client.extensions.ledger_client.list_mappings") + def test_list_mappings(self, mock_list, mock_config, graph_id): + """Test listing mappings.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(mappings=[Mock()]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_mappings(graph_id) + + assert len(result.mappings) == 1 + + @patch("robosystems_client.extensions.ledger_client.get_mapping_detail") + def test_get_mapping_detail(self, mock_detail, mock_config, graph_id): + """Test getting mapping detail.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(mapping_id="map-1", associations=[Mock(), Mock()]) + mock_detail.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_mapping_detail(graph_id, "map-1") + + assert len(result.associations) == 2 + + @patch("robosystems_client.extensions.ledger_client.get_mapping_coverage") + def test_get_mapping_coverage(self, mock_cov, mock_config, graph_id): + """Test getting mapping coverage.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(total_elements=50, mapped_elements=45, coverage=0.9) + mock_cov.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_mapping_coverage(graph_id, "map-1") + + assert result.coverage == 0.9 + + @patch("robosystems_client.extensions.ledger_client.create_mapping_association") + def test_create_mapping(self, mock_create, mock_config, graph_id): + """Test creating a mapping association.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_create.return_value = mock_resp + + client = LedgerClient(mock_config) + client.create_mapping( + graph_id, "map-1", from_element_id="elem-a", to_element_id="elem-b" + ) + + mock_create.assert_called_once() + + @patch("robosystems_client.extensions.ledger_client.create_mapping_association") + def test_create_mapping_error(self, mock_create, mock_config, graph_id): + """Test create mapping raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_create.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Create mapping failed"): + client.create_mapping(graph_id, "map-1", "elem-a", "elem-b") + + @patch("robosystems_client.extensions.ledger_client.delete_mapping_association") + def test_delete_mapping(self, mock_delete, mock_config, graph_id): + """Test deleting a mapping association.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NO_CONTENT + mock_delete.return_value = mock_resp + + client = LedgerClient(mock_config) + client.delete_mapping(graph_id, "map-1", "assoc-1") + + mock_delete.assert_called_once() + + @patch("robosystems_client.extensions.ledger_client.auto_map_elements") + def test_auto_map(self, mock_auto, mock_config, graph_id): + """Test triggering auto-mapping.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.ACCEPTED + mock_resp.parsed = {"operation_id": "op-auto-1"} + mock_auto.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.auto_map(graph_id, "map-1") + + assert result["operation_id"] == "op-auto-1" + + @patch("robosystems_client.extensions.ledger_client.auto_map_elements") + def test_auto_map_error(self, mock_auto, mock_config, graph_id): + """Test auto-map raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_auto.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Auto-map failed"): + client.auto_map(graph_id, "map-1") + + @patch("robosystems_client.extensions.ledger_client.create_structure") + def test_create_mapping_structure(self, mock_create_struct, mock_config, graph_id): + """Test creating a mapping structure.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_resp.parsed = Mock(structure_id="struct-new") + mock_create_struct.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.create_mapping_structure(graph_id, name="Custom Mapping") + + assert result.structure_id == "struct-new" + + +@pytest.mark.unit +class TestLedgerSchedules: + """Test suite for schedule operations.""" + + @patch("robosystems_client.extensions.ledger_client.list_schedules") + def test_list_schedules(self, mock_list, mock_config, graph_id): + """Test listing schedules.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(schedules=[Mock()]) + mock_list.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.list_schedules(graph_id) + + assert len(result.schedules) == 1 + + @patch("robosystems_client.extensions.ledger_client.get_schedule_facts") + def test_get_schedule_facts(self, mock_facts, mock_config, graph_id): + """Test getting schedule facts.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(facts=[Mock(), Mock()]) + mock_facts.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_schedule_facts(graph_id, "struct-1") + + assert len(result.facts) == 2 + + @patch("robosystems_client.extensions.ledger_client.get_period_close_status") + def test_get_period_close_status(self, mock_status, mock_config, graph_id): + """Test getting period close status.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(status="open", schedules=[Mock()]) + mock_status.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_period_close_status( + graph_id, period_start="2025-01-01", period_end="2025-03-31" + ) + + assert result.status == "open" + + @patch("robosystems_client.extensions.ledger_client.create_closing_entry") + def test_create_closing_entry(self, mock_close, mock_config, graph_id): + """Test creating a closing entry.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_resp.parsed = Mock(transaction_id="txn-close-1") + mock_close.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.create_closing_entry( + graph_id, + structure_id="struct-1", + posting_date="2025-03-31", + period_start="2025-01-01", + period_end="2025-03-31", + memo="Q1 close", + ) + + assert result.transaction_id == "txn-close-1" + + @patch("robosystems_client.extensions.ledger_client.create_closing_entry") + def test_create_closing_entry_error(self, mock_close, mock_config, graph_id): + """Test create closing entry raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_close.return_value = mock_resp + + client = LedgerClient(mock_config) + + with pytest.raises(RuntimeError, match="Create closing entry failed"): + client.create_closing_entry( + graph_id, "struct-1", "2025-03-31", "2025-01-01", "2025-03-31" + ) + + +@pytest.mark.unit +class TestLedgerClosingBook: + """Test suite for closing book operations.""" + + @patch("robosystems_client.extensions.ledger_client.get_closing_book_structures") + def test_get_closing_book_structures(self, mock_cb, mock_config, graph_id): + """Test getting closing book structures.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(categories=[Mock(), Mock()]) + mock_cb.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_closing_book_structures(graph_id) + + assert len(result.categories) == 2 + + @patch("robosystems_client.extensions.ledger_client.get_account_rollups") + def test_get_account_rollups(self, mock_rollups, mock_config, graph_id): + """Test getting account rollups.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(rollups=[Mock()]) + mock_rollups.return_value = mock_resp + + client = LedgerClient(mock_config) + result = client.get_account_rollups( + graph_id, mapping_id="map-1", start_date="2025-01-01" + ) + + assert len(result.rollups) == 1 diff --git a/tests/test_materialization_client.py b/tests/test_materialization_client.py new file mode 100644 index 0000000..a3fa61a --- /dev/null +++ b/tests/test_materialization_client.py @@ -0,0 +1,340 @@ +"""Unit tests for MaterializationClient.""" + +import pytest +from unittest.mock import Mock, patch +from robosystems_client.extensions.materialization_client import ( + MaterializationClient, + MaterializationOptions, + MaterializationResult, + MaterializationStatus, +) +from robosystems_client.extensions.operation_client import ( + OperationResult, + OperationStatus, +) + + +@pytest.mark.unit +class TestMaterializationDataclasses: + """Test suite for materialization-related dataclasses.""" + + def test_materialization_options_defaults(self): + """Test MaterializationOptions default values.""" + options = MaterializationOptions() + + assert options.ignore_errors is True + assert options.rebuild is False + assert options.force is False + assert options.materialize_embeddings is False + assert options.on_progress is None + assert options.timeout == 600 + + def test_materialization_options_custom(self): + """Test MaterializationOptions with custom values.""" + progress_fn = Mock() + options = MaterializationOptions( + ignore_errors=False, + rebuild=True, + force=True, + materialize_embeddings=True, + on_progress=progress_fn, + timeout=300, + ) + + assert options.ignore_errors is False + assert options.rebuild is True + assert options.force is True + assert options.materialize_embeddings is True + assert options.on_progress is progress_fn + assert options.timeout == 300 + + def test_materialization_result(self): + """Test MaterializationResult dataclass.""" + result = MaterializationResult( + status="success", + was_stale=True, + stale_reason="New files uploaded", + tables_materialized=["Entity", "Transaction"], + total_rows=1500, + execution_time_ms=3000.0, + message="Graph materialized successfully", + ) + + assert result.status == "success" + assert result.was_stale is True + assert result.stale_reason == "New files uploaded" + assert len(result.tables_materialized) == 2 + assert result.total_rows == 1500 + assert result.execution_time_ms == 3000.0 + assert result.success is True + assert result.error is None + + def test_materialization_result_with_error(self): + """Test MaterializationResult with error.""" + result = MaterializationResult( + status="failed", + was_stale=False, + stale_reason=None, + tables_materialized=[], + total_rows=0, + execution_time_ms=0, + message="Failed to materialize", + success=False, + error="Connection timeout", + ) + + assert result.success is False + assert result.error == "Connection timeout" + + def test_materialization_status(self): + """Test MaterializationStatus dataclass.""" + status = MaterializationStatus( + graph_id="graph-123", + is_stale=True, + stale_reason="Files uploaded since last materialization", + stale_since="2025-01-15T10:00:00Z", + last_materialized_at="2025-01-14T08:00:00Z", + materialization_count=5, + hours_since_materialization=26.0, + message="Graph is stale", + ) + + assert status.graph_id == "graph-123" + assert status.is_stale is True + assert status.materialization_count == 5 + assert status.hours_since_materialization == 26.0 + + +@pytest.mark.unit +class TestMaterializationClientInit: + """Test suite for MaterializationClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = MaterializationClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client._operation_client is None + + def test_operation_client_lazy_creation(self, mock_config): + """Test that operation client is created lazily.""" + client = MaterializationClient(mock_config) + + assert client._operation_client is None + op_client = client.operation_client + assert op_client is not None + # Second access returns same instance + assert client.operation_client is op_client + + +@pytest.mark.unit +class TestMaterialize: + """Test suite for MaterializationClient.materialize method.""" + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_success(self, mock_mat, mock_config, graph_id): + """Test successful materialization.""" + # Mock initial response with operation_id + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock(operation_id="op-mat-123") + mock_mat.return_value = mock_resp + + # Mock the operation client monitoring + op_result = OperationResult( + operation_id="op-mat-123", + status=OperationStatus.COMPLETED, + result={ + "was_stale": True, + "stale_reason": "New files", + "tables_materialized": ["Entity"], + "total_rows": 100, + "execution_time_ms": 2000.0, + "message": "Done", + }, + execution_time_ms=2000.0, + ) + + client = MaterializationClient(mock_config) + client._operation_client = Mock() + client._operation_client.monitor_operation.return_value = op_result + + result = client.materialize(graph_id) + + assert result.success is True + assert result.status == "success" + assert result.tables_materialized == ["Entity"] + assert result.total_rows == 100 + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_api_failure(self, mock_mat, mock_config, graph_id): + """Test materialization when API returns error.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_resp.parsed = None + mock_resp.content = b'{"detail": "Internal error"}' + mock_mat.return_value = mock_resp + + client = MaterializationClient(mock_config) + result = client.materialize(graph_id) + + assert result.success is False + assert result.status == "failed" + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_operation_failed(self, mock_mat, mock_config, graph_id): + """Test materialization when operation fails.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock(operation_id="op-fail") + mock_mat.return_value = mock_resp + + op_result = OperationResult( + operation_id="op-fail", + status=OperationStatus.FAILED, + error="Dagster job failed", + execution_time_ms=1000.0, + ) + + client = MaterializationClient(mock_config) + client._operation_client = Mock() + client._operation_client.monitor_operation.return_value = op_result + + result = client.materialize(graph_id) + + assert result.success is False + assert result.error == "Dagster job failed" + + def test_materialize_no_token(self, mock_config, graph_id): + """Test materialize fails without API key.""" + mock_config["token"] = None + client = MaterializationClient(mock_config) + + result = client.materialize(graph_id) + + assert result.success is False + assert "No API key" in result.error + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_with_progress(self, mock_mat, mock_config, graph_id): + """Test materialize calls progress callback.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock(operation_id="op-progress") + mock_mat.return_value = mock_resp + + op_result = OperationResult( + operation_id="op-progress", + status=OperationStatus.COMPLETED, + result={ + "was_stale": False, + "tables_materialized": [], + "total_rows": 0, + "execution_time_ms": 500.0, + "message": "Already up to date", + }, + execution_time_ms=500.0, + ) + + progress_messages = [] + options = MaterializationOptions( + on_progress=lambda msg: progress_messages.append(msg) + ) + + client = MaterializationClient(mock_config) + client._operation_client = Mock() + client._operation_client.monitor_operation.return_value = op_result + + client.materialize(graph_id, options) + + assert len(progress_messages) >= 2 # At least "Submitting" and "queued" + + @patch("robosystems_client.extensions.materialization_client.materialize_graph") + def test_materialize_with_rebuild(self, mock_mat, mock_config, graph_id): + """Test materialize passes rebuild option.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock(operation_id="op-rebuild") + mock_mat.return_value = mock_resp + + op_result = OperationResult( + operation_id="op-rebuild", + status=OperationStatus.COMPLETED, + result={ + "was_stale": True, + "tables_materialized": ["Entity"], + "total_rows": 50, + "execution_time_ms": 5000.0, + "message": "Rebuilt", + }, + execution_time_ms=5000.0, + ) + + client = MaterializationClient(mock_config) + client._operation_client = Mock() + client._operation_client.monitor_operation.return_value = op_result + + options = MaterializationOptions(rebuild=True, force=True) + result = client.materialize(graph_id, options) + + assert result.success is True + # Verify the request body had rebuild=True + call_kwargs = mock_mat.call_args[1] + assert call_kwargs["body"].rebuild is True + assert call_kwargs["body"].force is True + + +@pytest.mark.unit +class TestMaterializationStatus: + """Test suite for MaterializationClient.status method.""" + + @patch( + "robosystems_client.extensions.materialization_client.get_materialization_status" + ) + def test_get_status(self, mock_status, mock_config, graph_id): + """Test getting materialization status.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock( + graph_id=graph_id, + is_stale=True, + stale_reason="New files uploaded", + stale_since="2025-01-15T10:00:00Z", + last_materialized_at="2025-01-14T08:00:00Z", + materialization_count=3, + hours_since_materialization=26.0, + message="Graph is stale", + ) + mock_status.return_value = mock_resp + + client = MaterializationClient(mock_config) + status = client.status(graph_id) + + assert status is not None + assert status.is_stale is True + assert status.materialization_count == 3 + + @patch( + "robosystems_client.extensions.materialization_client.get_materialization_status" + ) + def test_get_status_failure(self, mock_status, mock_config, graph_id): + """Test status returns None on failure.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_resp.parsed = None + mock_status.return_value = mock_resp + + client = MaterializationClient(mock_config) + status = client.status(graph_id) + + assert status is None + + def test_status_no_token(self, mock_config, graph_id): + """Test status returns None without token.""" + mock_config["token"] = None + client = MaterializationClient(mock_config) + status = client.status(graph_id) + + assert status is None diff --git a/tests/test_operation_client_ops.py b/tests/test_operation_client_ops.py new file mode 100644 index 0000000..115ddc4 --- /dev/null +++ b/tests/test_operation_client_ops.py @@ -0,0 +1,350 @@ +"""Unit tests for OperationClient operational logic. + +Covers: monitor_operation (SSE event handling, completion, error, cancel, +timeout, progress callbacks, queue updates), get_operation_status, +cancel_operation, close_operation, close_all. + +Dataclass and enum tests already exist in tests/test_operation_client.py. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from robosystems_client.extensions.operation_client import ( + OperationClient, + OperationStatus, + OperationProgress, + MonitorOptions, +) +from robosystems_client.extensions.sse_client import SSEClient + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _fire_events(sse_client_instance, events): + """Simulate SSE events by directly calling registered listeners. + + Each event is a (event_type, data) tuple. + """ + for event_type, data in events: + sse_client_instance.emit(event_type, data) + + +# ── monitor_operation ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestMonitorOperation: + """Test OperationClient.monitor_operation via mocked SSE.""" + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_completed(self, mock_sleep, MockSSE, mock_config): + """Test monitoring an operation that completes successfully.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + # Simulate: started → progress → completed + listeners["operation_started"]({"agent": "financial"}) + listeners["operation_progress"]({"message": "Processing", "percentage": 50}) + listeners["operation_completed"]( + {"result": {"rows": 100}, "execution_time_ms": 2000} + ) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + result = client.monitor_operation("op-123") + + assert result.status == OperationStatus.COMPLETED + assert result.result == {"rows": 100} + assert result.execution_time_ms == 2000 + assert len(result.progress) == 1 + assert result.progress[0].message == "Processing" + assert result.progress[0].percentage == 50 + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_failed(self, mock_sleep, MockSSE, mock_config): + """Test monitoring an operation that fails.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_error"]({"message": "Database connection lost"}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + result = client.monitor_operation("op-fail") + + assert result.status == OperationStatus.FAILED + assert result.error == "Database connection lost" + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_cancelled(self, mock_sleep, MockSSE, mock_config): + """Test monitoring a cancelled operation.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_cancelled"]() + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + result = client.monitor_operation("op-cancel") + + assert result.status == OperationStatus.CANCELLED + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_with_progress_callback(self, mock_sleep, MockSSE, mock_config): + """Test that progress callback is invoked.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_progress"]( + {"message": "Step 1", "percentage": 25, "current_step": 1, "total_steps": 4} + ) + listeners["operation_progress"]( + {"message": "Step 2", "percentage": 50, "current_step": 2, "total_steps": 4} + ) + listeners["operation_completed"]({"result": {}}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + progress_updates = [] + options = MonitorOptions(on_progress=lambda p: progress_updates.append(p)) + + client = OperationClient(mock_config) + client.monitor_operation("op-progress", options) + + assert len(progress_updates) == 2 + assert isinstance(progress_updates[0], OperationProgress) + assert progress_updates[0].message == "Step 1" + assert progress_updates[1].current_step == 2 + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_with_queue_update(self, mock_sleep, MockSSE, mock_config): + """Test queue update callback.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["queue_update"]({"position": 3, "estimated_wait_seconds": 15}) + listeners["operation_completed"]({"result": {}}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + queue_updates = [] + options = MonitorOptions( + on_queue_update=lambda pos, wait: queue_updates.append((pos, wait)) + ) + + client = OperationClient(mock_config) + result = client.monitor_operation("op-queue", options) + + assert queue_updates == [(3, 15)] + assert result.status == OperationStatus.COMPLETED + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_error_uses_fallback_key(self, mock_sleep, MockSSE, mock_config): + """Test error event falls back to 'error' key when 'message' absent.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_error"]({"error": "Timeout exceeded"}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + result = client.monitor_operation("op-err") + + assert result.status == OperationStatus.FAILED + assert result.error == "Timeout exceeded" + + @patch("robosystems_client.extensions.operation_client.SSEClient") + @patch("time.sleep") + def test_monitor_cleanup_on_completion(self, mock_sleep, MockSSE, mock_config): + """Test that SSE client is cleaned up after completion.""" + fake_sse = MagicMock(spec=SSEClient) + listeners = {} + + def capture_on(event, handler): + listeners[event] = handler + + fake_sse.on.side_effect = capture_on + + def fake_connect(op_id): + listeners["operation_completed"]({"result": {}}) + + fake_sse.connect.side_effect = fake_connect + MockSSE.return_value = fake_sse + + client = OperationClient(mock_config) + client.monitor_operation("op-cleanup") + + # Operation should be removed from active_operations + assert "op-cleanup" not in client.active_operations + # SSE client should have been closed + fake_sse.close.assert_called() + + +# ── get_operation_status ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestGetOperationStatus: + """Test OperationClient.get_operation_status.""" + + @patch("robosystems_client.api.operations.get_operation_status.sync_detailed") + def test_get_status_success(self, mock_get, mock_config): + """Test successful status retrieval.""" + mock_resp = Mock() + mock_resp.parsed = Mock() + mock_resp.parsed.status = "running" + mock_resp.parsed.progress = 50 + mock_resp.parsed.result = None + mock_resp.parsed.error = None + mock_get.return_value = mock_resp + + client = OperationClient(mock_config) + status = client.get_operation_status("op-status") + + assert status["status"] == "running" + assert status["progress"] == 50 + + @patch("robosystems_client.api.operations.get_operation_status.sync_detailed") + def test_get_status_error(self, mock_get, mock_config): + """Test status retrieval on error.""" + mock_get.side_effect = Exception("Network error") + + client = OperationClient(mock_config) + status = client.get_operation_status("op-err") + + assert status["status"] == "error" + assert "Network error" in status["error"] + + @patch("robosystems_client.api.operations.get_operation_status.sync_detailed") + def test_get_status_no_parsed(self, mock_get, mock_config): + """Test status when response has no parsed data.""" + mock_resp = Mock() + mock_resp.parsed = None + mock_get.return_value = mock_resp + + client = OperationClient(mock_config) + status = client.get_operation_status("op-none") + + assert status["status"] == "unknown" + + +# ── cancel_operation ───────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCancelOperation: + """Test OperationClient.cancel_operation.""" + + @patch("robosystems_client.api.operations.cancel_operation.sync_detailed") + def test_cancel_success(self, mock_cancel, mock_config): + """Test successful cancellation.""" + mock_resp = Mock() + mock_resp.parsed = Mock() + mock_resp.parsed.cancelled = True + mock_cancel.return_value = mock_resp + + client = OperationClient(mock_config) + result = client.cancel_operation("op-cancel") + + assert result is True + + @patch("robosystems_client.api.operations.cancel_operation.sync_detailed") + def test_cancel_failure(self, mock_cancel, mock_config): + """Test cancellation failure.""" + mock_cancel.side_effect = Exception("Not found") + + client = OperationClient(mock_config) + result = client.cancel_operation("op-missing") + + assert result is False + + +# ── close_operation / close_all ────────────────────────────────────── + + +@pytest.mark.unit +class TestCloseOperations: + """Test close_operation and close_all.""" + + def test_close_operation(self, mock_config): + """Test closing a specific operation monitor.""" + client = OperationClient(mock_config) + mock_sse = Mock() + client.active_operations["op-1"] = mock_sse + + client.close_operation("op-1") + + mock_sse.close.assert_called_once() + assert "op-1" not in client.active_operations + + def test_close_operation_nonexistent(self, mock_config): + """Test closing a nonexistent operation is a no-op.""" + client = OperationClient(mock_config) + client.close_operation("op-missing") # Should not raise + + def test_close_all_multiple(self, mock_config): + """Test closing all active operations.""" + client = OperationClient(mock_config) + mock_sse1 = Mock() + mock_sse2 = Mock() + client.active_operations["op-1"] = mock_sse1 + client.active_operations["op-2"] = mock_sse2 + + client.close_all() + + mock_sse1.close.assert_called_once() + mock_sse2.close.assert_called_once() + assert len(client.active_operations) == 0 diff --git a/tests/test_query_client_ops.py b/tests/test_query_client_ops.py new file mode 100644 index 0000000..f4e2f4f --- /dev/null +++ b/tests/test_query_client_ops.py @@ -0,0 +1,524 @@ +"""Unit tests for QueryClient operational logic. + +Covers: execute_query (sync dict response, attrs response, queued response, +NDJSON streaming, error handling), _parse_ndjson_response, +_wait_for_query_completion, query convenience, query_batch, stream_query. + +Dataclass tests already exist in tests/test_query_client.py. +""" + +import pytest +from unittest.mock import Mock, patch +from robosystems_client.extensions.query_client import ( + QueryClient, + QueryRequest, + QueryOptions, + QueryResult, + QueuedQueryError, +) + + +# ── execute_query — sync dict response ─────────────────────────────── + + +@pytest.mark.unit +class TestExecuteQuerySync: + """Test execute_query with immediate sync responses.""" + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_dict_response(self, mock_exec, mock_config, graph_id): + """Test execute_query with a dict response.""" + mock_resp = Mock() + mock_resp.parsed = { + "data": [{"name": "ACME"}], + "columns": ["name"], + "row_count": 1, + "execution_time_ms": 42, + "timestamp": "2025-01-15T10:00:00Z", + } + # No NDJSON headers + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n.name") + result = client.execute_query(graph_id, request) + + assert isinstance(result, QueryResult) + assert result.data == [{"name": "ACME"}] + assert result.columns == ["name"] + assert result.row_count == 1 + assert result.execution_time_ms == 42 + assert result.graph_id == graph_id + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_attrs_response(self, mock_exec, mock_config, graph_id): + """Test execute_query with an attrs object response.""" + mock_data_item = Mock() + mock_data_item.to_dict.return_value = {"name": "Beta Corp"} + + mock_parsed = Mock() + mock_parsed.data = [mock_data_item] + mock_parsed.columns = ["name"] + mock_parsed.row_count = 1 + mock_parsed.execution_time_ms = 30 + mock_parsed.timestamp = "2025-01-15T10:00:00Z" + + mock_resp = Mock() + mock_resp.parsed = mock_parsed + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n.name") + result = client.execute_query(graph_id, request) + + assert isinstance(result, QueryResult) + assert result.data == [{"name": "Beta Corp"}] + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_attrs_with_additional_properties(self, mock_exec, mock_config, graph_id): + """Test attrs response items using additional_properties fallback.""" + mock_data_item = Mock(spec=[]) + mock_data_item.additional_properties = {"ticker": "ACM"} + # No to_dict method + + mock_parsed = Mock() + mock_parsed.data = [mock_data_item] + mock_parsed.columns = ["ticker"] + mock_parsed.row_count = 1 + mock_parsed.execution_time_ms = 10 + mock_parsed.timestamp = "2025-01-15T10:00:00Z" + + mock_resp = Mock() + mock_resp.parsed = mock_parsed + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n.ticker") + result = client.execute_query(graph_id, request) + + assert result.data == [{"ticker": "ACM"}] + + +# ── execute_query — queued response ────────────────────────────────── + + +@pytest.mark.unit +class TestExecuteQueryQueued: + """Test execute_query with queued responses.""" + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_queued_max_wait_zero_raises(self, mock_exec, mock_config, graph_id): + """Test queued response with max_wait=0 raises QueuedQueryError.""" + mock_resp = Mock() + mock_resp.parsed = { + "status": "queued", + "operation_id": "op-q1", + "queue_position": 3, + "estimated_wait_seconds": 15, + "message": "Query queued", + } + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + options = QueryOptions(max_wait=0) + + with pytest.raises(QueuedQueryError) as exc_info: + client.execute_query(graph_id, request, options) + + assert exc_info.value.queue_info.operation_id == "op-q1" + assert exc_info.value.queue_info.queue_position == 3 + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_queued_calls_queue_update_callback(self, mock_exec, mock_config, graph_id): + """Test queued response invokes on_queue_update.""" + mock_resp = Mock() + mock_resp.parsed = { + "status": "queued", + "operation_id": "op-q2", + "queue_position": 5, + "estimated_wait_seconds": 30, + "message": "Queued", + } + mock_resp.headers = {"content-type": "application/json"} + mock_resp.status_code = 200 + mock_exec.return_value = mock_resp + + queue_updates = [] + options = QueryOptions( + max_wait=0, + on_queue_update=lambda pos, wait: queue_updates.append((pos, wait)), + ) + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + + with pytest.raises(QueuedQueryError): + client.execute_query(graph_id, request, options) + + assert queue_updates == [(5, 30)] + + +# ── execute_query — error handling ─────────────────────────────────── + + +@pytest.mark.unit +class TestExecuteQueryErrors: + """Test error handling in execute_query.""" + + def test_no_token_raises(self, mock_config, graph_id): + """Test execute_query fails without token.""" + mock_config["token"] = None + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + + with pytest.raises(Exception, match="Authentication failed|No API key"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_auth_error_wrapped(self, mock_exec, mock_config, graph_id): + """Test 401/403 errors are wrapped as auth errors.""" + mock_exec.side_effect = Exception("401 Unauthorized") + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + + with pytest.raises(Exception, match="Authentication failed"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_generic_error_wrapped(self, mock_exec, mock_config, graph_id): + """Test generic errors are wrapped.""" + mock_exec.side_effect = Exception("Connection timeout") + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + + with pytest.raises(Exception, match="Query execution failed"): + client.execute_query(graph_id, request) + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_http_error_response(self, mock_exec, mock_config, graph_id): + """Test HTTP 4xx/5xx with error body.""" + mock_resp = Mock() + mock_resp.parsed = None + mock_resp.status_code = 400 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.content = b'{"detail": "Invalid query syntax"}' + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="BAD QUERY") + + with pytest.raises(Exception, match="Invalid query syntax"): + client.execute_query(graph_id, request) + + +# ── NDJSON parsing ─────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestNDJSON: + """Test _parse_ndjson_response.""" + + def test_single_chunk(self, mock_config, graph_id): + """Test parsing a single NDJSON chunk.""" + content = '{"columns": ["name"], "rows": [["ACME"]], "execution_time_ms": 10}\n' + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.columns == ["name"] + assert result.data == [["ACME"]] + assert result.row_count == 1 + assert result.execution_time_ms == 10 + + def test_multiple_chunks(self, mock_config, graph_id): + """Test parsing multiple NDJSON chunks.""" + content = ( + '{"columns": ["id"], "rows": [[1], [2]], "execution_time_ms": 5}\n' + '{"rows": [[3], [4]], "execution_time_ms": 12}\n' + ) + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.columns == ["id"] + assert result.data == [[1], [2], [3], [4]] + assert result.row_count == 4 + assert result.execution_time_ms == 12 # max of 5, 12 + + def test_data_key_fallback(self, mock_config, graph_id): + """Test NDJSON with 'data' key instead of 'rows'.""" + content = '{"columns": ["x"], "data": [{"x": 1}], "execution_time_ms": 3}\n' + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.data == [{"x": 1}] + + def test_empty_lines_skipped(self, mock_config, graph_id): + """Test that empty lines in NDJSON are skipped.""" + content = '{"columns": ["a"], "rows": [[1]]}\n\n\n' + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.row_count == 1 + + def test_invalid_json_raises(self, mock_config, graph_id): + """Test invalid JSON in NDJSON raises error.""" + content = "not json at all\n" + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + + with pytest.raises(Exception, match="Failed to parse NDJSON"): + client._parse_ndjson_response(mock_resp, graph_id) + + def test_no_columns_defaults_to_empty(self, mock_config, graph_id): + """Test NDJSON without columns field defaults to empty list.""" + content = '{"rows": [[1]]}\n' + + mock_resp = Mock() + mock_resp.content = content.encode() + + client = QueryClient(mock_config) + result = client._parse_ndjson_response(mock_resp, graph_id) + + assert result.columns == [] + + +# ── NDJSON detection in execute_query ──────────────────────────────── + + +@pytest.mark.unit +class TestNDJSONDetection: + """Test that NDJSON responses are detected and parsed.""" + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_ndjson_content_type_detected(self, mock_exec, mock_config, graph_id): + """Test NDJSON response detected by content-type header.""" + mock_resp = Mock() + mock_resp.parsed = None + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/x-ndjson"} + mock_resp.content = b'{"columns": ["n"], "rows": [["x"]]}\n' + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + result = client.execute_query(graph_id, request) + + assert isinstance(result, QueryResult) + assert result.data == [["x"]] + + @patch("robosystems_client.extensions.query_client.execute_cypher_query") + def test_ndjson_stream_format_header(self, mock_exec, mock_config, graph_id): + """Test NDJSON detected by x-stream-format header.""" + mock_resp = Mock() + mock_resp.parsed = None + mock_resp.status_code = 200 + mock_resp.headers = { + "content-type": "application/json", + "x-stream-format": "ndjson", + } + mock_resp.content = b'{"columns": ["a"], "rows": [[1]]}\n' + mock_exec.return_value = mock_resp + + client = QueryClient(mock_config) + request = QueryRequest(query="MATCH (n) RETURN n") + result = client.execute_query(graph_id, request) + + assert result.data == [[1]] + + +# ── query convenience ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQueryConvenience: + """Test query() convenience method.""" + + @patch.object(QueryClient, "execute_query") + def test_query_returns_result(self, mock_exec, mock_config, graph_id): + """Test query() returns QueryResult directly.""" + mock_exec.return_value = QueryResult( + data=[{"n": 1}], columns=["n"], row_count=1, execution_time_ms=5 + ) + + client = QueryClient(mock_config) + result = client.query(graph_id, "MATCH (n) RETURN n") + + assert result.row_count == 1 + + @patch.object(QueryClient, "execute_query") + def test_query_collects_iterator(self, mock_exec, mock_config, graph_id): + """Test query() collects iterator results into QueryResult.""" + mock_exec.return_value = iter([{"n": 1}, {"n": 2}]) + + client = QueryClient(mock_config) + result = client.query(graph_id, "MATCH (n) RETURN n") + + assert result.row_count == 2 + assert result.data == [{"n": 1}, {"n": 2}] + + @patch.object(QueryClient, "execute_query") + def test_query_passes_parameters(self, mock_exec, mock_config, graph_id): + """Test query() forwards parameters.""" + mock_exec.return_value = QueryResult( + data=[], columns=[], row_count=0, execution_time_ms=0 + ) + + client = QueryClient(mock_config) + client.query(graph_id, "MATCH (n) WHERE n.x > $v RETURN n", {"v": 10}) + + call_args = mock_exec.call_args + assert call_args[0][1].parameters == {"v": 10} + + +# ── query_batch ────────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQueryBatch: + """Test query_batch method.""" + + @patch.object(QueryClient, "query") + def test_batch_success(self, mock_query, mock_config, graph_id): + """Test batch execution of multiple queries.""" + mock_query.side_effect = [ + QueryResult( + data=[{"count": 5}], columns=["count"], row_count=1, execution_time_ms=10 + ), + QueryResult( + data=[{"count": 3}], columns=["count"], row_count=1, execution_time_ms=8 + ), + ] + + client = QueryClient(mock_config) + results = client.query_batch( + graph_id, + ["MATCH (n:Person) RETURN count(n)", "MATCH (c:Company) RETURN count(c)"], + ) + + assert len(results) == 2 + assert all(isinstance(r, QueryResult) for r in results) + + @patch.object(QueryClient, "query") + def test_batch_with_error(self, mock_query, mock_config, graph_id): + """Test batch handles individual query failures.""" + mock_query.side_effect = [ + QueryResult(data=[], columns=[], row_count=0, execution_time_ms=0), + Exception("Query 2 failed"), + ] + + client = QueryClient(mock_config) + results = client.query_batch(graph_id, ["query1", "query2"]) + + assert isinstance(results[0], QueryResult) + assert isinstance(results[1], dict) + assert "error" in results[1] + + def test_batch_length_mismatch(self, mock_config, graph_id): + """Test batch raises on mismatched query/params lengths.""" + client = QueryClient(mock_config) + + with pytest.raises(ValueError, match="same length"): + client.query_batch(graph_id, ["q1", "q2"], [{"a": 1}]) + + @patch.object(QueryClient, "query") + def test_batch_with_parameters(self, mock_query, mock_config, graph_id): + """Test batch passes parameters per query.""" + mock_query.return_value = QueryResult( + data=[], columns=[], row_count=0, execution_time_ms=0 + ) + + client = QueryClient(mock_config) + client.query_batch( + graph_id, + ["MATCH (n) WHERE n.x > $v RETURN n", "MATCH (n) WHERE n.y = $w RETURN n"], + [{"v": 10}, {"w": "abc"}], + ) + + assert mock_query.call_count == 2 + assert mock_query.call_args_list[0][0][2] == {"v": 10} + assert mock_query.call_args_list[1][0][2] == {"w": "abc"} + + +# ── stream_query ───────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestStreamQuery: + """Test stream_query method.""" + + @patch.object(QueryClient, "execute_query") + def test_stream_yields_from_iterator(self, mock_exec, mock_config, graph_id): + """Test stream_query yields items from iterator result.""" + mock_exec.return_value = iter([{"id": 1}, {"id": 2}, {"id": 3}]) + + client = QueryClient(mock_config) + items = list(client.stream_query(graph_id, "MATCH (n) RETURN n")) + + assert items == [{"id": 1}, {"id": 2}, {"id": 3}] + + @patch.object(QueryClient, "execute_query") + def test_stream_yields_from_query_result(self, mock_exec, mock_config, graph_id): + """Test stream_query yields items when result is QueryResult.""" + mock_exec.return_value = QueryResult( + data=[{"id": 1}, {"id": 2}], + columns=["id"], + row_count=2, + execution_time_ms=10, + ) + + client = QueryClient(mock_config) + items = list(client.stream_query(graph_id, "MATCH (n) RETURN n")) + + assert items == [{"id": 1}, {"id": 2}] + + @patch.object(QueryClient, "execute_query") + def test_stream_calls_progress(self, mock_exec, mock_config, graph_id): + """Test stream_query calls on_progress callback.""" + mock_exec.return_value = QueryResult( + data=[{"id": i} for i in range(5)], + columns=["id"], + row_count=5, + execution_time_ms=10, + ) + + progress_calls = [] + client = QueryClient(mock_config) + list( + client.stream_query( + graph_id, + "MATCH (n) RETURN n", + on_progress=lambda cur, tot: progress_calls.append((cur, tot)), + ) + ) + + assert len(progress_calls) == 5 + assert progress_calls[-1] == (5, 5) diff --git a/tests/test_report_client.py b/tests/test_report_client.py new file mode 100644 index 0000000..e99e0d6 --- /dev/null +++ b/tests/test_report_client.py @@ -0,0 +1,302 @@ +"""Unit tests for ReportClient.""" + +import pytest +from http import HTTPStatus +from unittest.mock import Mock, patch +from robosystems_client.extensions.report_client import ReportClient + + +@pytest.mark.unit +class TestReportClientInit: + """Test suite for ReportClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = ReportClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + assert client.timeout == 60 + + def test_get_client_no_token(self, mock_config): + """Test _get_client raises without token.""" + mock_config["token"] = None + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="No API key"): + client._get_client() + + +@pytest.mark.unit +class TestReportCreate: + """Test suite for ReportClient.create method.""" + + @patch("robosystems_client.extensions.report_client.create_report") + def test_create_report(self, mock_create, mock_config, graph_id): + """Test creating a report.""" + mock_parsed = Mock() + mock_parsed.report_id = "rpt-123" + mock_parsed.report_name = "Q1 Report" + + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_resp.parsed = mock_parsed + mock_create.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.create( + graph_id=graph_id, + name="Q1 Report", + mapping_id="map-1", + period_start="2025-01-01", + period_end="2025-03-31", + ) + + assert result.report_id == "rpt-123" + assert result.report_name == "Q1 Report" + + @patch("robosystems_client.extensions.report_client.create_report") + def test_create_report_with_options(self, mock_create, mock_config, graph_id): + """Test creating a report with all options.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.CREATED + mock_resp.parsed = Mock(report_id="rpt-full") + mock_create.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.create( + graph_id=graph_id, + name="Annual Report", + mapping_id="map-2", + period_start="2024-01-01", + period_end="2024-12-31", + taxonomy_id="tax_usgaap_reporting", + period_type="annual", + comparative=False, + ) + + assert result.report_id == "rpt-full" + mock_create.assert_called_once() + + @patch("robosystems_client.extensions.report_client.create_report") + def test_create_report_error(self, mock_create, mock_config, graph_id): + """Test create report raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_create.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Create report failed"): + client.create( + graph_id=graph_id, + name="Bad", + mapping_id="map-x", + period_start="bad", + period_end="bad", + ) + + +@pytest.mark.unit +class TestReportList: + """Test suite for ReportClient.list method.""" + + @patch("robosystems_client.extensions.report_client.list_reports") + def test_list_reports(self, mock_list, mock_config, graph_id): + """Test listing reports.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(reports=[Mock(), Mock()]) + mock_list.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.list(graph_id) + + assert len(result.reports) == 2 + + @patch("robosystems_client.extensions.report_client.list_reports") + def test_list_reports_error(self, mock_list, mock_config, graph_id): + """Test list reports raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_list.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="List reports failed"): + client.list(graph_id) + + +@pytest.mark.unit +class TestReportGet: + """Test suite for ReportClient.get method.""" + + @patch("robosystems_client.extensions.report_client.get_report") + def test_get_report(self, mock_get, mock_config, graph_id): + """Test getting a report.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(report_id="rpt-456", name="Q2 Report", structures=[Mock()]) + mock_get.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.get(graph_id, "rpt-456") + + assert result.report_id == "rpt-456" + assert len(result.structures) == 1 + + @patch("robosystems_client.extensions.report_client.get_report") + def test_get_report_error(self, mock_get, mock_config, graph_id): + """Test get report raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_get.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Get report failed"): + client.get(graph_id, "nonexistent") + + +@pytest.mark.unit +class TestReportStatement: + """Test suite for ReportClient.statement method.""" + + @patch("robosystems_client.extensions.report_client.get_statement") + def test_get_statement(self, mock_stmt, mock_config, graph_id): + """Test getting a financial statement.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock( + structure_type="income_statement", line_items=[Mock(), Mock()] + ) + mock_stmt.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.statement(graph_id, "rpt-123", "income_statement") + + assert result.structure_type == "income_statement" + assert len(result.line_items) == 2 + + @patch("robosystems_client.extensions.report_client.get_statement") + def test_get_statement_error(self, mock_stmt, mock_config, graph_id): + """Test statement raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_stmt.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Get statement failed"): + client.statement(graph_id, "rpt-bad", "income_statement") + + +@pytest.mark.unit +class TestReportRegenerate: + """Test suite for ReportClient.regenerate method.""" + + @patch("robosystems_client.extensions.report_client.regenerate_report") + def test_regenerate_report(self, mock_regen, mock_config, graph_id): + """Test regenerating a report.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(report_id="rpt-regen") + mock_regen.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.regenerate( + graph_id, "rpt-regen", period_start="2025-04-01", period_end="2025-06-30" + ) + + assert result.report_id == "rpt-regen" + + @patch("robosystems_client.extensions.report_client.regenerate_report") + def test_regenerate_report_error(self, mock_regen, mock_config, graph_id): + """Test regenerate raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.BAD_REQUEST + mock_regen.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Regenerate report failed"): + client.regenerate(graph_id, "rpt-bad", "bad", "bad") + + +@pytest.mark.unit +class TestReportDelete: + """Test suite for ReportClient.delete method.""" + + @patch("robosystems_client.extensions.report_client.delete_report") + def test_delete_report(self, mock_delete, mock_config, graph_id): + """Test deleting a report.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NO_CONTENT + mock_delete.return_value = mock_resp + + client = ReportClient(mock_config) + client.delete(graph_id, "rpt-del") # Should not raise + + mock_delete.assert_called_once() + + @patch("robosystems_client.extensions.report_client.delete_report") + def test_delete_report_error(self, mock_delete, mock_config, graph_id): + """Test delete raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + mock_delete.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Delete report failed"): + client.delete(graph_id, "rpt-err") + + +@pytest.mark.unit +class TestReportShare: + """Test suite for ReportClient.share method.""" + + @patch("robosystems_client.extensions.report_client.share_report") + def test_share_report(self, mock_share, mock_config, graph_id): + """Test sharing a report.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.OK + mock_resp.parsed = Mock(shared_count=3) + mock_share.return_value = mock_resp + + client = ReportClient(mock_config) + result = client.share(graph_id, "rpt-share", "pub-list-1") + + assert result.shared_count == 3 + + @patch("robosystems_client.extensions.report_client.share_report") + def test_share_report_error(self, mock_share, mock_config, graph_id): + """Test share raises on error.""" + mock_resp = Mock() + mock_resp.status_code = HTTPStatus.NOT_FOUND + mock_share.return_value = mock_resp + + client = ReportClient(mock_config) + + with pytest.raises(RuntimeError, match="Share report failed"): + client.share(graph_id, "rpt-bad", "pub-bad") + + +@pytest.mark.unit +class TestReportIsShared: + """Test suite for ReportClient.is_shared_report method.""" + + def test_is_shared_report_true(self, mock_config): + """Test detecting a shared report.""" + client = ReportClient(mock_config) + report = Mock(source_graph_id="other-graph-123") + + assert client.is_shared_report(report) is True + + def test_is_shared_report_false(self, mock_config): + """Test detecting a non-shared report.""" + client = ReportClient(mock_config) + report = Mock(spec=[]) # No source_graph_id attribute + + assert client.is_shared_report(report) is False diff --git a/tests/test_sse_client.py b/tests/test_sse_client.py new file mode 100644 index 0000000..f27f209 --- /dev/null +++ b/tests/test_sse_client.py @@ -0,0 +1,503 @@ +"""Unit tests for SSEClient operational logic. + +Covers: connection, event processing, reconnection/retry, heartbeat, +completion-event auto-close, and close/cleanup. Dataclass and listener +tests already exist in extensions/tests/test_unit.py. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from robosystems_client.extensions.sse_client import ( + SSEClient, + SSEConfig, +) + + +@pytest.fixture +def sse_config(): + return SSEConfig( + base_url="http://localhost:8000", + headers={"X-API-Key": "test-key"}, + max_retries=3, + retry_delay=100, + timeout=10, + ) + + +@pytest.fixture +def sse_client(sse_config): + return SSEClient(sse_config) + + +# ── Event processing ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestProcessEvents: + """Test _process_events with mocked SSE streams.""" + + def _make_client_with_lines(self, sse_config, lines): + """Create an SSEClient with a fake response that yields lines.""" + client = SSEClient(sse_config) + mock_response = Mock() + mock_response.iter_lines.return_value = iter(lines) + client._response = mock_response + client.closed = False + return client + + def test_simple_event(self, sse_config): + """Test processing a single complete event.""" + lines = [ + "event: operation_progress", + 'data: {"message": "Working"}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_progress", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + assert events[0] == {"message": "Working"} + + def test_multiline_data(self, sse_config): + """Test multiline data is joined with newlines.""" + lines = [ + "event: data_chunk", + 'data: {"rows": [', + 'data: {"id": 1}', + "data: ]}", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("data_chunk", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + assert events[0] == {"rows": [{"id": 1}]} + + def test_comment_lines_skipped(self, sse_config): + """Test that comment lines (starting with :) are skipped.""" + lines = [ + ": this is a comment", + "event: operation_progress", + 'data: {"message": "hello"}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_progress", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + + def test_id_field_tracked(self, sse_config): + """Test that id field is tracked for reconnection.""" + lines = [ + "event: operation_progress", + 'data: {"message": "step 1"}', + "id: 42", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + client._process_events() + + assert client.last_event_id == "42" + + def test_retry_field_parsed(self, sse_config): + """Test retry field is parsed from events.""" + lines = [ + "event: operation_progress", + 'data: {"message": "retry test"}', + "retry: 5000", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + # The retry value is parsed in the event buffer but not stored on client + # Just verify it doesn't crash + client._process_events() + + def test_invalid_retry_ignored(self, sse_config): + """Test invalid retry value is silently ignored.""" + lines = [ + "event: operation_progress", + 'data: {"message": "test"}', + "retry: not-a-number", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + client._process_events() # Should not raise + + def test_multiple_events(self, sse_config): + """Test processing multiple events in sequence.""" + lines = [ + "event: operation_progress", + 'data: {"message": "step 1"}', + "", + "event: operation_progress", + 'data: {"message": "step 2"}', + "", + "event: operation_completed", + 'data: {"result": {"done": true}}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + progress = [] + completed = [] + client.on("operation_progress", lambda d: progress.append(d)) + client.on("operation_completed", lambda d: completed.append(d)) + + client._process_events() + + assert len(progress) == 2 + assert len(completed) == 1 + + def test_default_event_type_is_message(self, sse_config): + """Test events without event: field default to 'message'.""" + lines = [ + 'data: {"hello": "world"}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("message", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + assert events[0] == {"hello": "world"} + + def test_non_json_data_kept_as_string(self, sse_config): + """Test that non-JSON data is kept as a string.""" + lines = [ + "event: operation_progress", + "data: plain text message", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_progress", lambda d: events.append(d)) + + client._process_events() + + assert events[0] == "plain text message" + + def test_final_event_without_trailing_newline(self, sse_config): + """Test that a final event is dispatched even without trailing empty line.""" + lines = [ + "event: operation_completed", + 'data: {"result": {}}', + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_completed", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + + def test_closed_flag_stops_processing(self, sse_config): + """Test that setting closed=True mid-stream stops processing.""" + call_count = 0 + + def close_on_second(data): + nonlocal call_count + call_count += 1 + + lines = [ + "event: operation_progress", + 'data: {"message": "first"}', + "", + "event: operation_progress", + 'data: {"message": "second"}', + "", + ] + client = self._make_client_with_lines(sse_config, lines) + client.on("operation_progress", close_on_second) + + # Close after first event dispatch + original_dispatch = client._dispatch_event + + def dispatch_and_close(buf): + original_dispatch(buf) + client.closed = True + + client._dispatch_event = dispatch_and_close + client._process_events() + + assert call_count == 1 + + def test_data_field_with_no_value(self, sse_config): + """Test 'data' line without colon appends empty string.""" + lines = [ + "event: operation_progress", + "data", + "", + ] + client = self._make_client_with_lines(sse_config, lines) + events = [] + client.on("operation_progress", lambda d: events.append(d)) + + client._process_events() + + assert len(events) == 1 + assert events[0] == "" # Empty string, not JSON + + +# ── Completion auto-close ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestCompletionAutoClose: + """Test that terminal events set closed flag.""" + + def _dispatch(self, sse_config, event_type): + client = SSEClient(sse_config) + buf = {"event": event_type, "data": ["{}"], "id": None, "retry": None} + client._dispatch_event(buf) + return client.closed + + def test_completed_sets_closed(self, sse_config): + assert self._dispatch(sse_config, "operation_completed") is True + + def test_error_sets_closed(self, sse_config): + assert self._dispatch(sse_config, "operation_error") is True + + def test_cancelled_sets_closed(self, sse_config): + assert self._dispatch(sse_config, "operation_cancelled") is True + + def test_progress_does_not_close(self, sse_config): + assert self._dispatch(sse_config, "operation_progress") is False + + def test_data_chunk_does_not_close(self, sse_config): + assert self._dispatch(sse_config, "data_chunk") is False + + +# ── Reconnection / retry ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestReconnection: + """Test _handle_error retry logic.""" + + @patch("robosystems_client.extensions.sse_client.time.sleep") + @patch("robosystems_client.extensions.sse_client.httpx") + def test_retry_with_backoff(self, mock_httpx, mock_sleep, sse_config): + """Test exponential backoff on retry.""" + # Make httpx.Client always raise to trigger retry chain + mock_httpx.Client.side_effect = Exception("connection refused") + + client = SSEClient(sse_config) + + reconnect_events = [] + exceeded_events = [] + client.on("reconnecting", lambda d: reconnect_events.append(d)) + client.on("max_retries_exceeded", lambda d: exceeded_events.append(d)) + + # connect → fail → _handle_error → retry connect → fail → ... → max retries + client.connect("op-1") + + # Should have attempted max_retries reconnections + assert len(reconnect_events) == sse_config.max_retries + assert len(exceeded_events) == 1 + + # Verify backoff: delays should double each time + # retry_delay=100ms, so: 100ms, 200ms, 400ms → sleep(0.1), sleep(0.2), sleep(0.4) + sleep_calls = [call.args[0] for call in mock_sleep.call_args_list] + assert sleep_calls[0] == pytest.approx(0.1) + assert sleep_calls[1] == pytest.approx(0.2) + assert sleep_calls[2] == pytest.approx(0.4) + + def test_max_retries_exceeded_emits_event(self, sse_config): + """Test that exceeding max retries emits event and closes.""" + sse_config.max_retries = 0 # No retries allowed + client = SSEClient(sse_config) + + exceeded_events = [] + client.on("max_retries_exceeded", lambda d: exceeded_events.append(d)) + + client._handle_error(Exception("fail"), "op-1", 0) + + assert len(exceeded_events) == 1 + assert client.closed is True + + @patch("robosystems_client.extensions.sse_client.time.sleep") + def test_resume_from_last_event_id(self, mock_sleep, sse_config): + """Test reconnection resumes from last_event_id.""" + sse_config.max_retries = 1 + client = SSEClient(sse_config) + client.last_event_id = "10" + + connect_calls = [] + # Make connect succeed (no-op) to stop recursion + with patch.object( + client, "connect", side_effect=lambda op, seq: connect_calls.append((op, seq)) + ): + client._handle_error(Exception("lost"), "op-1", 0) + + assert connect_calls == [("op-1", 11)] # last_event_id + 1 + + @patch("robosystems_client.extensions.sse_client.time.sleep") + def test_resume_with_non_numeric_event_id(self, mock_sleep, sse_config): + """Test reconnection falls back to from_sequence for non-numeric IDs.""" + sse_config.max_retries = 1 + client = SSEClient(sse_config) + client.last_event_id = "not-a-number" + + connect_calls = [] + with patch.object( + client, "connect", side_effect=lambda op, seq: connect_calls.append((op, seq)) + ): + client._handle_error(Exception("lost"), "op-1", 5) + + assert connect_calls == [("op-1", 5)] # Falls back to from_sequence + + def test_no_retry_when_closed(self, sse_config): + """Test no retry when client is already closed.""" + client = SSEClient(sse_config) + client.closed = True + + reconnect_events = [] + client.on("reconnecting", lambda d: reconnect_events.append(d)) + + client._handle_error(Exception("fail"), "op-1", 0) + + assert len(reconnect_events) == 0 + + +# ── Connect ────────────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestConnect: + """Test connect method.""" + + @patch("robosystems_client.extensions.sse_client.httpx") + def test_connect_sets_up_stream(self, mock_httpx, sse_config): + """Test connect creates httpx client and starts streaming.""" + mock_http_client = MagicMock() + mock_context = MagicMock() + mock_response = MagicMock() + mock_response.iter_lines.return_value = iter( + [ + "event: operation_completed", + 'data: {"result": {}}', + "", + ] + ) + mock_context.__enter__ = Mock(return_value=mock_response) + mock_http_client.stream.return_value = mock_context + mock_httpx.Client.return_value = mock_http_client + + client = SSEClient(sse_config) + connected_events = [] + client.on("connected", lambda d: connected_events.append(d)) + + client.connect("op-123") + + assert len(connected_events) == 1 + assert client.reconnect_attempts == 0 + + @patch("robosystems_client.extensions.sse_client.httpx") + def test_connect_error_triggers_retry(self, mock_httpx, sse_config): + """Test that connection error triggers _handle_error.""" + sse_config.max_retries = 0 # Don't actually retry + mock_httpx.Client.side_effect = Exception("Connection refused") + + client = SSEClient(sse_config) + exceeded = [] + client.on("max_retries_exceeded", lambda d: exceeded.append(d)) + + client.connect("op-123") + + assert len(exceeded) == 1 + + +# ── Close / cleanup ───────────────────────────────────────────────── + + +@pytest.mark.unit +class TestClose: + """Test close and cleanup.""" + + def test_close_clears_listeners(self, sse_config): + """Test close removes all listeners.""" + client = SSEClient(sse_config) + client.on("test", lambda d: None) + assert len(client.listeners) > 0 + + client.close() + + assert len(client.listeners) == 0 + assert client.closed is True + + def test_close_emits_closed_event(self, sse_config): + """Test close emits 'closed' event before clearing listeners.""" + client = SSEClient(sse_config) + closed_events = [] + client.on("closed", lambda d: closed_events.append(d)) + + client.close() + + assert len(closed_events) == 1 + + def test_close_cleans_up_http_client(self, sse_config): + """Test close closes the httpx client.""" + client = SSEClient(sse_config) + mock_http = Mock() + client.client = mock_http + + client.close() + + mock_http.close.assert_called_once() + assert client.client is None + + def test_close_cleans_up_context_manager(self, sse_config): + """Test close exits the stream context manager.""" + client = SSEClient(sse_config) + mock_ctx = MagicMock() + client._context_manager = mock_ctx + + client.close() + + mock_ctx.__exit__.assert_called_once_with(None, None, None) + + def test_is_connected(self, sse_config): + """Test is_connected reflects state.""" + client = SSEClient(sse_config) + assert client.is_connected() is False + + client.client = Mock() + assert client.is_connected() is True + + client.closed = True + assert client.is_connected() is False + + +# ── Listener error handling ────────────────────────────────────────── + + +@pytest.mark.unit +class TestListenerErrors: + """Test that listener errors don't break other listeners.""" + + def test_error_in_listener_doesnt_stop_others(self, sse_config): + """Test that a failing listener doesn't prevent others from running.""" + client = SSEClient(sse_config) + + results = [] + client.on("test", lambda d: (_ for _ in ()).throw(ValueError("boom"))) + client.on("test", lambda d: results.append(d)) + + client.emit("test", "data") + + assert results == ["data"] + + def test_emit_to_nonexistent_event_is_noop(self, sse_config): + """Test emitting to an event with no listeners doesn't raise.""" + client = SSEClient(sse_config) + client.emit("nonexistent", "data") # Should not raise diff --git a/tests/test_table_client.py b/tests/test_table_client.py new file mode 100644 index 0000000..0188f98 --- /dev/null +++ b/tests/test_table_client.py @@ -0,0 +1,240 @@ +"""Unit tests for TableClient.""" + +import pytest +from unittest.mock import Mock, patch +from robosystems_client.extensions.table_client import ( + TableClient, + TableInfo, + QueryResult, +) + + +@pytest.mark.unit +class TestTableDataclasses: + """Test suite for table-related dataclasses.""" + + def test_table_info(self): + """Test TableInfo dataclass.""" + info = TableInfo( + table_name="Entity", + table_type="parquet", + row_count=100, + file_count=3, + total_size_bytes=50000, + ) + + assert info.table_name == "Entity" + assert info.table_type == "parquet" + assert info.row_count == 100 + assert info.file_count == 3 + assert info.total_size_bytes == 50000 + + def test_query_result(self): + """Test QueryResult dataclass.""" + result = QueryResult( + columns=["name", "revenue"], + rows=[["ACME", 1000000], ["Beta Corp", 2000000]], + row_count=2, + execution_time_ms=45.0, + ) + + assert result.columns == ["name", "revenue"] + assert len(result.rows) == 2 + assert result.row_count == 2 + assert result.execution_time_ms == 45.0 + assert result.success is True + assert result.error is None + + def test_query_result_with_error(self): + """Test QueryResult with error.""" + result = QueryResult( + columns=[], + rows=[], + row_count=0, + execution_time_ms=0, + success=False, + error="SQL syntax error", + ) + + assert result.success is False + assert result.error == "SQL syntax error" + + +@pytest.mark.unit +class TestTableClientInit: + """Test suite for TableClient initialization.""" + + def test_client_initialization(self, mock_config): + """Test that client initializes correctly with config.""" + client = TableClient(mock_config) + + assert client.base_url == "http://localhost:8000" + assert client.token == "test-api-key" + assert client.headers == {"X-API-Key": "test-api-key"} + + +@pytest.mark.unit +class TestTableList: + """Test suite for TableClient.list method.""" + + @patch("robosystems_client.extensions.table_client.list_tables") + def test_list_tables(self, mock_list, mock_config, graph_id): + """Test listing tables.""" + mock_table_1 = Mock() + mock_table_1.table_name = "Entity" + mock_table_1.table_type = "parquet" + mock_table_1.row_count = 100 + mock_table_1.file_count = 2 + mock_table_1.total_size_bytes = 50000 + + mock_table_2 = Mock() + mock_table_2.table_name = "Transaction" + mock_table_2.table_type = "parquet" + mock_table_2.row_count = 500 + mock_table_2.file_count = 1 + mock_table_2.total_size_bytes = 120000 + + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.tables = [mock_table_1, mock_table_2] + mock_list.return_value = mock_resp + + client = TableClient(mock_config) + tables = client.list(graph_id) + + assert len(tables) == 2 + assert tables[0].table_name == "Entity" + assert tables[0].row_count == 100 + assert tables[1].table_name == "Transaction" + assert tables[1].row_count == 500 + + @patch("robosystems_client.extensions.table_client.list_tables") + def test_list_tables_failure(self, mock_list, mock_config, graph_id): + """Test list tables returns empty list on failure.""" + mock_resp = Mock() + mock_resp.status_code = 500 + mock_resp.parsed = None + mock_list.return_value = mock_resp + + client = TableClient(mock_config) + tables = client.list(graph_id) + + assert tables == [] + + def test_list_tables_no_token(self, mock_config, graph_id): + """Test list tables returns empty list without token.""" + mock_config["token"] = None + client = TableClient(mock_config) + tables = client.list(graph_id) + + assert tables == [] + + @patch("robosystems_client.extensions.table_client.list_tables") + def test_list_tables_empty(self, mock_list, mock_config, graph_id): + """Test listing when no tables exist.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.tables = [] + mock_list.return_value = mock_resp + + client = TableClient(mock_config) + tables = client.list(graph_id) + + assert tables == [] + + +@pytest.mark.unit +class TestTableQuery: + """Test suite for TableClient.query method.""" + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_success(self, mock_query, mock_config, graph_id): + """Test successful SQL query.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.columns = ["name", "ticker"] + mock_resp.parsed.rows = [["ACME Corp", "ACM"], ["Beta Inc", "BET"]] + mock_resp.parsed.execution_time_ms = 30.0 + mock_query.return_value = mock_resp + + client = TableClient(mock_config) + result = client.query(graph_id, "SELECT name, ticker FROM Entity") + + assert result.success is True + assert result.columns == ["name", "ticker"] + assert len(result.rows) == 2 + assert result.row_count == 2 + assert result.execution_time_ms == 30.0 + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_with_limit(self, mock_query, mock_config, graph_id): + """Test query with limit appended.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.columns = ["name"] + mock_resp.parsed.rows = [["ACME"]] + mock_resp.parsed.execution_time_ms = 10.0 + mock_query.return_value = mock_resp + + client = TableClient(mock_config) + client.query(graph_id, "SELECT name FROM Entity", limit=5) + + call_kwargs = mock_query.call_args[1] + assert "LIMIT 5" in call_kwargs["body"].sql + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_strips_trailing_semicolon(self, mock_query, mock_config, graph_id): + """Test query strips trailing semicolon before adding LIMIT.""" + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.parsed = Mock() + mock_resp.parsed.columns = ["name"] + mock_resp.parsed.rows = [] + mock_resp.parsed.execution_time_ms = 5.0 + mock_query.return_value = mock_resp + + client = TableClient(mock_config) + client.query(graph_id, "SELECT name FROM Entity;", limit=10) + + call_kwargs = mock_query.call_args[1] + sql = call_kwargs["body"].sql + assert not sql.endswith("; LIMIT 10") + assert sql.endswith("LIMIT 10") + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_failure(self, mock_query, mock_config, graph_id): + """Test query failure.""" + mock_resp = Mock() + mock_resp.status_code = 400 + mock_resp.parsed = None + mock_query.return_value = mock_resp + + client = TableClient(mock_config) + result = client.query(graph_id, "INVALID SQL") + + assert result.success is False + assert "Query failed" in result.error + + def test_query_no_token(self, mock_config, graph_id): + """Test query fails without token.""" + mock_config["token"] = None + client = TableClient(mock_config) + result = client.query(graph_id, "SELECT 1") + + assert result.success is False + assert "No API key" in result.error + + @patch("robosystems_client.extensions.table_client.query_tables") + def test_query_exception_handling(self, mock_query, mock_config, graph_id): + """Test query handles unexpected exceptions.""" + mock_query.side_effect = Exception("Network error") + + client = TableClient(mock_config) + result = client.query(graph_id, "SELECT 1") + + assert result.success is False + assert "Network error" in result.error diff --git a/robosystems_client/extensions/tests/test_token_utils.py b/tests/test_token_utils.py similarity index 100% rename from robosystems_client/extensions/tests/test_token_utils.py rename to tests/test_token_utils.py diff --git a/robosystems_client/extensions/tests/test_unit.py b/tests/test_utils_unit.py similarity index 98% rename from robosystems_client/extensions/tests/test_unit.py rename to tests/test_utils_unit.py index 8a21f97..1cb6090 100644 --- a/robosystems_client/extensions/tests/test_unit.py +++ b/tests/test_utils_unit.py @@ -7,13 +7,6 @@ from datetime import datetime from unittest.mock import Mock -import sys -import os - -# Add parent directories to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../..")) -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - from robosystems_client.extensions import ( SSEClient, SSEConfig,