diff --git a/sentry_sdk/integrations/pydantic_ai/__init__.py b/sentry_sdk/integrations/pydantic_ai/__init__.py index 0f0de53fa5..7b6efb7ff1 100644 --- a/sentry_sdk/integrations/pydantic_ai/__init__.py +++ b/sentry_sdk/integrations/pydantic_ai/__init__.py @@ -1,8 +1,10 @@ -from sentry_sdk.integrations import DidNotEnable, Integration +import functools +from sentry_sdk.integrations import DidNotEnable, Integration try: import pydantic_ai # type: ignore # noqa: F401 + from pydantic_ai import Agent except ImportError: raise DidNotEnable("pydantic-ai not installed") @@ -14,10 +16,20 @@ _patch_tool_execution, ) +from .spans.ai_client import ai_client_span, update_ai_client_span + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + from pydantic_ai import ModelRequestContext, RunContext + from pydantic_ai.messages import ModelResponse # type: ignore + class PydanticAIIntegration(Integration): identifier = "pydantic_ai" origin = f"auto.ai.{identifier}" + are_request_hooks_available = True def __init__( self, include_prompts: bool = True, handled_tool_call_exceptions: bool = True @@ -45,6 +57,87 @@ def setup_once() -> None: - Tool executions """ _patch_agent_run() - _patch_graph_nodes() - _patch_model_request() + + try: + from pydantic_ai.capabilities import Hooks # type: ignore + except ImportError: + Hooks = None + PydanticAIIntegration.are_request_hooks_available = False + + if Hooks is None: + _patch_graph_nodes() + _patch_model_request() + return + _patch_tool_execution() + + # Assumptions: + # - Model requests within a run are sequential. + # - ctx.metadata is a shared dict instance between hooks. + hooks = Hooks() + + @hooks.on.before_model_request # type: ignore + async def on_request( + ctx: "RunContext[None]", request_context: "ModelRequestContext" + ) -> "ModelRequestContext": + span = ai_client_span( + messages=request_context.messages, + agent=None, + model=request_context.model, + model_settings=request_context.model_settings, + ) + run_context_metadata = ctx.metadata + if isinstance(run_context_metadata, dict): + run_context_metadata["_sentry_span"] = span + + span.__enter__() + + return request_context + + @hooks.on.after_model_request # type: ignore + async def on_response( + ctx: "RunContext[None]", + *, + request_context: "ModelRequestContext", + response: "ModelResponse", + ) -> "ModelResponse": + run_context_metadata = ctx.metadata + if not isinstance(run_context_metadata, dict): + return response + + span = run_context_metadata["_sentry_span"] + if span is None: + return response + + update_ai_client_span(span, response) + span.__exit__(None, None, None) + del run_context_metadata["_sentry_span"] + + return response + + @hooks.on.model_request_error # type: ignore + async def on_error( + ctx: "RunContext[None]", + *, + request_context: "ModelRequestContext", + error: "Exception", + ) -> "ModelResponse": + run_context_metadata = ctx.metadata + if isinstance(run_context_metadata, dict): + span = run_context_metadata.pop("_sentry_span", None) + if span is not None: + span.__exit__(type(error), error, error.__traceback__) + raise error + + original_init = Agent.__init__ + + @functools.wraps(original_init) + def patched_init( + self: "Agent[Any, Any]", *args: "Any", **kwargs: "Any" + ) -> None: + caps = list(kwargs.get("capabilities") or []) + caps.append(hooks) + kwargs["capabilities"] = caps + return original_init(self, *args, **kwargs) + + Agent.__init__ = patched_init diff --git a/sentry_sdk/integrations/pydantic_ai/patches/agent_run.py b/sentry_sdk/integrations/pydantic_ai/patches/agent_run.py index eaa4385834..df0cec07e2 100644 --- a/sentry_sdk/integrations/pydantic_ai/patches/agent_run.py +++ b/sentry_sdk/integrations/pydantic_ai/patches/agent_run.py @@ -96,6 +96,9 @@ def _create_run_wrapper( original_func: The original run method is_streaming: Whether this is a streaming method (for future use) """ + from sentry_sdk.integrations.pydantic_ai import ( + PydanticAIIntegration, + ) # Required to avoid circular import @wraps(original_func) async def wrapper(self: "Any", *args: "Any", **kwargs: "Any") -> "Any": @@ -107,6 +110,11 @@ async def wrapper(self: "Any", *args: "Any", **kwargs: "Any") -> "Any": model = kwargs.get("model") model_settings = kwargs.get("model_settings") + if PydanticAIIntegration.are_request_hooks_available: + metadata = kwargs.get("metadata") + if not metadata: + kwargs["metadata"] = {"_sentry_span": None} + # Create invoke_agent span with invoke_agent_span( user_prompt, self, model, model_settings, is_streaming @@ -140,6 +148,9 @@ def _create_streaming_wrapper( """ Wraps run_stream method that returns an async context manager. """ + from sentry_sdk.integrations.pydantic_ai import ( + PydanticAIIntegration, + ) # Required to avoid circular import @wraps(original_func) def wrapper(self: "Any", *args: "Any", **kwargs: "Any") -> "Any": @@ -148,6 +159,11 @@ def wrapper(self: "Any", *args: "Any", **kwargs: "Any") -> "Any": model = kwargs.get("model") model_settings = kwargs.get("model_settings") + if PydanticAIIntegration.are_request_hooks_available: + metadata = kwargs.get("metadata") + if not metadata: + kwargs["metadata"] = {"_sentry_span": None} + # Call original function to get the context manager original_ctx_manager = original_func(self, *args, **kwargs) diff --git a/tests/integrations/pydantic_ai/test_pydantic_ai.py b/tests/integrations/pydantic_ai/test_pydantic_ai.py index f0ddc6c4ed..e64a6d3e52 100644 --- a/tests/integrations/pydantic_ai/test_pydantic_ai.py +++ b/tests/integrations/pydantic_ai/test_pydantic_ai.py @@ -19,34 +19,40 @@ @pytest.fixture -def test_agent(): - """Create a test agent with model settings.""" - return Agent( - "test", - name="test_agent", - system_prompt="You are a helpful test assistant.", - ) +def get_test_agent(): + def inner(): + """Create a test agent with model settings.""" + return Agent( + "test", + name="test_agent", + system_prompt="You are a helpful test assistant.", + ) + + return inner @pytest.fixture -def test_agent_with_settings(): - """Create a test agent with explicit model settings.""" - from pydantic_ai import ModelSettings +def get_test_agent_with_settings(): + def inner(): + """Create a test agent with explicit model settings.""" + from pydantic_ai import ModelSettings + + return Agent( + "test", + name="test_agent_settings", + system_prompt="You are a test assistant with settings.", + model_settings=ModelSettings( + temperature=0.7, + max_tokens=100, + top_p=0.9, + ), + ) - return Agent( - "test", - name="test_agent_settings", - system_prompt="You are a test assistant with settings.", - model_settings=ModelSettings( - temperature=0.7, - max_tokens=100, - top_p=0.9, - ), - ) + return inner @pytest.mark.asyncio -async def test_agent_run_async(sentry_init, capture_events, test_agent): +async def test_agent_run_async(sentry_init, capture_events, get_test_agent): """ Test that the integration creates spans for async agent runs. """ @@ -58,6 +64,7 @@ async def test_agent_run_async(sentry_init, capture_events, test_agent): events = capture_events() + test_agent = get_test_agent() result = await test_agent.run("Test input") assert result is not None @@ -88,7 +95,7 @@ async def test_agent_run_async(sentry_init, capture_events, test_agent): @pytest.mark.asyncio -async def test_agent_run_async_usage_data(sentry_init, capture_events, test_agent): +async def test_agent_run_async_usage_data(sentry_init, capture_events, get_test_agent): """ Test that the invoke_agent span includes token usage and model data. """ @@ -100,6 +107,7 @@ async def test_agent_run_async_usage_data(sentry_init, capture_events, test_agen events = capture_events() + test_agent = get_test_agent() result = await test_agent.run("Test input") assert result is not None @@ -132,7 +140,7 @@ async def test_agent_run_async_usage_data(sentry_init, capture_events, test_agen assert trace_data["gen_ai.response.model"] == "test" # Test model name -def test_agent_run_sync(sentry_init, capture_events, test_agent): +def test_agent_run_sync(sentry_init, capture_events, get_test_agent): """ Test that the integration creates spans for sync agent runs. """ @@ -144,6 +152,7 @@ def test_agent_run_sync(sentry_init, capture_events, test_agent): events = capture_events() + test_agent = get_test_agent() result = test_agent.run_sync("Test input") assert result is not None @@ -166,7 +175,7 @@ def test_agent_run_sync(sentry_init, capture_events, test_agent): @pytest.mark.asyncio -async def test_agent_run_stream(sentry_init, capture_events, test_agent): +async def test_agent_run_stream(sentry_init, capture_events, get_test_agent): """ Test that the integration creates spans for streaming agent runs. """ @@ -178,6 +187,7 @@ async def test_agent_run_stream(sentry_init, capture_events, test_agent): events = capture_events() + test_agent = get_test_agent() async with test_agent.run_stream("Test input") as result: # Consume the stream async for _ in result.stream_output(): @@ -207,7 +217,7 @@ async def test_agent_run_stream(sentry_init, capture_events, test_agent): @pytest.mark.asyncio -async def test_agent_run_stream_events(sentry_init, capture_events, test_agent): +async def test_agent_run_stream_events(sentry_init, capture_events, get_test_agent): """ Test that run_stream_events creates spans (it uses run internally, so non-streaming). """ @@ -220,6 +230,7 @@ async def test_agent_run_stream_events(sentry_init, capture_events, test_agent): events = capture_events() # Consume all events + test_agent = get_test_agent() async for _ in test_agent.run_stream_events("Test input"): pass @@ -239,11 +250,13 @@ async def test_agent_run_stream_events(sentry_init, capture_events, test_agent): @pytest.mark.asyncio -async def test_agent_with_tools(sentry_init, capture_events, test_agent): +async def test_agent_with_tools(sentry_init, capture_events, get_test_agent): """ Test that tool execution creates execute_tool spans. """ + test_agent = get_test_agent() + @test_agent.tool_plain def add_numbers(a: int, b: int) -> int: """Add two numbers together.""" @@ -294,7 +307,7 @@ def add_numbers(a: int, b: int) -> int: ) @pytest.mark.asyncio async def test_agent_with_tool_model_retry( - sentry_init, capture_events, test_agent, handled_tool_call_exceptions + sentry_init, capture_events, get_test_agent, handled_tool_call_exceptions ): """ Test that a handled exception is captured when a tool raises ModelRetry. @@ -302,6 +315,8 @@ async def test_agent_with_tool_model_retry( retries = 0 + test_agent = get_test_agent() + @test_agent.tool_plain def add_numbers(a: int, b: int) -> float: """Add two numbers together, but raises an exception on the first attempt.""" @@ -374,12 +389,14 @@ def add_numbers(a: int, b: int) -> float: ) @pytest.mark.asyncio async def test_agent_with_tool_validation_error( - sentry_init, capture_events, test_agent, handled_tool_call_exceptions + sentry_init, capture_events, get_test_agent, handled_tool_call_exceptions ): """ Test that a handled exception is captured when a tool has unsatisfiable constraints. """ + test_agent = get_test_agent() + @test_agent.tool_plain def add_numbers(a: Annotated[int, Field(gt=0, lt=0)], b: int) -> int: """Add two numbers together.""" @@ -440,11 +457,13 @@ def add_numbers(a: Annotated[int, Field(gt=0, lt=0)], b: int) -> int: @pytest.mark.asyncio -async def test_agent_with_tools_streaming(sentry_init, capture_events, test_agent): +async def test_agent_with_tools_streaming(sentry_init, capture_events, get_test_agent): """ Test that tool execution works correctly with streaming. """ + test_agent = get_test_agent() + @test_agent.tool_plain def multiply(a: int, b: int) -> int: """Multiply two numbers.""" @@ -484,7 +503,9 @@ def multiply(a: int, b: int) -> int: @pytest.mark.asyncio -async def test_model_settings(sentry_init, capture_events, test_agent_with_settings): +async def test_model_settings( + sentry_init, capture_events, get_test_agent_with_settings +): """ Test that model settings are captured in spans. """ @@ -495,6 +516,7 @@ async def test_model_settings(sentry_init, capture_events, test_agent_with_setti events = capture_events() + test_agent_with_settings = get_test_agent_with_settings() await test_agent_with_settings.run("Test input") (transaction,) = events @@ -596,7 +618,7 @@ async def test_error_handling(sentry_init, capture_events): @pytest.mark.asyncio -async def test_without_pii(sentry_init, capture_events, test_agent): +async def test_without_pii(sentry_init, capture_events, get_test_agent): """ Test that PII is not captured when send_default_pii is False. """ @@ -608,6 +630,7 @@ async def test_without_pii(sentry_init, capture_events, test_agent): events = capture_events() + test_agent = get_test_agent() await test_agent.run("Sensitive input") (transaction,) = events @@ -623,11 +646,13 @@ async def test_without_pii(sentry_init, capture_events, test_agent): @pytest.mark.asyncio -async def test_without_pii_tools(sentry_init, capture_events, test_agent): +async def test_without_pii_tools(sentry_init, capture_events, get_test_agent): """ Test that tool input/output are not captured when send_default_pii is False. """ + test_agent = get_test_agent() + @test_agent.tool_plain def sensitive_tool(data: str) -> str: """A tool with sensitive data.""" @@ -656,7 +681,7 @@ def sensitive_tool(data: str) -> str: @pytest.mark.asyncio -async def test_multiple_agents_concurrent(sentry_init, capture_events, test_agent): +async def test_multiple_agents_concurrent(sentry_init, capture_events, get_test_agent): """ Test that multiple agents can run concurrently without interfering. """ @@ -667,6 +692,8 @@ async def test_multiple_agents_concurrent(sentry_init, capture_events, test_agen events = capture_events() + test_agent = get_test_agent() + async def run_agent(input_text): return await test_agent.run(input_text) @@ -737,7 +764,7 @@ async def test_message_history(sentry_init, capture_events): @pytest.mark.asyncio -async def test_gen_ai_system(sentry_init, capture_events, test_agent): +async def test_gen_ai_system(sentry_init, capture_events, get_test_agent): """ Test that gen_ai.system is set from the model. """ @@ -748,6 +775,7 @@ async def test_gen_ai_system(sentry_init, capture_events, test_agent): events = capture_events() + test_agent = get_test_agent() await test_agent.run("Test input") (transaction,) = events @@ -764,7 +792,7 @@ async def test_gen_ai_system(sentry_init, capture_events, test_agent): @pytest.mark.asyncio -async def test_include_prompts_false(sentry_init, capture_events, test_agent): +async def test_include_prompts_false(sentry_init, capture_events, get_test_agent): """ Test that prompts are not captured when include_prompts=False. """ @@ -776,6 +804,7 @@ async def test_include_prompts_false(sentry_init, capture_events, test_agent): events = capture_events() + test_agent = get_test_agent() await test_agent.run("Sensitive prompt") (transaction,) = events @@ -791,7 +820,7 @@ async def test_include_prompts_false(sentry_init, capture_events, test_agent): @pytest.mark.asyncio -async def test_include_prompts_true(sentry_init, capture_events, test_agent): +async def test_include_prompts_true(sentry_init, capture_events, get_test_agent): """ Test that prompts are captured when include_prompts=True (default). """ @@ -803,6 +832,7 @@ async def test_include_prompts_true(sentry_init, capture_events, test_agent): events = capture_events() + test_agent = get_test_agent() await test_agent.run("Test prompt") (transaction,) = events @@ -819,12 +849,14 @@ async def test_include_prompts_true(sentry_init, capture_events, test_agent): @pytest.mark.asyncio async def test_include_prompts_false_with_tools( - sentry_init, capture_events, test_agent + sentry_init, capture_events, get_test_agent ): """ Test that tool input/output are not captured when include_prompts=False. """ + test_agent = get_test_agent() + @test_agent.tool_plain def test_tool(value: int) -> int: """A test tool.""" @@ -853,7 +885,9 @@ def test_tool(value: int) -> int: @pytest.mark.asyncio -async def test_include_prompts_requires_pii(sentry_init, capture_events, test_agent): +async def test_include_prompts_requires_pii( + sentry_init, capture_events, get_test_agent +): """ Test that include_prompts requires send_default_pii=True. """ @@ -865,6 +899,7 @@ async def test_include_prompts_requires_pii(sentry_init, capture_events, test_ag events = capture_events() + test_agent = get_test_agent() await test_agent.run("Test prompt") (transaction,) = events @@ -1015,7 +1050,7 @@ async def mock_map_tool_result_part(part): @pytest.mark.asyncio -async def test_context_cleanup_after_run(sentry_init, test_agent): +async def test_context_cleanup_after_run(sentry_init, get_test_agent): """ Test that the pydantic_ai_agent context is properly cleaned up after agent execution. """ @@ -1031,13 +1066,14 @@ async def test_context_cleanup_after_run(sentry_init, test_agent): assert "pydantic_ai_agent" not in scope._contexts # Run the agent + test_agent = get_test_agent() await test_agent.run("Test input") # Verify context is cleaned up after run assert "pydantic_ai_agent" not in scope._contexts -def test_context_cleanup_after_run_sync(sentry_init, test_agent): +def test_context_cleanup_after_run_sync(sentry_init, get_test_agent): """ Test that the pydantic_ai_agent context is properly cleaned up after sync agent execution. """ @@ -1053,6 +1089,7 @@ def test_context_cleanup_after_run_sync(sentry_init, test_agent): assert "pydantic_ai_agent" not in scope._contexts # Run the agent synchronously + test_agent = get_test_agent() test_agent.run_sync("Test input") # Verify context is cleaned up after run @@ -1060,7 +1097,7 @@ def test_context_cleanup_after_run_sync(sentry_init, test_agent): @pytest.mark.asyncio -async def test_context_cleanup_after_streaming(sentry_init, test_agent): +async def test_context_cleanup_after_streaming(sentry_init, get_test_agent): """ Test that the pydantic_ai_agent context is properly cleaned up after streaming execution. """ @@ -1075,6 +1112,7 @@ async def test_context_cleanup_after_streaming(sentry_init, test_agent): scope = sentry_sdk.get_current_scope() assert "pydantic_ai_agent" not in scope._contexts + test_agent = get_test_agent() # Run the agent with streaming async with test_agent.run_stream("Test input") as result: async for _ in result.stream_output(): @@ -1085,12 +1123,14 @@ async def test_context_cleanup_after_streaming(sentry_init, test_agent): @pytest.mark.asyncio -async def test_context_cleanup_on_error(sentry_init, test_agent): +async def test_context_cleanup_on_error(sentry_init, get_test_agent): """ Test that the pydantic_ai_agent context is cleaned up even when an error occurs. """ import sentry_sdk + test_agent = get_test_agent() + # Create an agent with a tool that raises an error @test_agent.tool_plain def failing_tool() -> str: @@ -1117,7 +1157,7 @@ def failing_tool() -> str: @pytest.mark.asyncio -async def test_context_isolation_concurrent_agents(sentry_init, test_agent): +async def test_context_isolation_concurrent_agents(sentry_init, get_test_agent): """ Test that concurrent agent executions maintain isolated contexts. """ @@ -1150,6 +1190,7 @@ async def run_and_check_context(agent, agent_name): return agent_name + test_agent = get_test_agent() # Run both agents concurrently results = await asyncio.gather( run_and_check_context(test_agent, "agent1"), @@ -1403,12 +1444,14 @@ async def test_agent_data_from_scope(sentry_init, capture_events): @pytest.mark.asyncio async def test_available_tools_without_description( - sentry_init, capture_events, test_agent + sentry_init, capture_events, get_test_agent ): """ Test that available tools are captured even when description is missing. """ + test_agent = get_test_agent() + @test_agent.tool_plain def tool_without_desc(x: int) -> int: # No docstring = no description @@ -1435,11 +1478,13 @@ def tool_without_desc(x: int) -> int: @pytest.mark.asyncio -async def test_output_with_tool_calls(sentry_init, capture_events, test_agent): +async def test_output_with_tool_calls(sentry_init, capture_events, get_test_agent): """ Test that tool calls in model response are captured correctly. """ + test_agent = get_test_agent() + @test_agent.tool_plain def calc_tool(value: int) -> int: """Calculate something.""" @@ -1637,7 +1682,6 @@ async def test_input_messages_error_handling(sentry_init, capture_events): Test that _set_input_messages handles errors gracefully. """ import sentry_sdk - from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_input_messages sentry_init( integrations=[PydanticAIIntegration()], @@ -1791,7 +1835,6 @@ async def test_message_parts_with_list_content(sentry_init, capture_events): """ import sentry_sdk from unittest.mock import MagicMock - from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_input_messages sentry_init( integrations=[PydanticAIIntegration()], @@ -1898,7 +1941,6 @@ async def test_message_with_system_prompt_part(sentry_init, capture_events): """ import sentry_sdk from unittest.mock import MagicMock - from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_input_messages from pydantic_ai import messages sentry_init( @@ -1935,7 +1977,6 @@ async def test_message_with_instructions(sentry_init, capture_events): """ import sentry_sdk from unittest.mock import MagicMock - from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_input_messages sentry_init( integrations=[PydanticAIIntegration()], @@ -1970,7 +2011,6 @@ async def test_set_input_messages_without_prompts(sentry_init, capture_events): Test that _set_input_messages respects _should_send_prompts(). """ import sentry_sdk - from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_input_messages sentry_init( integrations=[PydanticAIIntegration(include_prompts=False)],