如何开发自定义工具和中间件以应对需求?
摘要:自定义工具与中间件开发 前言 在前几篇文章中,我们学习了如何使用Agent Framework内置的工具,以及如何通过工作流编排来协调多个任务。但是在实际业务中,内置工具往往不能满足所有需求。我们经常需要创建自定义工具来对接企业系统、调用外
自定义工具与中间件开发
前言
在前几篇文章中,我们学习了如何使用Agent Framework内置的工具,以及如何通过工作流编排来协调多个任务。但是在实际业务中,内置工具往往不能满足所有需求。我们经常需要创建自定义工具来对接企业系统、调用外部API、操作数据库等。
本文将深入探讨如何在Agent Framework中开发自定义工具和中间件,包括工具的定义、注册、参数验证、错误处理,以及中间件的开发和使用。通过本文的学习,你将能够创建功能强大且安全的自定义工具,构建可扩展的Agent系统。
一、自定义工具基础
1.1 工具接口定义
Agent Framework中的工具需要实现特定的接口:
// ITool.cs - 工具基础接口
public interface ITool
{
// 工具唯一标识
string Id { get; }
// 工具名称
string Name { get; }
// 工具描述
string Description { get; }
// 工具参数模式(用于LLM理解如何调用)
ToolParameterSchema GetParameterSchema();
// 执行工具
Task<ToolExecutionResult> ExecuteAsync(Dictionary<string, object> parameters);
}
// 工具参数模式
public class ToolParameterSchema
{
public string Type { get; set; } = "object";
public Dictionary<string, ToolParameterProperty> Properties { get; set; } = new();
public List<string> Required { get; set; } = new();
}
public class ToolParameterProperty
{
public string Type { get; set; } = "string";
public string Description { get; set; } = string.Empty;
public bool Required { get; set; }
public object? Default { get; set; }
public List<string>? Enum { get; set; }
}
// 工具执行结果
public class ToolExecutionResult
{
public bool Success { get; set; }
public string? Output { get; set; }
public string? Error { get; set; }
public Dictionary<string, object> Metadata { get; set; } = new();
}
1.2 抽象基类实现
为了简化工具开发,我们创建一个抽象基类:
// BaseTool.cs
public abstract class BaseTool : ITool
{
public abstract string Id { get; }
public abstract string Name { get; }
public abstract string Description { get; }
protected abstract Task<ToolExecutionResult> ExecuteInternalAsync(
Dictionary<string, object> parameters);
public virtual ToolParameterSchema GetParameterSchema()
{
return new ToolParameterSchema();
}
public async Task<ToolExecutionResult> ExecuteAsync(Dictionary<string, object> parameters)
{
try
{
// 参数验证
var validationResult = ValidateParameters(parameters);
if (!validationResult.IsValid)
{
return new ToolExecutionResult
{
Success = false,
Error = validationResult.ErrorMessage
};
}
// 执行工具逻辑
var result = await ExecuteInternalAsync(parameters);
return result;
}
catch (Exception ex)
{
return new ToolExecutionResult
{
Success = false,
Error = $"工具执行错误: {ex.Message}"
};
}
}
protected virtual ParameterValidationResult ValidateParameters(
Dictionary<string, object> parameters)
{
var schema = GetParameterSchema();
// 检查必需参数
foreach (var requiredParam in schema.Required)
{
if (!parameters.ContainsKey(requiredParam) ||
parameters[requiredParam] == null)
{
return new ParameterValidationResult
{
IsValid = false,
ErrorMessage = $"缺少必需参数: {requiredParam}"
};
}
}
return new ParameterValidationResult { IsValid = true };
}
}
public class ParameterValidationResult
{
public bool IsValid { get; set; }
public string? ErrorMessage { get; set; }
}
二、实用工具开发
2.1 数据库查询工具
首先,让我们创建一个可以执行SQL查询的工具:
// DatabaseQueryTool.cs
public class DatabaseQueryTool : BaseTool
{
private readonly string _connectionString;
private readonly ILogger<DatabaseQueryTool> _logger;
private readonly int _maxRows;
public override string Id => "db_query";
public override string Name => "数据库查询";
public override string Description => "执行SQL查询并返回结果。适用于查询数据、统计信息等。只支持SELECT语句。";
public DatabaseQueryTool(
string connectionString,
ILogger<DatabaseQueryTool> logger,
int maxRows = 100)
{
_connectionString = connectionString;
_logger = logger;
_maxRows = maxRows;
}
public override ToolParameterSchema GetParameterSchema()
{
return new ToolParameterSchema
{
Type = "object",
Properties = new Dictionary<string, ToolParameterProperty>
{
["sql"] = new ToolParameterProperty
{
Type = "string",
Description = "要执行的SQL查询语句。只支持SELECT语句,不支持UPDATE/INSERT/DELETE。",
Required = true
},
["parameters"] = new ToolParameterProperty
{
Type = "object",
Description = "查询参数,用于参数化查询",
Required = false
}
},
Required = new List<string> { "sql" }
};
}
protected override async Task<ToolExecutionResult> ExecuteInternalAsync(
Dictionary<string, object> parameters)
{
var sql = parameters["sql"]?.ToString();
if (string.IsNullOrWhiteSpace(sql))
{
return new ToolExecutionResult
{
Success = false,
Error = "SQL语句不能为空"
};
}
// 安全检查:只允许SELECT语句
var trimmedSql = sql.Trim().ToUpperInvariant();
if (!trimmedSql.StartsWith("SELECT"))
{
return new ToolExecutionResult
{
Success = false,
Error = "只允许执行SELECT查询语句"
};
}
_logger.LogInformation("执行查询: {Sql}", sql);
try
{
using var connection = new SqlConnection(_connectionString);
await connection.OpenAsync();
// 使用参数化查询防止SQL注入
using var command = new SqlCommand(sql, connection);
// 添加查询参数
if (parameters.ContainsKey("parameters") &&
parameters["parameters"] is Dictionary<string, object> queryParams)
{
foreach (var kvp in queryParams)
{
command.Parameters.AddWithValue(kvp.Key, kvp.Value ?? DBNull.Value);
}
}
// 限制返回行数
command.CommandText += $" LIMIT {_maxRows}";
using var reader = await command.ExecuteReaderAsync();
var results = new List<Dictionary<string, object?>>();
var columns = new List<string>();
// 获取列名
for (int i = 0; i < reader.FieldCount; i++)
{
columns.Add(reader.GetName(i));
}
// 读取数据
while (await reader.ReadAsync())
{
var row = new Dictionary<string, object?>();
for (int i = 0; i < columns.Count; i++)
{
row[columns[i]] = reader.IsDBNull(i) ? null : reader.GetValue(i);
}
results.Add(row);
}
_logger.LogInformation("查询返回 {Count} 行", results.Count);
return new ToolExecutionResult
{
Success = true,
Output = JsonSerializer.Serialize(new
{
columns,
rows = results,
totalCount = results.Count
}),
Metadata = new Dictionary<string, object>
{
{ "rowCount", results.Count },
{ "executionTime", DateTime.UtcNow }
}
};
}
catch (Exception ex)
{
_logger.LogError(ex, "查询执行失败");
return new ToolExecutionResult
{
Success = false,
Error = $"查询执行失败: {ex.Message}"
};
}
}
}
2.2 HTTP请求工具
创建一个通用的HTTP请求工具:
// HttpRequestTool.cs
public class HttpRequestTool : BaseTool
{
private readonly HttpClient _httpClient;
private readonly ILogger<HttpRequestTool> _logger;
private readonly Dictionary<string, string> _allowedHeaders;
public override string Id => "http_request";
public override string Name => "HTTP请求";
public override string Description => "发送HTTP请求,支持GET、POST、PUT、DELETE等方法。可以用于调用外部API。";
public HttpRequestTool(
HttpClient httpClient,
ILogger<HttpRequestTool> logger,
Dictionary<string, string>? defaultHeaders = null)
{
_httpClient = httpClient;
_logger = logger;
_allowedHeaders = defaultHeaders ?? new Dictionary<string, string>();
}
public override ToolParameterSchema GetParameterSchema()
{
return new ToolParameterSchema
{
Type = "object",
Properties = new Dictionary<string, ToolParameterProperty>
{
["url"] = new ToolParameterProperty
{
Type = "string",
Description = "请求的URL地址",
Required = true
},
["method"] = new ToolParameterProperty
{
Type = "string",
Description = "HTTP方法:GET、POST、PUT、DELETE、PATCH",
Required = false,
Default = "GET",
Enum = new List<string> { "GET", "POST", "PUT", "DELETE", "PATCH" }
},
["headers"] = new ToolParameterProperty
{
Type = "object",
Description = "请求头",
Required = false
},
["body"] = new ToolParameterProperty
{
Type = "object",
Description = "请求体(JSON对象)",
Required = false
},
["timeout"] = new ToolParameterProperty
{
Type = "number",
Description = "请求超时时间(秒)",
Required = false,
Default = 30
}
},
Required = new List<string> { "url" }
};
}
protected override async Task<ToolExecutionResult> ExecuteInternalAsync(
Dictionary<string, object> parameters)
{
var url = parameters["url"]?.ToString();
var method = parameters.GetValueOrDefault("method")?.ToString() ?? "GET";
var timeout = parameters.GetValueOrDefault("timeout") is int t ? t : 30;
if (string.IsNullOrWhiteSpace(url))
{
return new ToolExecutionResult
{
Success = false,
Error = "URL不能为空"
};
}
_logger.LogInformation("发送HTTP请求: {Method} {Url}", method, url);
try
{
var request = new HttpRequestMessage(
new HttpMethod(method),
url);
// 添加默认请求头
foreach (var header in _allowedHeaders)
{
request.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
// 添加自定义请求头
if (parameters.ContainsKey("headers") &&
parameters["headers"] is Dictionary<string, object> customHeaders)
{
foreach (var header in customHeaders)
{
request.Headers.TryAddWithoutValidation(header.Key, header.Value?.ToString());
}
}
// 添加请求体
if (parameters.ContainsKey("body") &&
method != "GET")
{
var body = parameters["body"];
var json = JsonSerializer.Serialize(body);
request.Content = new StringContent(json, Encoding.UTF8, "application/json");
}
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(timeout));
var response = await _httpClient.SendAsync(request, cts.Token);
var content = await response.Content.ReadAsStringAsync();
_logger.LogInformation("HTTP响应: {StatusCode}, Content-Length={Length}",
response.StatusCode, content.Length);
return new ToolExecutionResult
{
Success = response.IsSuccessStatusCode,
Output = content,
Metadata = new Dictionary<string, object>
{
{ "statusCode", (int)response.StatusCode },
{ "reasonPhrase", response.ReasonPhrase ?? "" }
}
};
}
catch (TaskCanceledException)
{
return new ToolExecutionResult
{
Success = false,
Error = $"请求超时({timeout}秒)"
};
}
catch (Exception ex)
{
_logger.LogError(ex, "HTTP请求失败");
return new ToolExecutionResult
{
Success = false,
Error = $"请求失败: {ex.Message}"
};
}
}
}
2.3 文件操作工具
创建用于读取和写入文件的工具:
// FileOperationTool.cs
public class FileOperationTool : BaseTool
{
private readonly string _baseDirectory;
private readonly ILogger<FileOperationTool> _logger;
private static readonly string[] AllowedExtensions = { ".txt", ".json", ".xml", ".csv", ".md" };
public override string Id => "file_operation";
public override string Name => "文件操作";
public override string Description => "读取或写入文件。支持读取文本文件内容,或将内容写入文件。";
public FileOperationTool(
string baseDirectory,
ILogger<FileOperationTool> logger)
{
_baseDirectory = baseDirectory;
_logger = logger;
// 确保目录存在
Directory.CreateDirectory(_baseDirectory);
}
public override ToolParameterSchema GetParameterSchema()
{
return new ToolParameterSchema
{
Type = "object",
Properties = new Dictionary<string, ToolParameterProperty>
{
["operation"] = new ToolParameterProperty
{
Type = "string",
Description = "操作类型:read(读取)或 write(写入)",
Required = true,
Enum = new List<string> { "read", "write" }
},
["path"] = new ToolParameterProperty
{
Type = "string",
Description = "文件路径(相对于基础目录)",
Required = true
},
["content"] = new ToolParameterProperty
{
Type = "string",
Description = "写入文件的内容(write操作时必需)",
Required = false
}
},
Required = new List<string> { "operation", "path" }
};
}
protected override Task<ToolExecutionResult> ExecuteInternalAsync(
Dictionary<string, object> parameters)
{
var operation = parameters["operation"]?.ToString()?.ToLower();
var path = parameters["path"]?.ToString();
if (string.IsNullOrWhiteSpace(path))
{
return Task.FromResult(new ToolExecutionResult
{
Success = false,
Error = "文件路径不能为空"
});
}
// 安全检查:防止路径遍历攻击
var fullPath = GetFullPath(path);
if (fullPath == null)
{
return Task.FromResult(new ToolExecutionResult
{
Success = false,
Error = "无效的文件路径"
});
}
return operation switch
{
"read" => ReadFileAsync(fullPath),
"write" => WriteFileAsync(fullPath, parameters),
_ => Task.FromResult(new ToolExecutionResult
{
Success = false,
Error = $"不支持的操作: {operation}"
})
};
}
private string? GetFullPath(string relativePath)
{
try
{
var fullPath = Path.GetFullPath(Path.Combine(_baseDirectory, relativePath));
// 确保路径在基础目录内
if (!fullPath.StartsWith(_baseDirectory, StringComparison.OrdinalIgnoreCase))
{
return null;
}
// 检查文件扩展名
var extension = Path.GetExtension(fullPath);
if (!AllowedExtensions.Contains(extension.ToLower()))
{
return null;
}
return fullPath;
}
catch
{
return null;
}
}
private Task<ToolExecutionResult> ReadFileAsync(string path)
{
try
{
if (!File.Exists(path))
{
return Task.FromResult(new ToolExecutionResult
{
Success = false,
Error = $"文件不存在: {path}"
});
}
var content = File.ReadAllText(path);
_logger.LogInformation("读取文件: {Path}, Size={Size}", path, content.Length);
return Task.FromResult(new ToolExecutionResult
{
Success = true,
Output = content,
Metadata = new Dictionary<string, object>
{
{ "path", path },
{ "size", content.Length }
}
});
}
catch (Exception ex)
{
_logger.LogError(ex, "读取文件失败: {Path}", path);
return Task.FromResult(new ToolExecutionResult
{
Success = false,
Error = $"读取文件失败: {ex.Message}"
});
}
}
private Task<ToolExecutionResult> WriteFileAsync(
string path,
Dictionary<string, object> parameters)
{
var content = parameters.GetValueOrDefault("content")?.ToString();
if (string.IsNullOrEmpty(content))
{
return Task.FromResult(new ToolExecutionResult
{
Success = false,
Error = "写入内容不能为空"
});
}
try
{
// 确保目录存在
var directory = Path.GetDirectoryName(path);
if (!string.IsNullOrEmpty(directory))
{
Directory.CreateDirectory(directory);
}
File.WriteAllText(path, content);
_logger.LogInformation("写入文件: {Path}, Size={Size}", path, content.Length);
return Task.FromResult(new ToolExecutionResult
{
Success = true,
Output = $"文件已写入: {path}",
Metadata = new Dictionary<string, object>
{
{ "path", path },
{ "size", content.Length }
}
});
}
catch (Exception ex)
{
_logger.LogError(ex, "写入文件失败: {Path}", path);
return Task.FromResult(new ToolExecutionResult
{
Success = false,
Error = $"写入文件失败: {ex.Message}"
});
}
}
}
三、工具注册与管理
3.1 工具管理器
// ToolRegistry.cs
public class ToolRegistry : IToolRegistry
{
private readonly Dictionary<string, ITool> _tools;
private readonly ILogger<ToolRegistry> _logger;
private readonly SemaphoreSlim _lock = new(1, 1);
public ToolRegistry(ILogger<ToolRegistry> logger)
{
_tools = new Dictionary<string, ITool>();
_logger = logger;
}
public void Register(ITool tool)
{
_lock.Wait();
try
{
if (_tools.ContainsKey(tool.Id))
{
_logger.LogWarning("工具已存在,将被替换: {ToolId}", tool.Id);
_tools[tool.Id] = tool;
}
else
{
_tools[tool.Id] = tool;
_logger.LogInformation("注册工具: {ToolName} ({ToolId})", tool.Name, tool.Id);
}
}
finally
{
_lock.Release();
}
}
public void Unregister(string toolId)
{
_lock.Wait();
try
{
if (_tools.Remove(toolId))
{
_logger.LogInformation("注销工具: {ToolId}", toolId);
}
}
finally
{
_lock.Release();
}
}
public ITool? Get(string toolId)
{
_tools.TryGetValue(toolId, out var tool);
return tool;
}
public IEnumerable<ITool> GetAll()
{
return _tools.Values;
}
public IEnumerable<ITool> Search(string? keyword)
{
if (string.IsNullOrWhiteSpace(keyword))
{
return GetAll();
}
return _tools.Values
.Where(t => t.Name.Contains(keyword, StringComparison.OrdinalIgnoreCase) ||
t.Description.Contains(keyword, StringComparison.OrdinalIgnoreCase))
.ToList();
}
public ToolList GetToolListForLLM()
{
var tools = _tools.Values.Select(tool => new ToolDefinition
{
Type = "function",
Function = new ToolFunctionDefinition
{
Name = tool.Id,
Description = tool.Description,
Parameters = tool.GetParameterSchema()
}
}).ToList();
return new ToolList { Tools = tools };
}
}
public interface IToolRegistry
{
void Register(ITool tool);
void Unregister(string toolId);
ITool? Get(string toolId);
IEnumerable<ITool> GetAll();
IEnumerable<ITool> Search(string? keyword);
ToolList GetToolListForLLM();
}
public class ToolList
{
public List<ToolDefinition> Tools { get; set; } = new();
}
public class ToolDefinition
{
public string Type { get; set; } = string.Empty;
public ToolFunctionDefinition Function { get; set; } = new();
}
public class ToolFunctionDefinition
{
public string Name { get; set; } = string.Empty;
public string Description { get; set; } = string.Empty;
public ToolParameterSchema Parameters { get; set; } = new();
}
3.2 依赖注入集成
在ASP.NET Core中集成工具注册:
// ToolExtensions.cs
public static class ToolExtensions
{
public static IServiceCollection AddAgentTools(this IServiceCollection services)
{
// 注册工具注册表
services.AddSingleton<IToolRegistry, ToolRegistry>();
// 注册内置工具
services.AddSingleton<ITool>(sp =>
{
var logger = sp.GetRequiredService<ILogger<DatabaseQueryTool>>();
var connectionString = sp.GetRequiredService<IConfiguration>()["Database:ConnectionString"];
return new DatabaseQueryTool(connectionString!, logger);
});
services.AddSingleton<ITool>(sp =>
{
var logger = sp.GetRequiredService<ILogger<HttpRequestTool>>();
var httpClient = sp.GetRequiredService<HttpClient>();
var headers = sp.GetRequiredService<IConfiguration>()
.GetSection("HttpClient:DefaultHeaders")
.Get<Dictionary<string, string>>();
return new HttpRequestTool(httpClient, logger, headers);
});
services.AddSingleton<ITool>(sp =>
{
var logger = sp.GetRequiredService<ILogger<FileOperationTool>>();
var baseDirectory = sp.GetRequiredService<IConfiguration>()["FileStorage:BaseDirectory"]
?? Path.Combine(AppContext.BaseDirectory, "files");
return new FileOperationTool(baseDirectory, logger);
});
// 注册自定义工具
services.AddTransient<ICustomTool, OrderQueryTool>();
services.AddTransient<ICustomTool, WeatherQueryTool>();
return services;
}
public static IAgentHost UseTools(this IAgentHost host, IServiceProvider serviceProvider)
{
var registry = serviceProvider.GetRequiredService<IToolRegistry>();
// 注册所有实现了ITool接口的服务
var tools = serviceProvider.GetServices<ITool>();
foreach (var tool in tools)
{
registry.Register(tool);
}
return host;
}
}
四、中间件开发
4.1 中间件基础
中间件允许在请求处理过程中插入自定义逻辑:
// IAgentMiddleware.cs
public interface IAgentMiddleware
{
// 异步调用下一个中间件
Task InvokeAsync(AgentContext context, AgentDelegate next);
}
// 中间件委托
public delegate Task AgentDelegate(AgentContext context);
// 中间件管道
public class AgentMiddlewarePipeline
{
private readonly List<IAgentMiddleware> _middlewares = new();
private readonly ILogger<AgentMiddlewarePipeline> _logger;
public AgentMiddlewarePipeline(ILogger<AgentMiddlewarePipeline> logger)
{
_logger = logger;
}
public void Use(IAgentMiddleware middleware)
{
_middlewares.Add(middleware);
_logger.LogInformation("注册中间件: {MiddlewareType}", middleware.GetType().Name);
}
public void Use(Func<AgentContext, AgentDelegate, Task> middleware)
{
_middlewares.Add(new AnonymousMiddleware(middleware));
}
public async Task InvokeAsync(AgentContext context)
{
var index = 0;
AgentDelegate next = async (ctx) =>
{
if (index >= _middlewares.Count)
{
return; // 没有更多中间件了
}
var middleware = _middlewares[index++];
await middleware.InvokeAsync(ctx, next);
};
if (_middlewares.Count > 0)
{
await _middlewares[0].InvokeAsync(context, next);
}
}
private class AnonymousMiddleware : IAgentMiddleware
{
private readonly Func<AgentContext, AgentDelegate, Task> _middleware;
public AnonymousMiddleware(Func<AgentContext, AgentDelegate, Task> middleware)
{
_middleware = middleware;
}
public Task InvokeAsync(AgentContext context, AgentDelegate next)
{
return _middleware(context, next);
}
}
}
4.2 日志记录中间件
// LoggingMiddleware.cs
public class LoggingMiddleware : IAgentMiddleware
{
private readonly ILogger<LoggingMiddleware> _logger;
public LoggingMiddleware(ILogger<LoggingMiddleware> logger)
{
_logger = logger;
}
public async Task InvokeAsync(AgentContext context, AgentDelegate next)
{
var startTime = DateTime.UtcNow;
var requestId = Guid.NewGuid().ToString("N")[..8];
context.Properties["RequestId"] = requestId;
_logger.LogInformation(
"[{RequestId}] 开始处理请求: UserId={UserId}, Message={Message}",
requestId,
context.UserId,
context.Message);
try
{
await next(context);
var duration = DateTime.UtcNow - startTime;
_logger.LogInformation(
"[{RequestId}] 请求处理完成: Duration={Duration}ms, Success={Success}",
requestId,
duration.TotalMilliseconds,
context.Properties.GetValueOrDefault("Success"));
}
catch (Exception ex)
{
var duration = DateTime.UtcNow - startTime;
_logger.LogError(ex,
"[{RequestId}] 请求处理失败: Duration={Duration}ms, Error={Error}",
requestId,
duration.TotalMilliseconds,
ex.Message);
throw;
}
}
}
4.3 认证中间件
// AuthenticationMiddleware.cs
public class AuthenticationMiddleware : IAgentMiddleware
{
private readonly ITokenValidator _tokenValidator;
private readonly ILogger<AuthenticationMiddleware> _logger;
private readonly List<string> _excludedPaths;
public AuthenticationMiddleware(
ITokenValidator tokenValidator,
ILogger<AuthenticationMiddleware> logger,
List<string>? excludedPaths = null)
{
_tokenValidator = tokenValidator;
_logger = logger;
_excludedPaths = excludedPaths ?? new List<string> { "/health", "/ready" };
}
public async Task InvokeAsync(AgentContext context, AgentDelegate next)
{
// 检查是否需要认证
if (ShouldSkipAuthentication(context))
{
await next(context);
return;
}
// 获取token
var token = ExtractToken(context);
if (string.IsNullOrEmpty(token))
{
_logger.LogWarning("缺少认证令牌: Path={Path}", context.Properties.GetValueOrDefault("Path"));
context.Properties["AuthError"] = "Missing authentication token";
context.Response = new AgentResponse
{
Success = false,
Error = "需要认证"
};
return;
}
// 验证token
try
{
var validationResult = await _tokenValidator.ValidateAsync(token);
if (!validationResult.IsValid)
{
_logger.LogWarning("Token验证失败: {Error}", validationResult.Error);
context.Properties["AuthError"] = validationResult.Error;
context.Response = new AgentResponse
{
Success = false,
Error = "认证失败"
};
return;
}
// 设置用户信息
context.UserId = validationResult.UserId;
context.Properties["UserRoles"] = validationResult.Roles;
_logger.LogDebug("认证成功: UserId={UserId}", validationResult.UserId);
await next(context);
}
catch (Exception ex)
{
_logger.LogError(ex, "认证过程出错");
context.Properties["AuthError"] = ex.Message;
context.Response = new AgentResponse
{
Success = false,
Error = "认证服务错误"
};
}
}
private bool ShouldSkipAuthentication(AgentContext context)
{
var path = context.Properties.GetValueOrDefault("Path")?.ToString() ?? "";
return _excludedPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
}
private string? ExtractToken(AgentContext context)
{
// 从Header中获取
if (context.Properties.TryGetValue("Authorization", out var authHeader))
{
var header = authHeader?.ToString();
if (header?.StartsWith("Bearer ") == true)
{
return header.Substring(7);
}
}
// 从Cookie中获取
if (context.Properties.TryGetValue("Cookie", out var cookie))
{
var cookieString = cookie?.ToString();
var tokenMatch = System.Text.RegularExpressions.Regex.Match(
cookieString ?? "",
@"token=([^;]+)");
if (tokenMatch.Success)
{
return tokenMatch.Groups[1].Value;
}
}
return null;
}
}
4.4 限流中间件
// RateLimitingMiddleware.cs
public class RateLimitingMiddleware : IAgentMiddleware
{
private readonly IRateLimiter _rateLimiter;
private readonly ILogger<RateLimitingMiddleware> _logger;
public RateLimitingMiddleware(
IRateLimiter rateLimiter,
ILogger<RateLimitingMiddleware> logger)
{
_rateLimiter = rateLimiter;
_logger = logger;
}
public async Task InvokeAsync(AgentContext context, AgentDelegate next)
{
// 使用用户ID或IP地址作为限流key
var key = context.UserId ??
context.Properties.GetValueOrDefault("ClientIP")?.ToString() ??
"anonymous";
var limitResult = await _rateLimiter.CheckLimitAsync(key);
if (!limitResult.Allowed)
{
_logger.LogWarning(
"请求被限流: Key={Key}, Limit={Limit}, Remaining={Remaining}",
key, limitResult.Limit, limitResult.Remaining);
context.Response = new AgentResponse
{
Success = false,
Error = $"请求过于频繁,请{limitResult.RetryAfter}秒后再试"
};
context.Properties["RateLimited"] = true;
context.Properties["RetryAfter"] = limitResult.RetryAfter;
return;
}
_logger.LogDebug(
"限流检查通过: Key={Key}, Remaining={Remaining}/{Limit}",
key, limitResult.Remaining, limitResult.Limit);
// 添加限流信息到响应头
context.Properties["RateLimit-Limit"] = limitResult.Limit;
context.Properties["RateLimit-Remaining"] = limitResult.Remaining;
context.Properties["RateLimit-Reset"] = limitResult.ResetTime;
await next(context);
}
}
// 简单的内存限流器实现
public class MemoryRateLimiter : IRateLimiter
{
private readonly ConcurrentDictionary<string, RateLimitEntry> _entries;
private readonly int _maxRequests;
private readonly TimeSpan _window;
public MemoryRateLimiter(int maxRequests = 100, int windowSeconds = 60)
{
_maxRequests = maxRequests;
_window = TimeSpan.FromSeconds(windowSeconds);
_entries = new ConcurrentDictionary<string, RateLimitEntry>();
}
public Task<RateLimitResult> CheckLimitAsync(string key)
{
var now = DateTime.UtcNow;
var entry = _entries.GetOrAdd(key, _ => new RateLimitEntry
{
Count = 0,
WindowStart = now
});
lock (entry)
{
// 检查是否在同一个窗口内
if (now - entry.WindowStart > _window)
{
// 重置窗口
entry.WindowStart = now;
entry.Count = 0;
}
entry.Count++;
var remaining = Math.Max(0, _maxRequests - entry.Count);
var allowed = entry.Count <= _maxRequests;
return Task.FromResult(new RateLimitResult
{
Allowed = allowed,
Limit = _maxRequests,
Remaining = remaining,
RetryAfter = allowed ? 0 : (int)(entry.WindowStart + _window - now).TotalSeconds,
ResetTime = entry.WindowStart + _window
});
}
}
private class RateLimitEntry
{
public int Count { get; set; }
public DateTime WindowStart { get; set; }
}
}
public interface IRateLimiter
{
Task<RateLimitResult> CheckLimitAsync(string key);
}
public class RateLimitResult
{
public bool Allowed { get; set; }
public int Limit { get; set; }
public int Remaining { get; set; }
public int RetryAfter { get; set; }
public DateTime ResetTime { get; set; }
}
五、工具与中间件组合使用
5.1 完整的Agent主机
// AgentHost.cs
public class AgentHost
{
private readonly IToolRegistry _toolRegistry;
private readonly AgentMiddlewarePipeline _pipeline;
private readonly IAIAgent _agent;
private readonly ILogger<AgentHost> _logger;
public AgentHost(
IToolRegistry toolRegistry,
IAIAgent agent,
ILogger<AgentHost> logger)
{
_toolRegistry = toolRegistry;
_agent = agent;
_logger = logger;
_pipeline = new AgentMiddlewarePipeline(logger);
}
public AgentHost Use(IAgentMiddleware middleware)
{
_pipeline.Use(middleware);
return this;
}
public AgentHost Use(Func<AgentContext, AgentDelegate, Task> middleware)
{
_pipeline.Use(middleware);
return this;
}
public async Task<AgentResponse> ProcessAsync(AgentRequest request)
{
var context = new AgentContext
{
UserId = request.UserId,
ConversationId = request.ConversationId,
Message = request.Message,
Properties = request.Properties ?? new Dictionary<string, object>(),
Timestamp = DateTime.UtcNow
};
try
{
// 通过中间件管道处理
await _pipeline.InvokeAsync(context);
// 如果中间件已经设置了响应,直接返回
if (context.Response != null)
{
return context.Response;
}
// 调用Agent处理
var tools = _toolRegistry.GetAll().ToList();
var response = await _agent.ProcessAsync(context, tools);
return new AgentResponse
{
Success = true,
Message = response,
ConversationId = context.ConversationId,
Metadata = context.Properties
};
}
catch (Exception ex)
{
_logger.LogError(ex, "Agent处理失败");
return new AgentResponse
{
Success = false,
Error = $"处理失败: {ex.Message}"
};
}
}
}
5.2 配置示例
// Program.cs 配置
var builder = WebApplication.CreateBuilder(args);
// 添加服务
builder.Services.AddSingleton<IToolRegistry, ToolRegistry>();
builder.Services.AddAgentTools();
// 配置中间件
builder.Services.AddSingleton<ITokenValidator, JwtTokenValidator>();
builder.Services.AddSingleton<IRateLimiter>(sp =>
new MemoryRateLimiter(maxRequests: 100, windowSeconds: 60));
var app = builder.Build();
// 配置Agent主机
var agentHost = new AgentHost(
app.Services.GetRequiredService<IToolRegistry>(),
app.Services.GetRequiredService<IAIAgent>(),
app.Services.GetRequiredService<ILogger<AgentHost>>())
.Use(new LoggingMiddleware(app.Services.GetRequiredService<ILogger<LoggingMiddleware>>()))
.Use(new AuthenticationMiddleware(
app.Services.GetRequiredService<ITokenValidator>(),
app.Services.GetRequiredService<ILogger<AuthenticationMiddleware>>()))
.Use(new RateLimitingMiddleware(
app.Services.GetRequiredService<IRateLimiter>(),
app.Services.GetRequiredService<ILogger<RateLimitingMiddleware>>()));
app.MapPost("/agent", async (AgentRequest request) =>
{
var response = await agentHost.ProcessAsync(request);
return Results.Ok(response);
});
app.Run();
六、最佳实践
6.1 工具设计原则
在开发自定义工具时,应遵循以下原则:
单一职责:每个工具应该只完成一项具体任务。这样可以提高工具的可测试性和可维护性。
清晰的参数定义:为每个参数提供清晰的描述和验证规则。这有助于LLM正确理解何时以及如何使用工具。
完善的错误处理:工具应该捕获并处理可预见的错误,返回有意义的错误信息,而不是抛出原始异常。
安全优先:对所有输入进行验证,特别是涉及文件操作、数据库查询、HTTP请求等可能存在安全风险的场景。
幂等性:在可能的情况下,工具应该具有幂等性,即多次执行相同操作应该产生相同的结果。
6.2 中间件设计原则
短小精悍:中间件应该专注于单一功能,不要在中间件中混合多个不相关的逻辑。
不要阻塞:避免在中间件中执行耗时的同步操作,使用异步方式处理。
正确传递上下文:确保修改context时不要影响后续中间件的正确执行。
错误处理:在中间件中妥善处理异常,避免未处理的异常导致请求崩溃。
七、总结与展望
通过本文的学习,我们已经掌握了自定义工具和中间件开发的核心技术:
✅ 工具接口:理解Agent Framework的工具接口定义
✅ 工具开发:创建数据库查询、HTTP请求、文件操作等实用工具
✅ 工具管理:实现工具注册、搜索、动态管理
✅ 中间件开发:创建日志、认证、限流等中间件
✅ 组合使用:将工具和中间件组合构建完整的Agent系统
✅ 最佳实践:遵循工具和中间件的设计原则
关键收获:
自定义工具和中间件是扩展Agent Framework功能的关键手段。通过合理设计工具,我们可以让Agent与各种外部系统进行交互;通过使用中间件,我们可以在请求处理过程中添加认证、限流、日志等功能。在实际开发中,要注意工具的安全性和中间件的性能,确保系统稳定可靠运行。
下一篇文章预告:
在第八篇文章中,我们将探索监控与可观测性。我们将学习如何使用OpenTelemetry实现Agent系统的全面监控,包括性能指标、分布式追踪、日志聚合等,确保生产环境的稳定运行。
实践建议:
为每个工具编写单元测试,确保功能正确
使用依赖注入管理工具生命周期
为敏感操作添加审计日志
监控工具调用频率和性能
定期审查和更新工具权限
相关资源:
Agent Framework工具开发文档
中间件模式详解
.NET依赖注入指南
"好的工具让Agent更强大,好的中间件让系统更可靠。"
