这个是工具(Function Calling)-阿里云Spring AI Alibaba官网官网 的运行图

7610E899-5AAB-46E7-9A47-994714D30499.png

具体代码在 ChatModel 的内部。以 call方法的调用为例:

//源码 /Users/xuanmiss/.m2/repository/com/alibaba/cloud/ai/spring-ai-alibaba-core/1.0.0-M5.1/spring-ai-alibaba-core-1.0.0-M5.1.jar!/com/alibaba/cloud/ai/dashscope/chat/DashScopeChatModel.class


public ChatResponse call(Prompt prompt) {  
    ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(DashScopeApiConstants.PROVIDER_NAME).requestOptions((ChatOptions)(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)).build();  
    ChatResponse chatResponse = (ChatResponse)ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {  
        return observationContext;  
    }, this.observationRegistry).observe(() -> {  
        DashScopeApi.ChatCompletionRequest request = this.createRequest(prompt, false);  
        ResponseEntity<DashScopeApi.ChatCompletion> completionEntity = (ResponseEntity)this.retryTemplate.execute((ctx) -> {  
            return this.dashscopeApi.chatCompletionEntity(request);  
        });  
        DashScopeApi.ChatCompletion chatCompletion = (DashScopeApi.ChatCompletion)completionEntity.getBody();  
        if (chatCompletion == null) {  
            logger.warn("No chat completion returned for prompt: {}", prompt);  
            return new ChatResponse(List.of());  
        } else {  
            List<DashScopeApi.ChatCompletionOutput.Choice> choices = chatCompletion.output().choices();  
            List<Generation> generations = choices.stream().map((choice) -> {  
                Map<String, Object> metadata = Map.of("id", chatCompletion.requestId(), "role", choice.message().role() != null ? choice.message().role().name() : "", "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");  
                return buildGeneration(choice, metadata);  
            }).toList();  
            ChatResponse response = new ChatResponse(generations, this.from((DashScopeApi.ChatCompletion)completionEntity.getBody()));  
            observationContext.setResponse(response);  
            return response;  
        }  
    });  
    if (this.isToolCall(chatResponse, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {  
        List<Message> toolCallConversation = this.handleToolCalls(prompt, chatResponse);  
        return this.call(new Prompt(toolCallConversation, prompt.getOptions()));  
    } else {  
        return chatResponse;  
    }  
}

这里判断模型的返回,如果是 {"message":{"tool_calls" 这样的内容,则执行 function/tool call
继续往下看执行代码和如何找到我们注册的tool和如何调用

// 源码/Users/xuanmiss/.m2/repository/org/springframework/ai/spring-ai-core/1.0.0-M5/spring-ai-core-1.0.0-M5.jar!/org/springframework/ai/chat/model/AbstractToolCallSupport.class


protected List<Message> handleToolCalls(Prompt prompt, ChatResponse response) {  
    Optional<Generation> toolCallGeneration = response.getResults().stream().filter((g) -> {  
        return !CollectionUtils.isEmpty(g.getOutput().getToolCalls());  
    }).findFirst();  
    if (toolCallGeneration.isEmpty()) {  
        throw new IllegalStateException("No tool call generation found in the response!");  
    } else {  
        AssistantMessage assistantMessage = ((Generation)toolCallGeneration.get()).getOutput();  
        Map<String, Object> toolContextMap = Map.of();  
        ChatOptions var7 = prompt.getOptions();  
        if (var7 instanceof FunctionCallingOptions) {  
            FunctionCallingOptions functionCallOptions = (FunctionCallingOptions)var7;  
            if (!CollectionUtils.isEmpty(functionCallOptions.getToolContext())) {  
                toolContextMap = new HashMap(functionCallOptions.getToolContext());  
                List<Message> toolCallHistory = new ArrayList(prompt.copy().getInstructions());  
                toolCallHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), assistantMessage.getToolCalls()));  
                ((Map)toolContextMap).put("TOOL_CALL_HISTORY", toolCallHistory);  
            }  
        }  
  
        ToolResponseMessage toolMessageResponse = this.executeFunctions(assistantMessage, new ToolContext((Map)toolContextMap));  
        List<Message> toolConversationHistory = this.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse);  
        return toolConversationHistory;  
    }  
}
// ······

protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage, ToolContext toolContext) {  
    List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList();  
    Iterator var4 = assistantMessage.getToolCalls().iterator();  
  
    while(var4.hasNext()) {  
        AssistantMessage.ToolCall toolCall = (AssistantMessage.ToolCall)var4.next();  
        String functionName = toolCall.name();  
        String functionArguments = toolCall.arguments();  
        if (!this.functionCallbackRegister.containsKey(functionName)) {  
            throw new IllegalStateException("No function callback found for function name: " + functionName);  
        }  
  
        String functionResponse = ((FunctionCallback)this.functionCallbackRegister.get(functionName)).call(functionArguments, toolContext);  
        toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, functionResponse));  
    }  
  
    return new ToolResponseMessage(toolResponses, Map.of());  
}

在这个方法里做了 tools的调用,主要是调用 ToolResponseMessage toolMessageResponse = this.executeFunctions(assistantMessage, new ToolContext((Map)toolContextMap)); 执行,然后构造完整的 message ,是同一个类的方法,主要是从 String functionResponse = ((FunctionCallback)this.functionCallbackRegister.get(functionName)).call(functionArguments, toolContext); 这里拿到方法的名称,执行call。这里其实是调用了 FunctionCallback.java 这个接口,来实现。下面看看这个接口的内容和具体的两个实现。能够跟 Tools 的两种注册方式对应。

这个 Interface的源码如下

String call(String functionInput);  
  
default String call(String functionInput, ToolContext tooContext) {  
    if (tooContext != null && !tooContext.getContext().isEmpty()) {  
        throw new UnsupportedOperationException("Function context is not supported!");  
    } else {  
        return this.call(functionInput);  
    }  
}

这里看看默认提供的实现类和相关逻辑:

8F0412B1-454D-47B9-827F-5BCBE011FFB3.png
这里一共是三个实现类,其中第二个是 FunctionDefinition,用来提供定义的,没有具体的功能实现。
第一个 AbstractFunctionCallback实现如下,如果我们通过定义 FunctionCallBack或者注册了实现Function接口的bean,最终都会由这里执行

public String call(String functionArguments) {  
    I request = this.fromJson(functionArguments, this.inputType);  
    return (String)this.andThen(this.responseConverter).apply(request, (Object)null);  
}

第三个 MethodInvokingFunctionCallback的逻辑如下。如果我们是已有的类,service等,将其中的某个方法注册成tools,则会最终执行到这里。

public String call(String functionInput, ToolContext toolContext) {  
    try {  
        if (toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext()) && !this.isToolContextMethod) {  
            throw new IllegalArgumentException("Configured method does not accept ToolContext as input parameter!");  
        } else {  
            Map<String, Object> map = (Map)this.mapper.readValue(functionInput, Map.class);  
            Object[] methodArgs = Stream.of(this.method.getParameters()).map((parameter) -> {  
                Class<?> type = parameter.getType();  
                if (ClassUtils.isAssignable(type, ToolContext.class)) {  
                    return toolContext;  
                } else {  
                    Object rawValue = map.get(parameter.getName());  
                    return this.toJavaType(rawValue, type);  
                }  
            }).toArray();  
            Object response = ReflectionUtils.invokeMethod(this.method, this.functionObject, methodArgs);  
            Class<?> returnType = this.method.getReturnType();  
            if (returnType == Void.TYPE) {  
                return "Done";  
            } else {  
                return returnType != Class.class && !returnType.isRecord() && returnType != List.class && returnType != Map.class ? (String)this.responseConverter.apply(response) : ModelOptionsUtils.toJsonString(response);  
            }  
        }  
    } catch (Exception var7) {  
        Exception e = var7;  
        ReflectionUtils.handleReflectionException(e);  
        return null;  
    }  
}

本次没有写这种示例,大致可以参考官方文档中的内容:

// 1. 已存在的MockOrderService
@Service
public class MockOrderService {
    public Response getOrder(Request request) {
        String productName = "尤尼克斯羽毛球拍";
        return new Response(String.format("%s的订单编号为%s, 购买的商品为: %s", request.userId, request.orderId, productName));
    }

    @JsonInclude(JsonInclude.Include.NON_NULL)
    public record Request(
            //这里的JsonProperty将转换为function的parameters信息, 包括参数名称和参数描述等
            /*
             {
                "orderId": {
                    "type": "string",
                    "description": "订单编号, 比如1001***"
                    },
                "userId": {
                    "type": "string",
                    "description": "用户编号, 比如2001***"
                }
            }
            */
            @JsonProperty(required = true, value = "orderId") @JsonPropertyDescription("订单编号, 比如1001***") String orderId,
            @JsonProperty(required = true, value = "userId") @JsonPropertyDescription("用户编号, 比如2001***") String userId) {
    }

    public record Response(String description) {
    }
}

//2. 将MockOrderService的getOrder注册为function call的bean
@Configuration
public class FunctionCallConfiguration {
    @Bean
    @Description("根据用户编号和订单编号查询订单信息")  //function的描述
    public Function<MockOrderService.Request, MockOrderService.Response> getOrderFunction(MockOrderService mockOrderService) {
        return mockOrderService::getOrder;
    }
}

//3. 调用function call
DashScopeChatModel dashscopeChatModel = ...;
ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
        .defaultFunctions("getOrderFunction")
        .build();

ChatResponse response = chatClient
        .prompt()
        .user("帮我一下订单, 用户编号为1001, 订单编号为2001")
        .call()
        .chatResponse();

String content = response.getResult().getOutput().getContent();
logger.info("content: {}", content);