Skip to content

Commit b658e4b

Browse files
authored
feat(integrations): Add async support for ai_track decorator
This commit adds capabilities to support async functions for the `ai_track` decorator
1 parent fc5db4f commit b658e4b

File tree

2 files changed

+97
-3
lines changed

2 files changed

+97
-3
lines changed

sentry_sdk/ai/monitoring.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from functools import wraps
23

34
import sentry_sdk.utils
@@ -26,8 +27,7 @@ def ai_track(description, **span_kwargs):
2627
# type: (str, Any) -> Callable[..., Any]
2728
def decorator(f):
2829
# type: (Callable[..., Any]) -> Callable[..., Any]
29-
@wraps(f)
30-
def wrapped(*args, **kwargs):
30+
def sync_wrapped(*args, **kwargs):
3131
# type: (Any, Any) -> Any
3232
curr_pipeline = _ai_pipeline_name.get()
3333
op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline")
@@ -56,7 +56,39 @@ def wrapped(*args, **kwargs):
5656
_ai_pipeline_name.set(None)
5757
return res
5858

59-
return wrapped
59+
async def async_wrapped(*args, **kwargs):
60+
# type: (Any, Any) -> Any
61+
curr_pipeline = _ai_pipeline_name.get()
62+
op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline")
63+
64+
with start_span(description=description, op=op, **span_kwargs) as span:
65+
for k, v in kwargs.pop("sentry_tags", {}).items():
66+
span.set_tag(k, v)
67+
for k, v in kwargs.pop("sentry_data", {}).items():
68+
span.set_data(k, v)
69+
if curr_pipeline:
70+
span.set_data("ai.pipeline.name", curr_pipeline)
71+
return await f(*args, **kwargs)
72+
else:
73+
_ai_pipeline_name.set(description)
74+
try:
75+
res = await f(*args, **kwargs)
76+
except Exception as e:
77+
event, hint = sentry_sdk.utils.event_from_exception(
78+
e,
79+
client_options=sentry_sdk.get_client().options,
80+
mechanism={"type": "ai_monitoring", "handled": False},
81+
)
82+
sentry_sdk.capture_event(event, hint=hint)
83+
raise e from None
84+
finally:
85+
_ai_pipeline_name.set(None)
86+
return res
87+
88+
if inspect.iscoroutinefunction(f):
89+
return wraps(f)(async_wrapped)
90+
else:
91+
return wraps(f)(sync_wrapped)
6092

6193
return decorator
6294

tests/test_ai_monitoring.py

+62
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
import sentry_sdk
24
from sentry_sdk.ai.monitoring import ai_track
35

@@ -57,3 +59,63 @@ def pipeline():
5759
assert ai_pipeline_span["tags"]["user"] == "colin"
5860
assert ai_pipeline_span["data"]["some_data"] == "value"
5961
assert ai_run_span["description"] == "my tool"
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_ai_track_async(sentry_init, capture_events):
66+
sentry_init(traces_sample_rate=1.0)
67+
events = capture_events()
68+
69+
@ai_track("my async tool")
70+
async def async_tool(**kwargs):
71+
pass
72+
73+
@ai_track("some async test pipeline")
74+
async def async_pipeline():
75+
await async_tool()
76+
77+
with sentry_sdk.start_transaction():
78+
await async_pipeline()
79+
80+
transaction = events[0]
81+
assert transaction["type"] == "transaction"
82+
assert len(transaction["spans"]) == 2
83+
spans = transaction["spans"]
84+
85+
ai_pipeline_span = spans[0] if spans[0]["op"] == "ai.pipeline" else spans[1]
86+
ai_run_span = spans[0] if spans[0]["op"] == "ai.run" else spans[1]
87+
88+
assert ai_pipeline_span["description"] == "some async test pipeline"
89+
assert ai_run_span["description"] == "my async tool"
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_ai_track_async_with_tags(sentry_init, capture_events):
94+
sentry_init(traces_sample_rate=1.0)
95+
events = capture_events()
96+
97+
@ai_track("my async tool")
98+
async def async_tool(**kwargs):
99+
pass
100+
101+
@ai_track("some async test pipeline")
102+
async def async_pipeline():
103+
await async_tool()
104+
105+
with sentry_sdk.start_transaction():
106+
await async_pipeline(
107+
sentry_tags={"user": "czyber"}, sentry_data={"some_data": "value"}
108+
)
109+
110+
transaction = events[0]
111+
assert transaction["type"] == "transaction"
112+
assert len(transaction["spans"]) == 2
113+
spans = transaction["spans"]
114+
115+
ai_pipeline_span = spans[0] if spans[0]["op"] == "ai.pipeline" else spans[1]
116+
ai_run_span = spans[0] if spans[0]["op"] == "ai.run" else spans[1]
117+
118+
assert ai_pipeline_span["description"] == "some async test pipeline"
119+
assert ai_pipeline_span["tags"]["user"] == "czyber"
120+
assert ai_pipeline_span["data"]["some_data"] == "value"
121+
assert ai_run_span["description"] == "my async tool"

0 commit comments

Comments
 (0)