Middleware Pipeline
Composable middleware for logging, caching, auth, and custom guard behaviors.
Overview
agentguard's middleware pipeline lets you compose guard behaviors like Express.js middleware. Each middleware function can inspect, modify, or short-circuit tool calls.
from agentguard import guard
from agentguard.middleware import Middleware, Context
class LoggingMiddleware(Middleware):
async def __call__(self, ctx: Context, next_fn):
print(f"→ {{ctx.tool_name}}({{ctx.args}})")
result = await next_fn(ctx)
print(f"← {{ctx.tool_name}} returned in {{ctx.elapsed_ms:.0f}}ms")
return result
class AuthMiddleware(Middleware):
def __init__(self, allowed_tools):
self.allowed_tools = allowed_tools
async def __call__(self, ctx: Context, next_fn):
if ctx.tool_name not in self.allowed_tools:
raise PermissionError(f"Tool {{ctx.tool_name}} not allowed")
return await next_fn(ctx)
@guard(
middleware=[
LoggingMiddleware(),
AuthMiddleware(allowed_tools=["search", "calculate"]),
]
)
def search(query: str) -> dict:
return api.search(query)
Execution Order
Middleware executes in order, wrapping the actual tool call. Think of it as a stack:
Request → Middleware 1 → Middleware 2 → Tool Function Response ← Middleware 1 ← Middleware 2 ← Tool Function
Built-in Middleware
from agentguard.middleware import (
CacheMiddleware, # Cache tool results by arguments
SanitizeMiddleware, # Strip PII from inputs/outputs
AuditMiddleware, # Audit log to file or webhook
TransformMiddleware, # Transform args or results
)
@guard(
middleware=[
CacheMiddleware(ttl=300), # Cache for 5 minutes
SanitizeMiddleware(fields=["ssn", "password"]),
AuditMiddleware(webhook="https://hooks.example.com/audit"),
]
)
def query_user(user_id: str) -> dict:
return db.get_user(user_id)
Multi-Agent Shared State
When multiple agents share tools, use shared state to coordinate budgets, rate limits, and circuit breakers across agent boundaries:
from agentguard import guard, SharedState
# Create shared state (in-memory or Redis-backed)
state = SharedState.redis(url="redis://localhost:6379")
@guard(
shared_state=state,
budget=BudgetConfig(max_cost_per_session=10.00),
rate_limit=RateLimitConfig(calls_per_minute=100),
)
def shared_api(query: str) -> dict:
return api.call(query)
# Now Agent A and Agent B both use shared_api
# Budget and rate limits are enforced ACROSS both agents
Async Support
All guards work with async functions out of the box:
from agentguard import guard
@guard(validate_input=True, detect_hallucination=True, max_retries=3)
async def async_search(query: str) -> dict:
async with aiohttp.ClientSession() as session:
async with session.get(f"https://api.search.com?q={{query}}") as resp:
return await resp.json()
# Use in async context
result = await async_search("hello world")
When using async tools, agentguard uses asyncio internally for retries and timeouts. No thread pool overhead.
MiddlewareChain
The MiddlewareChain is an ordered list of middleware functions that run around every tool call. Middleware is invoked in the order added via use().
from agentguard.middleware import MiddlewareChain
chain = MiddlewareChain()
chain.use(logging_middleware)
chain.use(auth_middleware)
chain.use(timing_middleware)
# Fluent chaining
chain = MiddlewareChain().use(mw1).use(mw2).use(mw3)
Writing Custom Middleware
Each middleware receives a MiddlewareContext and a next callable. It must call await next(ctx) to pass control down the chain:
async def auth_middleware(ctx, next):
if not ctx.metadata.get("api_key"):
raise PermissionError("No API key provided")
ctx.mark("auth_checked")
return await next(ctx)
async def rate_limit_middleware(ctx, next):
# Pre-call: check rate limit
check_rate_limit(ctx.tool_name)
result = await next(ctx)
# Post-call: log timing
print(f"{ctx.tool_name} took {ctx.elapsed_ms():.1f}ms")
return result
MiddlewareContext Fields
| Field | Type | Description |
|---|---|---|
tool_name | str | Name of the guarded tool being called |
args | tuple | Positional arguments |
kwargs | dict | Keyword arguments |
config | GuardConfig | The tool's guard configuration |
metadata | dict | Caller-supplied metadata (api_key, session_id, etc.) |
timestamps | dict | Named timestamps (start, auth_checked, etc.) |
call_id | str | Unique ID for this invocation |
Context Helper Methods
ctx.mark("label")— record a named timestampctx.elapsed_ms(since="start")— milliseconds since a named timestamp
Built-in Middleware Factories
from agentguard.middleware import (
logging_middleware,
timing_middleware,
metadata_middleware,
)
chain = MiddlewareChain()
# Print log lines before/after each call
chain.use(logging_middleware(log_args=True, log_result=True))
# Measure wall-clock time → ctx.metadata["elapsed_ms"]
chain.use(timing_middleware())
# Inject static metadata into every call
chain.use(metadata_middleware(environment="production", version="2.1"))