Skip to content

Commit aee8a0d

Browse files
committed
Restrict subagent names
1 parent 9e45daf commit aee8a0d

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

splunklib/ai/engines/langchain.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import logging
1717
import os
18+
import string
1819
import uuid
1920
from collections.abc import Awaitable, Callable, Sequence
2021
from dataclasses import asdict, dataclass
@@ -1416,11 +1417,22 @@ def _denormalize_tool_name(name: str) -> str:
14161417
return name
14171418

14181419

1420+
def _is_agent_name_valid(name: str) -> bool:
1421+
AGENT_NAME_ALLOWED_CHARS = string.ascii_letters + string.digits + "_-"
1422+
if not (1 <= len(name) <= 128):
1423+
return False
1424+
1425+
return set(name).issubset(AGENT_NAME_ALLOWED_CHARS)
1426+
1427+
14191428
def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
14201429
if not agent.name:
14211430
raise AssertionError("Agent must have a name to be used by other Agents")
14221431

1423-
# TODO: restrict subagent names
1432+
if not _is_agent_name_valid(agent.name):
1433+
raise AssertionError(
1434+
"Agent name is invalid, must contain only letters, numbers, '_' or '-' and have max 128 characters"
1435+
)
14241436

14251437
async def invoke_agent(
14261438
message: HumanMessage, thread_id: str | None

tests/integration/ai/test_agent.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,66 @@ async def test_duplicated_subagent_name(self) -> None:
440440
):
441441
pass
442442

443+
@pytest.mark.asyncio
444+
async def test_subagent_with_invalid_name(self) -> None:
445+
pytest.importorskip("langchain_openai")
446+
447+
async with (
448+
Agent(
449+
model=(await self.model()),
450+
system_prompt="",
451+
service=self.service,
452+
name="invalid name",
453+
) as subagent_invalid,
454+
Agent(
455+
model=(await self.model()),
456+
system_prompt="",
457+
service=self.service,
458+
name="invalid@name",
459+
) as subagent_invalid2,
460+
Agent(
461+
model=(await self.model()),
462+
system_prompt="",
463+
service=self.service,
464+
name="a" * 129,
465+
) as subagent_too_long,
466+
):
467+
with pytest.raises(
468+
AssertionError,
469+
match="Agent name is invalid",
470+
):
471+
async with Agent(
472+
model=(await self.model()),
473+
system_prompt="",
474+
service=self.service,
475+
agents=[subagent_invalid],
476+
):
477+
pass
478+
479+
with pytest.raises(
480+
AssertionError,
481+
match="Agent name is invalid",
482+
):
483+
async with Agent(
484+
model=(await self.model()),
485+
system_prompt="",
486+
service=self.service,
487+
agents=[subagent_invalid2],
488+
):
489+
pass
490+
491+
with pytest.raises(
492+
AssertionError,
493+
match="Agent name is invalid",
494+
):
495+
async with Agent(
496+
model=(await self.model()),
497+
system_prompt="",
498+
service=self.service,
499+
agents=[subagent_too_long],
500+
):
501+
pass
502+
443503
@pytest.mark.asyncio
444504
async def test_subagent_soft_failure_with_invalid_args(self) -> None:
445505
pytest.importorskip("langchain_openai")

0 commit comments

Comments
 (0)