Skip to content

Commit 0687d7f

Browse files
committed
Make id required (non-nullable) for tool/subagent/output calls
Also while here, move the id field to be first.
1 parent 9e45daf commit 0687d7f

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

splunklib/ai/engines/langchain.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,45 @@ def unpack_tool_call(self, call: LC_ToolCall) -> LC_ToolCall:
481481

482482
return call
483483

484+
class _CheckCallIDMiddleware(LC_AgentMiddleware):
485+
def _check_has_call_id(self, msg: LC_AIMessage) -> None:
486+
for call in msg.tool_calls:
487+
if call["id"] is None:
488+
# If we ever hit this with real model, just generate a random call_id here.
489+
raise Exception("LLM returned a Tool Call without a call_id")
490+
491+
@override
492+
async def awrap_model_call(
493+
self,
494+
request: LC_ModelRequest,
495+
handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]],
496+
) -> LC_ModelCallResult:
497+
try:
498+
resp = await handler(request)
499+
ai_message = resp
500+
if isinstance(ai_message, LC_ExtendedModelResponse):
501+
ai_message = ai_message.model_response
502+
if isinstance(ai_message, LC_ModelResponse):
503+
ai_message = next(
504+
(
505+
m
506+
for m in ai_message.result
507+
if isinstance(m, LC_AIMessage)
508+
),
509+
None,
510+
)
511+
assert ai_message, "AIMessage not found found in response"
512+
self._check_has_call_id(ai_message)
513+
return resp
514+
except LC_StructuredOutputError as e:
515+
self._check_has_call_id(e.ai_message)
516+
raise
517+
484518
lc_middleware.append(_ToolFailureArtifact())
485519
if len(conversational_subagents) > 0:
486520
lc_middleware.append(_ThreadIDMiddleware())
487521
lc_middleware.append(_SubagentArgumentPacker())
522+
lc_middleware.append(_CheckCallIDMiddleware())
488523

489524
class _DEBUGMiddleware(LC_AgentMiddleware):
490525
@override
@@ -1254,7 +1289,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
12541289
StructuredOutputCall(
12551290
name=tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
12561291
args=tc["args"],
1257-
id=tc["id"],
1292+
id=tc["id"] or "",
12581293
)
12591294
for tc in ai_message.tool_calls
12601295
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
@@ -1529,7 +1564,7 @@ def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | Subagent
15291564
name=_denormalize_agent_name(name),
15301565
args=SubagentLCArgs(**tool_call["args"]).args,
15311566
thread_id=SubagentLCArgs(**tool_call["args"]).thread_id,
1532-
id=tool_call["id"],
1567+
id=tool_call["id"] or "",
15331568
)
15341569

15351570
tool_type: ToolType = (
@@ -1538,7 +1573,7 @@ def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | Subagent
15381573
return ToolCall(
15391574
name=_denormalize_tool_name(name),
15401575
args=tool_call["args"],
1541-
id=tool_call["id"],
1576+
id=tool_call["id"] or "",
15421577
type=tool_type,
15431578
)
15441579

@@ -1567,9 +1602,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15671602
],
15681603
structured_output_calls=[
15691604
StructuredOutputCall(
1605+
tc["id"] or "",
15701606
tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
15711607
tc["args"],
1572-
tc["id"],
15731608
)
15741609
for tc in message.tool_calls
15751610
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)

splunklib/ai/messages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,25 @@
2323

2424
@dataclass(frozen=True)
2525
class ToolCall:
26+
id: str
2627
name: str
27-
args: dict[str, Any]
28-
id: str | None # TODO: can be None?
2928
type: ToolType
29+
args: dict[str, Any]
3030

3131

3232
@dataclass(frozen=True)
3333
class SubagentCall:
34+
id: str
3435
name: str
3536
args: str | dict[str, Any]
36-
id: str | None # TODO: can be None?
3737
thread_id: str | None
3838

3939

4040
@dataclass(frozen=True)
4141
class StructuredOutputCall:
42+
id: str
4243
name: str
4344
args: dict[str, Any]
44-
id: str | None # TODO: can be None?
4545

4646

4747
@dataclass(frozen=True)

0 commit comments

Comments
 (0)