Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,45 @@ def unpack_tool_call(self, call: LC_ToolCall) -> LC_ToolCall:

return call

class _CheckCallIDMiddleware(LC_AgentMiddleware):
def _check_has_call_id(self, msg: LC_AIMessage) -> None:
for call in msg.tool_calls:
if not call["id"]:
# If we ever hit this with real model, just generate a random call_id here.
raise Exception("LLM returned a Tool Call without a call_id")

@override
async def awrap_model_call(
self,
request: LC_ModelRequest,
handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]],
) -> LC_ModelCallResult:
try:
resp = await handler(request)
ai_message = resp
if isinstance(ai_message, LC_ExtendedModelResponse):
ai_message = ai_message.model_response
if isinstance(ai_message, LC_ModelResponse):
ai_message = next(
(
m
for m in ai_message.result
if isinstance(m, LC_AIMessage)
),
None,
)
assert ai_message, "AIMessage not found found in response"
self._check_has_call_id(ai_message)
return resp
except LC_StructuredOutputError as e:
self._check_has_call_id(e.ai_message)
raise

lc_middleware.append(_ToolFailureArtifact())
if len(conversational_subagents) > 0:
lc_middleware.append(_ThreadIDMiddleware())
lc_middleware.append(_SubagentArgumentPacker())
lc_middleware.append(_CheckCallIDMiddleware())

class _DEBUGMiddleware(LC_AgentMiddleware):
@override
Expand Down Expand Up @@ -1254,7 +1289,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
StructuredOutputCall(
name=tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
args=tc["args"],
id=tc["id"],
id=tc["id"] or "",
)
for tc in ai_message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
Expand Down Expand Up @@ -1529,7 +1564,7 @@ def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | Subagent
name=_denormalize_agent_name(name),
args=SubagentLCArgs(**tool_call["args"]).args,
thread_id=SubagentLCArgs(**tool_call["args"]).thread_id,
id=tool_call["id"],
id=tool_call["id"] or "",
)

tool_type: ToolType = (
Expand All @@ -1538,7 +1573,7 @@ def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | Subagent
return ToolCall(
name=_denormalize_tool_name(name),
args=tool_call["args"],
id=tool_call["id"],
id=tool_call["id"] or "",
type=tool_type,
)

Expand Down Expand Up @@ -1567,9 +1602,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
],
structured_output_calls=[
StructuredOutputCall(
tc["id"] or "",
tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
tc["args"],
tc["id"],
)
for tc in message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
Expand Down
8 changes: 4 additions & 4 deletions splunklib/ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@

@dataclass(frozen=True)
class ToolCall:
id: str
name: str
args: dict[str, Any]
id: str | None # TODO: can be None?
type: ToolType
args: dict[str, Any]


@dataclass(frozen=True)
class SubagentCall:
id: str
name: str
args: str | dict[str, Any]
id: str | None # TODO: can be None?
thread_id: str | None


@dataclass(frozen=True)
class StructuredOutputCall:
id: str
name: str
args: dict[str, Any]
id: str | None # TODO: can be None?


@dataclass(frozen=True)
Expand Down
Loading