Merge branch 'master-jdk17' of https://gitee.com/zhijiantianya/ruoyi-vue-pro into develop-api-remove
# Conflicts: # yudao-module-ai/yudao-module-ai-biz/pom.xml
This commit is contained in:
@@ -17,6 +17,10 @@
|
||||
国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek
|
||||
国外:OpenAI、Ollama、Midjourney、StableDiffusion、Suno
|
||||
</description>
|
||||
<properties>
|
||||
<spring-ai.version>1.0.0-M6</spring-ai.version>
|
||||
<tinyflow.version>1.0.2</tinyflow.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
@@ -24,13 +28,18 @@
|
||||
<artifactId>yudao-module-ai-api</artifactId>
|
||||
<version>${revision}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- 业务组件 -->
|
||||
<dependency>
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
<artifactId>yudao-spring-boot-starter-ai</artifactId>
|
||||
<artifactId>yudao-module-system-api</artifactId>
|
||||
<version>${revision}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
<artifactId>yudao-module-infra-api</artifactId>
|
||||
<version>${revision}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- 业务组件 -->
|
||||
|
||||
<dependency>
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
@@ -66,5 +75,142 @@
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
<artifactId>yudao-spring-boot-starter-excel</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Spring AI Model 模型接入 -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-azure-openai-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-stability-ai-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<!-- 通义千问 -->
|
||||
<groupId>com.alibaba.cloud.ai</groupId>
|
||||
<artifactId>spring-ai-alibaba-starter</artifactId>
|
||||
<version>${spring-ai.version}.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<!-- 文心一言 -->
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-qianfan-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<!-- 智谱 GLM -->
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-zhipuai-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-minimax-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-moonshot-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- 向量存储:https://db-engines.com/en/ranking/vector+dbms -->
|
||||
<dependency>
|
||||
<!-- Qdrant:https://qdrant.tech/ -->
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-qdrant-store</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<!-- Redis:https://redis.io/docs/latest/develop/get-started/vector-database/ -->
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-redis-store</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
<artifactId>yudao-spring-boot-starter-redis</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<!-- Milvus:https://milvus.io/ -->
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-milvus-store</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
<exclusions>
|
||||
<!-- 解决和 logback 的日志冲突 -->
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-reload4j</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<!-- Tika:负责内容的解析 -->
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-tika-document-reader</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
<!-- TODO 芋艿:boot 项目里,不引入 cloud 依赖!!!另外,这样也是为了解决启动报错的问题! -->
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<artifactId>spring-cloud-function-context</artifactId>
|
||||
<groupId>org.springframework.cloud</groupId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<artifactId>spring-cloud-function-core</artifactId>
|
||||
<groupId>org.springframework.cloud</groupId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<!-- TinyFlow:AI 工作流 -->
|
||||
<dependency>
|
||||
<groupId>dev.tinyflow</groupId>
|
||||
<artifactId>tinyflow-java-core</artifactId>
|
||||
<version>${tinyflow.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>com.jfinal</groupId>
|
||||
<artifactId>enjoy</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<!-- 解决 https://gitee.com/zhijiantianya/ruoyi-vue-pro/pulls/1318/ 问题 -->
|
||||
<groupId>com.agentsflex</groupId>
|
||||
<artifactId>agents-flex-store-elasticsearch</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<!-- TODO @芋艿:暂时移除 groovy,和 iot 冲突 -->
|
||||
<groupId>org.codehaus.groovy</groupId>
|
||||
<artifactId>groovy-all</artifactId>
|
||||
</exclusion>
|
||||
<!-- 解决和 logback 的日志冲突 -->
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.apache.logging.log4j</groupId>
|
||||
<artifactId>log4j-slf4j-impl</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-reload4j</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@@ -1,7 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.image;
|
||||
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.validation.InEnum;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.dal.dataobject.image;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
||||
@@ -49,7 +50,7 @@ public class AiImageDO extends BaseDO {
|
||||
/**
|
||||
* 平台
|
||||
*
|
||||
* 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
|
||||
* 枚举 {@link AiPlatformEnum}
|
||||
*/
|
||||
private String platform;
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package cn.iocoder.yudao.module.ai.dal.dataobject.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.dal.dataobject.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package cn.iocoder.yudao.module.ai.dal.dataobject.music;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package cn.iocoder.yudao.module.ai.dal.dataobject.write;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package cn.iocoder.yudao.module.ai.enums.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* AI 模型类型的枚举
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
@Getter
|
||||
@RequiredArgsConstructor
|
||||
public enum AiModelTypeEnum implements ArrayValuable<Integer> {
|
||||
|
||||
CHAT(1, "对话"),
|
||||
IMAGE(2, "图片"),
|
||||
VOICE(3, "语音"),
|
||||
VIDEO(4, "视频"),
|
||||
EMBEDDING(5, "向量"),
|
||||
RERANK(6, "重排序");
|
||||
|
||||
/**
|
||||
* 类型
|
||||
*/
|
||||
private final Integer type;
|
||||
/**
|
||||
* 类型名
|
||||
*/
|
||||
private final String name;
|
||||
|
||||
public static final Integer[] ARRAYS = Arrays.stream(values()).map(AiModelTypeEnum::getType).toArray(Integer[]::new);
|
||||
|
||||
@Override
|
||||
public Integer[] array() {
|
||||
return ARRAYS;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package cn.iocoder.yudao.module.ai.enums.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* AI 模型平台
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum AiPlatformEnum implements ArrayValuable<String> {
|
||||
|
||||
// ========== 国内平台 ==========
|
||||
|
||||
TONG_YI("TongYi", "通义千问"), // 阿里
|
||||
YI_YAN("YiYan", "文心一言"), // 百度
|
||||
DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek
|
||||
ZHI_PU("ZhiPu", "智谱"), // 智谱 AI
|
||||
XING_HUO("XingHuo", "星火"), // 讯飞
|
||||
DOU_BAO("DouBao", "豆包"), // 字节
|
||||
HUN_YUAN("HunYuan", "混元"), // 腾讯
|
||||
SILICON_FLOW("SiliconFlow", "硅基流动"), // 硅基流动
|
||||
MINI_MAX("MiniMax", "MiniMax"), // 稀宇科技
|
||||
MOONSHOT("Moonshot", "月之暗灭"), // KIMI
|
||||
BAI_CHUAN("BaiChuan", "百川智能"), // 百川智能
|
||||
|
||||
// ========== 国外平台 ==========
|
||||
|
||||
OPENAI("OpenAI", "OpenAI"), // OpenAI 官方
|
||||
AZURE_OPENAI("AzureOpenAI", "AzureOpenAI"), // OpenAI 微软
|
||||
OLLAMA("Ollama", "Ollama"),
|
||||
|
||||
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
|
||||
MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney
|
||||
SUNO("Suno", "Suno"), // Suno AI
|
||||
|
||||
;
|
||||
|
||||
/**
|
||||
* 平台
|
||||
*/
|
||||
private final String platform;
|
||||
/**
|
||||
* 平台名
|
||||
*/
|
||||
private final String name;
|
||||
|
||||
public static final String[] ARRAYS = Arrays.stream(values()).map(AiPlatformEnum::getPlatform).toArray(String[]::new);
|
||||
|
||||
public static AiPlatformEnum validatePlatform(String platform) {
|
||||
for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) {
|
||||
if (platformEnum.getPlatform().equals(platform)) {
|
||||
return platformEnum;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("非法平台: " + platform);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] array() {
|
||||
return ARRAYS;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
/**
|
||||
* 占位
|
||||
*/
|
||||
package cn.iocoder.yudao.module.ai.enums;
|
||||
@@ -0,0 +1,253 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.config;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.AiModelFactory;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.AiModelFactoryImpl;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.baichuan.BaiChuanChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao.DouBaoChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.hunyuan.HunYuanChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientProperties;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreProperties;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreProperties;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
|
||||
import org.springframework.ai.tokenizer.TokenCountEstimator;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
/**
|
||||
* 芋道 AI 自动配置
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Configuration
|
||||
@EnableConfigurationProperties({ YudaoAiProperties.class,
|
||||
QdrantVectorStoreProperties.class, // 解析 Qdrant 配置
|
||||
RedisVectorStoreProperties.class, // 解析 Redis 配置
|
||||
MilvusVectorStoreProperties.class, MilvusServiceClientProperties.class // 解析 Milvus 配置
|
||||
})
|
||||
@Slf4j
|
||||
public class AiAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
public AiModelFactory aiModelFactory() {
|
||||
return new AiModelFactoryImpl();
|
||||
}
|
||||
|
||||
// ========== 各种 AI Client 创建 ==========
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.deepseek.enable", havingValue = "true")
|
||||
public DeepSeekChatModel deepSeekChatModel(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.DeepSeekProperties properties = yudaoAiProperties.getDeepseek();
|
||||
return buildDeepSeekChatModel(properties);
|
||||
}
|
||||
|
||||
public DeepSeekChatModel buildDeepSeekChatModel(YudaoAiProperties.DeepSeekProperties properties) {
|
||||
if (StrUtil.isEmpty(properties.getModel())) {
|
||||
properties.setModel(DeepSeekChatModel.MODEL_DEFAULT);
|
||||
}
|
||||
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(DeepSeekChatModel.BASE_URL)
|
||||
.apiKey(properties.getApiKey())
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topP(properties.getTopP())
|
||||
.build())
|
||||
.toolCallingManager(getToolCallingManager())
|
||||
.build();
|
||||
return new DeepSeekChatModel(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.doubao.enable", havingValue = "true")
|
||||
public DouBaoChatModel douBaoChatClient(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.DouBaoProperties properties = yudaoAiProperties.getDoubao();
|
||||
return buildDouBaoChatClient(properties);
|
||||
}
|
||||
|
||||
public DouBaoChatModel buildDouBaoChatClient(YudaoAiProperties.DouBaoProperties properties) {
|
||||
if (StrUtil.isEmpty(properties.getModel())) {
|
||||
properties.setModel(DouBaoChatModel.MODEL_DEFAULT);
|
||||
}
|
||||
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(DouBaoChatModel.BASE_URL)
|
||||
.apiKey(properties.getApiKey())
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topP(properties.getTopP())
|
||||
.build())
|
||||
.toolCallingManager(getToolCallingManager())
|
||||
.build();
|
||||
return new DouBaoChatModel(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.siliconflow.enable", havingValue = "true")
|
||||
public SiliconFlowChatModel siliconFlowChatClient(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.SiliconFlowProperties properties = yudaoAiProperties.getSiliconflow();
|
||||
return buildSiliconFlowChatClient(properties);
|
||||
}
|
||||
|
||||
public SiliconFlowChatModel buildSiliconFlowChatClient(YudaoAiProperties.SiliconFlowProperties properties) {
|
||||
if (StrUtil.isEmpty(properties.getModel())) {
|
||||
properties.setModel(SiliconFlowApiConstants.MODEL_DEFAULT);
|
||||
}
|
||||
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(SiliconFlowApiConstants.DEFAULT_BASE_URL)
|
||||
.apiKey(properties.getApiKey())
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topP(properties.getTopP())
|
||||
.build())
|
||||
.toolCallingManager(getToolCallingManager())
|
||||
.build();
|
||||
return new SiliconFlowChatModel(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.hunyuan.enable", havingValue = "true")
|
||||
public HunYuanChatModel hunYuanChatClient(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.HunYuanProperties properties = yudaoAiProperties.getHunyuan();
|
||||
return buildHunYuanChatClient(properties);
|
||||
}
|
||||
|
||||
public HunYuanChatModel buildHunYuanChatClient(YudaoAiProperties.HunYuanProperties properties) {
|
||||
if (StrUtil.isEmpty(properties.getModel())) {
|
||||
properties.setModel(HunYuanChatModel.MODEL_DEFAULT);
|
||||
}
|
||||
// 特殊:由于混元大模型不提供 deepseek,而是通过知识引擎,所以需要区分下 URL
|
||||
if (StrUtil.isEmpty(properties.getBaseUrl())) {
|
||||
properties.setBaseUrl(
|
||||
StrUtil.startWithIgnoreCase(properties.getModel(), "deepseek") ? HunYuanChatModel.DEEP_SEEK_BASE_URL
|
||||
: HunYuanChatModel.BASE_URL);
|
||||
}
|
||||
// 创建 OpenAiChatModel、HunYuanChatModel 对象
|
||||
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(properties.getBaseUrl())
|
||||
.apiKey(properties.getApiKey())
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topP(properties.getTopP())
|
||||
.build())
|
||||
.toolCallingManager(getToolCallingManager())
|
||||
.build();
|
||||
return new HunYuanChatModel(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true")
|
||||
public XingHuoChatModel xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.XingHuoProperties properties = yudaoAiProperties.getXinghuo();
|
||||
return buildXingHuoChatClient(properties);
|
||||
}
|
||||
|
||||
public XingHuoChatModel buildXingHuoChatClient(YudaoAiProperties.XingHuoProperties properties) {
|
||||
if (StrUtil.isEmpty(properties.getModel())) {
|
||||
properties.setModel(XingHuoChatModel.MODEL_DEFAULT);
|
||||
}
|
||||
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(XingHuoChatModel.BASE_URL)
|
||||
.apiKey(properties.getAppKey() + ":" + properties.getSecretKey())
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topP(properties.getTopP())
|
||||
.build())
|
||||
.toolCallingManager(getToolCallingManager())
|
||||
.build();
|
||||
return new XingHuoChatModel(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.baichuan.enable", havingValue = "true")
|
||||
public BaiChuanChatModel baiChuanChatClient(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.BaiChuanProperties properties = yudaoAiProperties.getBaichuan();
|
||||
return buildBaiChuanChatClient(properties);
|
||||
}
|
||||
|
||||
public BaiChuanChatModel buildBaiChuanChatClient(YudaoAiProperties.BaiChuanProperties properties) {
|
||||
if (StrUtil.isEmpty(properties.getModel())) {
|
||||
properties.setModel(BaiChuanChatModel.MODEL_DEFAULT);
|
||||
}
|
||||
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(BaiChuanChatModel.BASE_URL)
|
||||
.apiKey(properties.getApiKey())
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topP(properties.getTopP())
|
||||
.build())
|
||||
.toolCallingManager(getToolCallingManager())
|
||||
.build();
|
||||
return new BaiChuanChatModel(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
|
||||
public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.MidjourneyProperties config = yudaoAiProperties.getMidjourney();
|
||||
return new MidjourneyApi(config.getBaseUrl(), config.getApiKey(), config.getNotifyUrl());
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.suno.enable", havingValue = "true")
|
||||
public SunoApi sunoApi(YudaoAiProperties yudaoAiProperties) {
|
||||
return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl());
|
||||
}
|
||||
|
||||
// ========== RAG 相关 ==========
|
||||
|
||||
@Bean
|
||||
public TokenCountEstimator tokenCountEstimator() {
|
||||
return new JTokkitTokenCountEstimator();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BatchingStrategy batchingStrategy() {
|
||||
return new TokenCountBatchingStrategy();
|
||||
}
|
||||
|
||||
private static ToolCallingManager getToolCallingManager() {
|
||||
return SpringUtil.getBean(ToolCallingManager.class);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
/**
|
||||
* 芋道 AI 配置类
|
||||
*
|
||||
* @author fansili
|
||||
* @since 1.0
|
||||
*/
|
||||
@ConfigurationProperties(prefix = "yudao.ai")
|
||||
@Data
|
||||
public class YudaoAiProperties {
|
||||
|
||||
/**
|
||||
* DeepSeek
|
||||
*/
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private DeepSeekProperties deepseek;
|
||||
|
||||
/**
|
||||
* 字节豆包
|
||||
*/
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private DouBaoProperties doubao;
|
||||
|
||||
/**
|
||||
* 腾讯混元
|
||||
*/
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private HunYuanProperties hunyuan;
|
||||
|
||||
/**
|
||||
* 硅基流动
|
||||
*/
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private SiliconFlowProperties siliconflow;
|
||||
|
||||
/**
|
||||
* 讯飞星火
|
||||
*/
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private XingHuoProperties xinghuo;
|
||||
|
||||
/**
|
||||
* 百川
|
||||
*/
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private BaiChuanProperties baichuan;
|
||||
|
||||
/**
|
||||
* Midjourney 绘图
|
||||
*/
|
||||
private MidjourneyProperties midjourney;
|
||||
|
||||
/**
|
||||
* Suno 音乐
|
||||
*/
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private SunoProperties suno;
|
||||
|
||||
@Data
|
||||
public static class DeepSeekProperties {
|
||||
|
||||
private String enable;
|
||||
private String apiKey;
|
||||
|
||||
private String model;
|
||||
private Double temperature;
|
||||
private Integer maxTokens;
|
||||
private Double topP;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class DouBaoProperties {
|
||||
|
||||
private String enable;
|
||||
private String apiKey;
|
||||
|
||||
private String model;
|
||||
private Double temperature;
|
||||
private Integer maxTokens;
|
||||
private Double topP;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class HunYuanProperties {
|
||||
|
||||
private String enable;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
|
||||
private String model;
|
||||
private Double temperature;
|
||||
private Integer maxTokens;
|
||||
private Double topP;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class SiliconFlowProperties {
|
||||
|
||||
private String enable;
|
||||
private String apiKey;
|
||||
|
||||
private String model;
|
||||
private Double temperature;
|
||||
private Integer maxTokens;
|
||||
private Double topP;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class XingHuoProperties {
|
||||
|
||||
private String enable;
|
||||
private String appId;
|
||||
private String appKey;
|
||||
private String secretKey;
|
||||
|
||||
private String model;
|
||||
private Double temperature;
|
||||
private Integer maxTokens;
|
||||
private Double topP;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class BaiChuanProperties {
|
||||
|
||||
private String enable;
|
||||
private String apiKey;
|
||||
|
||||
private String model;
|
||||
private Double temperature;
|
||||
private Integer maxTokens;
|
||||
private Double topP;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class MidjourneyProperties {
|
||||
|
||||
private String enable;
|
||||
private String baseUrl;
|
||||
|
||||
private String apiKey;
|
||||
private String notifyUrl;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class SunoProperties {
|
||||
|
||||
private boolean enable = false;
|
||||
|
||||
private String baseUrl;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.suno.api.SunoApi;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* AI Model 模型工厂的接口类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public interface AiModelFactory {
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 ChatModel 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param platform 平台
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @return ChatModel 对象
|
||||
*/
|
||||
ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url);
|
||||
|
||||
/**
|
||||
* 基于默认配置,获得 ChatModel 对象
|
||||
*
|
||||
* 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
|
||||
*
|
||||
* @param platform 平台
|
||||
* @return ChatModel 对象
|
||||
*/
|
||||
ChatModel getDefaultChatModel(AiPlatformEnum platform);
|
||||
|
||||
/**
|
||||
* 基于默认配置,获得 ImageModel 对象
|
||||
*
|
||||
* 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
|
||||
*
|
||||
* @param platform 平台
|
||||
* @return ImageModel 对象
|
||||
*/
|
||||
ImageModel getDefaultImageModel(AiPlatformEnum platform);
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 ImageModel 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param platform 平台
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @return ImageModel 对象
|
||||
*/
|
||||
ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url);
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 MidjourneyApi 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @return MidjourneyApi 对象
|
||||
*/
|
||||
MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url);
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 SunoApi 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @return SunoApi 对象
|
||||
*/
|
||||
SunoApi getOrCreateSunoApi(String apiKey, String url);
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 EmbeddingModel 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param platform 平台
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @param model 模型
|
||||
* @return ChatModel 对象
|
||||
*/
|
||||
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model);
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 VectorStore 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param type 向量存储类型
|
||||
* @param embeddingModel 向量模型
|
||||
* @param metadataFields 元数据字段
|
||||
* @return VectorStore 对象
|
||||
*/
|
||||
VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type,
|
||||
EmbeddingModel embeddingModel,
|
||||
Map<String, Class<?>> metadataFields);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,752 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core;
|
||||
|
||||
import cn.hutool.core.io.FileUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.lang.Singleton;
|
||||
import cn.hutool.core.lang.func.Func0;
|
||||
import cn.hutool.core.util.ArrayUtil;
|
||||
import cn.hutool.core.util.RuntimeUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.config.AiAutoConfiguration;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.config.YudaoAiProperties;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.baichuan.BaiChuanChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao.DouBaoChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.hunyuan.HunYuanChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowImageApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowImageModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
|
||||
import com.alibaba.cloud.ai.autoconfigure.dashscope.DashScopeAutoConfiguration;
|
||||
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
|
||||
import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
|
||||
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions;
|
||||
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
|
||||
import com.azure.ai.openai.OpenAIClientBuilder;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import io.milvus.client.MilvusServiceClient;
|
||||
import io.qdrant.client.QdrantClient;
|
||||
import io.qdrant.client.QdrantGrpcClient;
|
||||
import lombok.SneakyThrows;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiEmbeddingProperties;
|
||||
import org.springframework.ai.autoconfigure.minimax.MiniMaxAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.stabilityai.StabilityAiImageAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientConnectionDetails;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientProperties;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreProperties;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreProperties;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.minimax.MiniMaxChatModel;
|
||||
import org.springframework.ai.minimax.MiniMaxChatOptions;
|
||||
import org.springframework.ai.minimax.MiniMaxEmbeddingModel;
|
||||
import org.springframework.ai.minimax.MiniMaxEmbeddingOptions;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.moonshot.MoonshotChatModel;
|
||||
import org.springframework.ai.moonshot.MoonshotChatOptions;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.OllamaEmbeddingModel;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
||||
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
|
||||
import org.springframework.ai.openai.OpenAiImageModel;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
|
||||
import org.springframework.ai.qianfan.QianFanChatModel;
|
||||
import org.springframework.ai.qianfan.QianFanEmbeddingModel;
|
||||
import org.springframework.ai.qianfan.QianFanEmbeddingOptions;
|
||||
import org.springframework.ai.qianfan.QianFanImageModel;
|
||||
import org.springframework.ai.qianfan.api.QianFanApi;
|
||||
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
||||
import org.springframework.ai.stabilityai.StabilityAiImageModel;
|
||||
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
||||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.ai.vectorstore.milvus.MilvusVectorStore;
|
||||
import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
|
||||
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
|
||||
import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
|
||||
import org.springframework.ai.vectorstore.redis.RedisVectorStore;
|
||||
import org.springframework.ai.zhipuai.*;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
||||
import org.springframework.beans.BeansException;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
|
||||
import org.springframework.web.client.RestClient;
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
|
||||
import java.io.File;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Timer;
|
||||
import java.util.TimerTask;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
||||
import static org.springframework.ai.retry.RetryUtils.DEFAULT_RETRY_TEMPLATE;
|
||||
|
||||
/**
|
||||
* AI Model 模型工厂的实现类
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class AiModelFactoryImpl implements AiModelFactory {
|
||||
|
||||
@Override
|
||||
public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) {
|
||||
String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
|
||||
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return buildTongYiChatModel(apiKey);
|
||||
case YI_YAN:
|
||||
return buildYiYanChatModel(apiKey);
|
||||
case DEEP_SEEK:
|
||||
return buildDeepSeekChatModel(apiKey);
|
||||
case DOU_BAO:
|
||||
return buildDouBaoChatModel(apiKey);
|
||||
case HUN_YUAN:
|
||||
return buildHunYuanChatModel(apiKey, url);
|
||||
case SILICON_FLOW:
|
||||
return buildSiliconFlowChatModel(apiKey);
|
||||
case ZHI_PU:
|
||||
return buildZhiPuChatModel(apiKey, url);
|
||||
case MINI_MAX:
|
||||
return buildMiniMaxChatModel(apiKey, url);
|
||||
case MOONSHOT:
|
||||
return buildMoonshotChatModel(apiKey, url);
|
||||
case XING_HUO:
|
||||
return buildXingHuoChatModel(apiKey);
|
||||
case BAI_CHUAN:
|
||||
return buildBaiChuanChatModel(apiKey);
|
||||
case OPENAI:
|
||||
return buildOpenAiChatModel(apiKey, url);
|
||||
case AZURE_OPENAI:
|
||||
return buildAzureOpenAiChatModel(apiKey, url);
|
||||
case OLLAMA:
|
||||
return buildOllamaChatModel(url);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return SpringUtil.getBean(DashScopeChatModel.class);
|
||||
case YI_YAN:
|
||||
return SpringUtil.getBean(QianFanChatModel.class);
|
||||
case DEEP_SEEK:
|
||||
return SpringUtil.getBean(DeepSeekChatModel.class);
|
||||
case DOU_BAO:
|
||||
return SpringUtil.getBean(DouBaoChatModel.class);
|
||||
case HUN_YUAN:
|
||||
return SpringUtil.getBean(HunYuanChatModel.class);
|
||||
case SILICON_FLOW:
|
||||
return SpringUtil.getBean(SiliconFlowChatModel.class);
|
||||
case ZHI_PU:
|
||||
return SpringUtil.getBean(ZhiPuAiChatModel.class);
|
||||
case MINI_MAX:
|
||||
return SpringUtil.getBean(MiniMaxChatModel.class);
|
||||
case MOONSHOT:
|
||||
return SpringUtil.getBean(MoonshotChatModel.class);
|
||||
case XING_HUO:
|
||||
return SpringUtil.getBean(XingHuoChatModel.class);
|
||||
case BAI_CHUAN:
|
||||
return SpringUtil.getBean(AzureOpenAiChatModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiChatModel.class);
|
||||
case AZURE_OPENAI:
|
||||
return SpringUtil.getBean(AzureOpenAiChatModel.class);
|
||||
case OLLAMA:
|
||||
return SpringUtil.getBean(OllamaChatModel.class);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return SpringUtil.getBean(DashScopeImageModel.class);
|
||||
case YI_YAN:
|
||||
return SpringUtil.getBean(QianFanImageModel.class);
|
||||
case ZHI_PU:
|
||||
return SpringUtil.getBean(ZhiPuAiImageModel.class);
|
||||
case SILICON_FLOW:
|
||||
return SpringUtil.getBean(SiliconFlowImageModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiImageModel.class);
|
||||
case STABLE_DIFFUSION:
|
||||
return SpringUtil.getBean(StabilityAiImageModel.class);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return buildTongYiImagesModel(apiKey);
|
||||
case YI_YAN:
|
||||
return buildQianFanImageModel(apiKey);
|
||||
case ZHI_PU:
|
||||
return buildZhiPuAiImageModel(apiKey, url);
|
||||
case OPENAI:
|
||||
return buildOpenAiImageModel(apiKey, url);
|
||||
case SILICON_FLOW:
|
||||
return buildSiliconFlowImageModel(apiKey,url);
|
||||
case STABLE_DIFFUSION:
|
||||
return buildStabilityAiImageModel(apiKey, url);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
|
||||
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey,
|
||||
url);
|
||||
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
|
||||
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class)
|
||||
.getMidjourney();
|
||||
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public SunoApi getOrCreateSunoApi(String apiKey, String url) {
|
||||
String cacheKey = buildClientCacheKey(SunoApi.class, AiPlatformEnum.SUNO.getPlatform(), apiKey, url);
|
||||
return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url));
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("EnhancedSwitchMigration")
|
||||
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model) {
|
||||
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url, model);
|
||||
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return buildTongYiEmbeddingModel(apiKey, model);
|
||||
case YI_YAN:
|
||||
return buildYiYanEmbeddingModel(apiKey, model);
|
||||
case ZHI_PU:
|
||||
return buildZhiPuEmbeddingModel(apiKey, url, model);
|
||||
case MINI_MAX:
|
||||
return buildMiniMaxEmbeddingModel(apiKey, url, model);
|
||||
case OPENAI:
|
||||
return buildOpenAiEmbeddingModel(apiKey, url, model);
|
||||
case AZURE_OPENAI:
|
||||
return buildAzureOpenAiEmbeddingModel(apiKey, url, model);
|
||||
case OLLAMA:
|
||||
return buildOllamaEmbeddingModel(url, model);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type,
|
||||
EmbeddingModel embeddingModel,
|
||||
Map<String, Class<?>> metadataFields) {
|
||||
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel, type);
|
||||
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
|
||||
if (type == SimpleVectorStore.class) {
|
||||
return buildSimpleVectorStore(embeddingModel);
|
||||
}
|
||||
if (type == QdrantVectorStore.class) {
|
||||
return buildQdrantVectorStore(embeddingModel);
|
||||
}
|
||||
if (type == RedisVectorStore.class) {
|
||||
return buildRedisVectorStore(embeddingModel, metadataFields);
|
||||
}
|
||||
if (type == MilvusVectorStore.class) {
|
||||
return buildMilvusVectorStore(embeddingModel);
|
||||
}
|
||||
throw new IllegalArgumentException(StrUtil.format("未知类型({})", type));
|
||||
});
|
||||
}
|
||||
|
||||
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
|
||||
if (ArrayUtil.isEmpty(params)) {
|
||||
return clazz.getName();
|
||||
}
|
||||
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
|
||||
}
|
||||
|
||||
// ========== 各种创建 spring-ai 客户端的方法 ==========
|
||||
|
||||
/**
|
||||
* 可参考 {@link DashScopeAutoConfiguration} 的 dashscopeChatModel 方法
|
||||
*/
|
||||
private static DashScopeChatModel buildTongYiChatModel(String key) {
|
||||
DashScopeApi dashScopeApi = new DashScopeApi(key);
|
||||
DashScopeChatOptions options = DashScopeChatOptions.builder().withModel(DashScopeApi.DEFAULT_CHAT_MODEL)
|
||||
.withTemperature(0.7).build();
|
||||
return new DashScopeChatModel(dashScopeApi, options, getFunctionCallbackResolver(), DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link DashScopeAutoConfiguration} 的 dashScopeImageModel 方法
|
||||
*/
|
||||
private static DashScopeImageModel buildTongYiImagesModel(String key) {
|
||||
DashScopeImageApi dashScopeImageApi = new DashScopeImageApi(key);
|
||||
return new DashScopeImageModel(dashScopeImageApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link QianFanAutoConfiguration} 的 qianFanChatModel 方法
|
||||
*/
|
||||
private static QianFanChatModel buildYiYanChatModel(String key) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
|
||||
String appKey = keys.get(0);
|
||||
String secretKey = keys.get(1);
|
||||
QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
|
||||
return new QianFanChatModel(qianFanApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link QianFanAutoConfiguration} 的 qianFanImageModel 方法
|
||||
*/
|
||||
private QianFanImageModel buildQianFanImageModel(String key) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
|
||||
String appKey = keys.get(0);
|
||||
String secretKey = keys.get(1);
|
||||
QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
|
||||
return new QianFanImageModel(qianFanApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link AiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
|
||||
*/
|
||||
private static DeepSeekChatModel buildDeepSeekChatModel(String apiKey) {
|
||||
YudaoAiProperties.DeepSeekProperties properties = new YudaoAiProperties.DeepSeekProperties()
|
||||
.setApiKey(apiKey);
|
||||
return new AiAutoConfiguration().buildDeepSeekChatModel(properties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link AiAutoConfiguration#douBaoChatClient(YudaoAiProperties)}
|
||||
*/
|
||||
private ChatModel buildDouBaoChatModel(String apiKey) {
|
||||
YudaoAiProperties.DouBaoProperties properties = new YudaoAiProperties.DouBaoProperties()
|
||||
.setApiKey(apiKey);
|
||||
return new AiAutoConfiguration().buildDouBaoChatClient(properties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link AiAutoConfiguration#hunYuanChatClient(YudaoAiProperties)}
|
||||
*/
|
||||
private ChatModel buildHunYuanChatModel(String apiKey, String url) {
|
||||
YudaoAiProperties.HunYuanProperties properties = new YudaoAiProperties.HunYuanProperties()
|
||||
.setBaseUrl(url).setApiKey(apiKey);
|
||||
return new AiAutoConfiguration().buildHunYuanChatClient(properties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link AiAutoConfiguration#siliconFlowChatClient(YudaoAiProperties)}
|
||||
*/
|
||||
private ChatModel buildSiliconFlowChatModel(String apiKey) {
|
||||
YudaoAiProperties.SiliconFlowProperties properties = new YudaoAiProperties.SiliconFlowProperties()
|
||||
.setApiKey(apiKey);
|
||||
return new AiAutoConfiguration().buildSiliconFlowChatClient(properties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link ZhiPuAiAutoConfiguration} 的 zhiPuAiChatModel 方法
|
||||
*/
|
||||
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
|
||||
ZhiPuAiApi zhiPuAiApi = StrUtil.isEmpty(url) ? new ZhiPuAiApi(apiKey)
|
||||
: new ZhiPuAiApi(url, apiKey);
|
||||
ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder().model(ZhiPuAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build();
|
||||
return new ZhiPuAiChatModel(zhiPuAiApi, options, getFunctionCallbackResolver(), DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link ZhiPuAiAutoConfiguration} 的 zhiPuAiImageModel 方法
|
||||
*/
|
||||
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
|
||||
ZhiPuAiImageApi zhiPuAiApi = StrUtil.isEmpty(url) ? new ZhiPuAiImageApi(apiKey)
|
||||
: new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
|
||||
return new ZhiPuAiImageModel(zhiPuAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link MiniMaxAutoConfiguration} 的 miniMaxChatModel 方法
|
||||
*/
|
||||
private MiniMaxChatModel buildMiniMaxChatModel(String apiKey, String url) {
|
||||
MiniMaxApi miniMaxApi = StrUtil.isEmpty(url) ? new MiniMaxApi(apiKey)
|
||||
: new MiniMaxApi(url, apiKey);
|
||||
MiniMaxChatOptions options = MiniMaxChatOptions.builder().model(MiniMaxApi.DEFAULT_CHAT_MODEL).temperature(0.7).build();
|
||||
return new MiniMaxChatModel(miniMaxApi, options, getFunctionCallbackResolver(), DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link MoonshotAutoConfiguration} 的 moonshotChatModel 方法
|
||||
*/
|
||||
private MoonshotChatModel buildMoonshotChatModel(String apiKey, String url) {
|
||||
MoonshotApi moonshotApi = StrUtil.isEmpty(url)? new MoonshotApi(apiKey)
|
||||
: new MoonshotApi(url, apiKey);
|
||||
MoonshotChatOptions options = MoonshotChatOptions.builder().model(MoonshotApi.DEFAULT_CHAT_MODEL).build();
|
||||
return new MoonshotChatModel(moonshotApi, options, getFunctionCallbackResolver(), DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link AiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
|
||||
*/
|
||||
private static XingHuoChatModel buildXingHuoChatModel(String key) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
Assert.equals(keys.size(), 2, "XingHuoChatClient 的密钥需要 (appKey|secretKey) 格式");
|
||||
YudaoAiProperties.XingHuoProperties properties = new YudaoAiProperties.XingHuoProperties()
|
||||
.setAppKey(keys.get(0)).setSecretKey(keys.get(1));
|
||||
return new AiAutoConfiguration().buildXingHuoChatClient(properties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link AiAutoConfiguration#baiChuanChatClient(YudaoAiProperties)}
|
||||
*/
|
||||
private BaiChuanChatModel buildBaiChuanChatModel(String apiKey) {
|
||||
YudaoAiProperties.BaiChuanProperties properties = new YudaoAiProperties.BaiChuanProperties()
|
||||
.setApiKey(apiKey);
|
||||
return new AiAutoConfiguration().buildBaiChuanChatClient(properties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiChatModel 方法
|
||||
*/
|
||||
private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
|
||||
url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
|
||||
OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build();
|
||||
return OpenAiChatModel.builder().openAiApi(openAiApi).toolCallingManager(getToolCallingManager()).build();
|
||||
}
|
||||
|
||||
// TODO @芋艿:手头暂时没密钥,使用建议再测试下
|
||||
/**
|
||||
* 可参考 {@link AzureOpenAiAutoConfiguration}
|
||||
*/
|
||||
private static AzureOpenAiChatModel buildAzureOpenAiChatModel(String apiKey, String url) {
|
||||
AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration();
|
||||
// 创建 OpenAIClient 对象
|
||||
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
|
||||
connectionProperties.setApiKey(apiKey);
|
||||
connectionProperties.setEndpoint(url);
|
||||
OpenAIClientBuilder openAIClient = azureOpenAiAutoConfiguration.openAIClientBuilder(connectionProperties, null);
|
||||
// 获取 AzureOpenAiChatProperties 对象
|
||||
AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class);
|
||||
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties,
|
||||
getToolCallingManager(), null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiImageModel 方法
|
||||
*/
|
||||
private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) {
|
||||
url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
|
||||
OpenAiImageApi openAiApi = OpenAiImageApi.builder().baseUrl(url).apiKey(openAiToken).build();
|
||||
return new OpenAiImageModel(openAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 SiliconFlowImageModel 对象
|
||||
*/
|
||||
private SiliconFlowImageModel buildSiliconFlowImageModel(String apiToken, String url) {
|
||||
url = StrUtil.blankToDefault(url, SiliconFlowApiConstants.DEFAULT_BASE_URL);
|
||||
SiliconFlowImageApi openAiApi = new SiliconFlowImageApi(url, apiToken);
|
||||
return new SiliconFlowImageModel(openAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OllamaAutoConfiguration} 的 ollamaApi 方法
|
||||
*/
|
||||
private static OllamaChatModel buildOllamaChatModel(String url) {
|
||||
OllamaApi ollamaApi = new OllamaApi(url);
|
||||
return OllamaChatModel.builder().ollamaApi(ollamaApi).toolCallingManager(getToolCallingManager()).build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link StabilityAiImageAutoConfiguration} 的 stabilityAiImageModel 方法
|
||||
*/
|
||||
private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) {
|
||||
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
|
||||
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
|
||||
return new StabilityAiImageModel(stabilityAiApi);
|
||||
}
|
||||
|
||||
// ========== 各种创建 EmbeddingModel 的方法 ==========
|
||||
|
||||
/**
|
||||
* 可参考 {@link DashScopeAutoConfiguration} 的 dashscopeEmbeddingModel 方法
|
||||
*/
|
||||
private DashScopeEmbeddingModel buildTongYiEmbeddingModel(String apiKey, String model) {
|
||||
DashScopeApi dashScopeApi = new DashScopeApi(apiKey);
|
||||
DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model).build();
|
||||
return new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, dashScopeEmbeddingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link ZhiPuAiAutoConfiguration} 的 zhiPuAiEmbeddingModel 方法
|
||||
*/
|
||||
private ZhiPuAiEmbeddingModel buildZhiPuEmbeddingModel(String apiKey, String url, String model) {
|
||||
ZhiPuAiApi zhiPuAiApi = StrUtil.isEmpty(url) ? new ZhiPuAiApi(apiKey)
|
||||
: new ZhiPuAiApi(url, apiKey);
|
||||
ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions = ZhiPuAiEmbeddingOptions.builder().model(model).build();
|
||||
return new ZhiPuAiEmbeddingModel(zhiPuAiApi, MetadataMode.EMBED, zhiPuAiEmbeddingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link MiniMaxAutoConfiguration} 的 miniMaxEmbeddingModel 方法
|
||||
*/
|
||||
private EmbeddingModel buildMiniMaxEmbeddingModel(String apiKey, String url, String model) {
|
||||
MiniMaxApi miniMaxApi = StrUtil.isEmpty(url)? new MiniMaxApi(apiKey)
|
||||
: new MiniMaxApi(url, apiKey);
|
||||
MiniMaxEmbeddingOptions miniMaxEmbeddingOptions = MiniMaxEmbeddingOptions.builder().model(model).build();
|
||||
return new MiniMaxEmbeddingModel(miniMaxApi, MetadataMode.EMBED, miniMaxEmbeddingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link QianFanAutoConfiguration} 的 qianFanEmbeddingModel 方法
|
||||
*/
|
||||
private QianFanEmbeddingModel buildYiYanEmbeddingModel(String key, String model) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
|
||||
String appKey = keys.get(0);
|
||||
String secretKey = keys.get(1);
|
||||
QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
|
||||
QianFanEmbeddingOptions qianFanEmbeddingOptions = QianFanEmbeddingOptions.builder().model(model).build();
|
||||
return new QianFanEmbeddingModel(qianFanApi, MetadataMode.EMBED, qianFanEmbeddingOptions);
|
||||
}
|
||||
|
||||
private OllamaEmbeddingModel buildOllamaEmbeddingModel(String url, String model) {
|
||||
OllamaApi ollamaApi = new OllamaApi(url);
|
||||
OllamaOptions ollamaOptions = OllamaOptions.builder().model(model).build();
|
||||
return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiEmbeddingModel 方法
|
||||
*/
|
||||
private OpenAiEmbeddingModel buildOpenAiEmbeddingModel(String openAiToken, String url, String model) {
|
||||
url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
|
||||
OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build();
|
||||
OpenAiEmbeddingOptions openAiEmbeddingProperties = OpenAiEmbeddingOptions.builder().model(model).build();
|
||||
return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, openAiEmbeddingProperties);
|
||||
}
|
||||
|
||||
// TODO @芋艿:手头暂时没密钥,使用建议再测试下
|
||||
/**
|
||||
* 可参考 {@link AzureOpenAiAutoConfiguration} 的 azureOpenAiEmbeddingModel 方法
|
||||
*/
|
||||
private AzureOpenAiEmbeddingModel buildAzureOpenAiEmbeddingModel(String apiKey, String url, String model) {
|
||||
AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration();
|
||||
// 创建 OpenAIClient 对象
|
||||
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
|
||||
connectionProperties.setApiKey(apiKey);
|
||||
connectionProperties.setEndpoint(url);
|
||||
OpenAIClientBuilder openAIClient = azureOpenAiAutoConfiguration.openAIClientBuilder(connectionProperties, null);
|
||||
// 获取 AzureOpenAiChatProperties 对象
|
||||
AzureOpenAiEmbeddingProperties embeddingProperties = SpringUtil.getBean(AzureOpenAiEmbeddingProperties.class);
|
||||
return azureOpenAiAutoConfiguration.azureOpenAiEmbeddingModel(openAIClient, embeddingProperties,
|
||||
null, null);
|
||||
}
|
||||
|
||||
// ========== 各种创建 VectorStore 的方法 ==========
|
||||
|
||||
/**
|
||||
* 注意:仅适合本地测试使用,生产建议还是使用 Qdrant、Milvus 等
|
||||
*/
|
||||
@SneakyThrows
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored")
|
||||
private SimpleVectorStore buildSimpleVectorStore(EmbeddingModel embeddingModel) {
|
||||
SimpleVectorStore vectorStore = SimpleVectorStore.builder(embeddingModel).build();
|
||||
// 启动加载
|
||||
File file = new File(StrUtil.format("{}/vector_store/simple_{}.json",
|
||||
FileUtil.getUserHomePath(), embeddingModel.getClass().getSimpleName()));
|
||||
if (!file.exists()) {
|
||||
FileUtil.mkParentDirs(file);
|
||||
file.createNewFile();
|
||||
} else if (file.length() > 0) {
|
||||
vectorStore.load(file);
|
||||
}
|
||||
// 定时持久化,每分钟一次
|
||||
Timer timer = new Timer("SimpleVectorStoreTimer-" + file.getAbsolutePath());
|
||||
timer.scheduleAtFixedRate(new TimerTask() {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
vectorStore.save(file);
|
||||
}
|
||||
|
||||
}, Duration.ofMinutes(1).toMillis(), Duration.ofMinutes(1).toMillis());
|
||||
// 关闭时,进行持久化
|
||||
RuntimeUtil.addShutdownHook(() -> vectorStore.save(file));
|
||||
return vectorStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* 参考 {@link QdrantVectorStoreAutoConfiguration} 的 vectorStore 方法
|
||||
*/
|
||||
@SneakyThrows
|
||||
private QdrantVectorStore buildQdrantVectorStore(EmbeddingModel embeddingModel) {
|
||||
QdrantVectorStoreAutoConfiguration configuration = new QdrantVectorStoreAutoConfiguration();
|
||||
QdrantVectorStoreProperties properties = SpringUtil.getBean(QdrantVectorStoreProperties.class);
|
||||
// 参考 QdrantVectorStoreAutoConfiguration 实现,创建 QdrantClient 对象
|
||||
QdrantGrpcClient.Builder grpcClientBuilder = QdrantGrpcClient.newBuilder(
|
||||
properties.getHost(), properties.getPort(), properties.isUseTls());
|
||||
if (StrUtil.isNotEmpty(properties.getApiKey())) {
|
||||
grpcClientBuilder.withApiKey(properties.getApiKey());
|
||||
}
|
||||
QdrantClient qdrantClient = new QdrantClient(grpcClientBuilder.build());
|
||||
// 创建 QdrantVectorStore 对象
|
||||
QdrantVectorStore vectorStore = configuration.vectorStore(embeddingModel, properties, qdrantClient,
|
||||
getObservationRegistry(), getCustomObservationConvention(), getBatchingStrategy());
|
||||
// 初始化索引
|
||||
vectorStore.afterPropertiesSet();
|
||||
return vectorStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* 参考 {@link RedisVectorStoreAutoConfiguration} 的 vectorStore 方法
|
||||
*/
|
||||
private RedisVectorStore buildRedisVectorStore(EmbeddingModel embeddingModel,
|
||||
Map<String, Class<?>> metadataFields) {
|
||||
// 创建 JedisPooled 对象
|
||||
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
|
||||
JedisPooled jedisPooled = new JedisPooled(redisProperties.getHost(), redisProperties.getPort());
|
||||
// 创建 RedisVectorStoreProperties 对象
|
||||
RedisVectorStoreAutoConfiguration configuration = new RedisVectorStoreAutoConfiguration();
|
||||
RedisVectorStoreProperties properties = SpringUtil.getBean(RedisVectorStoreProperties.class);
|
||||
RedisVectorStore redisVectorStore = RedisVectorStore.builder(jedisPooled, embeddingModel)
|
||||
.indexName(properties.getIndex()).prefix(properties.getPrefix())
|
||||
.initializeSchema(properties.isInitializeSchema())
|
||||
.metadataFields(convertList(metadataFields.entrySet(), entry -> {
|
||||
String fieldName = entry.getKey();
|
||||
Class<?> fieldType = entry.getValue();
|
||||
if (Number.class.isAssignableFrom(fieldType)) {
|
||||
return RedisVectorStore.MetadataField.numeric(fieldName);
|
||||
}
|
||||
if (Boolean.class.isAssignableFrom(fieldType)) {
|
||||
return RedisVectorStore.MetadataField.tag(fieldName);
|
||||
}
|
||||
return RedisVectorStore.MetadataField.text(fieldName);
|
||||
}))
|
||||
.observationRegistry(getObservationRegistry().getObject())
|
||||
.customObservationConvention(getCustomObservationConvention().getObject())
|
||||
.batchingStrategy(getBatchingStrategy())
|
||||
.build();
|
||||
// 初始化索引
|
||||
redisVectorStore.afterPropertiesSet();
|
||||
return redisVectorStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* 参考 {@link MilvusVectorStoreAutoConfiguration} 的 vectorStore 方法
|
||||
*/
|
||||
@SneakyThrows
|
||||
private MilvusVectorStore buildMilvusVectorStore(EmbeddingModel embeddingModel) {
|
||||
MilvusVectorStoreAutoConfiguration configuration = new MilvusVectorStoreAutoConfiguration();
|
||||
// 获取配置属性
|
||||
MilvusVectorStoreProperties serverProperties = SpringUtil.getBean(MilvusVectorStoreProperties.class);
|
||||
MilvusServiceClientProperties clientProperties = SpringUtil.getBean(MilvusServiceClientProperties.class);
|
||||
|
||||
// 创建 MilvusServiceClient 对象
|
||||
MilvusServiceClient milvusClient = configuration.milvusClient(serverProperties, clientProperties,
|
||||
new MilvusServiceClientConnectionDetails() {
|
||||
|
||||
@Override
|
||||
public String getHost() {
|
||||
return clientProperties.getHost();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getPort() {
|
||||
return clientProperties.getPort();
|
||||
}
|
||||
|
||||
}
|
||||
);
|
||||
// 创建 MilvusVectorStore 对象
|
||||
MilvusVectorStore vectorStore = configuration.vectorStore(milvusClient, embeddingModel, serverProperties,
|
||||
getBatchingStrategy(), getObservationRegistry(), getCustomObservationConvention());
|
||||
|
||||
// 初始化索引
|
||||
vectorStore.afterPropertiesSet();
|
||||
return vectorStore;
|
||||
}
|
||||
|
||||
private static ObjectProvider<ObservationRegistry> getObservationRegistry() {
|
||||
return new ObjectProvider<>() {
|
||||
|
||||
@Override
|
||||
public ObservationRegistry getObject() throws BeansException {
|
||||
return SpringUtil.getBean(ObservationRegistry.class);
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
private static ObjectProvider<VectorStoreObservationConvention> getCustomObservationConvention() {
|
||||
return new ObjectProvider<>() {
|
||||
@Override
|
||||
public VectorStoreObservationConvention getObject() throws BeansException {
|
||||
return new DefaultVectorStoreObservationConvention();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private static BatchingStrategy getBatchingStrategy() {
|
||||
return SpringUtil.getBean(BatchingStrategy.class);
|
||||
}
|
||||
|
||||
private static ToolCallingManager getToolCallingManager() {
|
||||
return SpringUtil.getBean(ToolCallingManager.class);
|
||||
}
|
||||
|
||||
private static FunctionCallbackResolver getFunctionCallbackResolver() {
|
||||
return SpringUtil.getBean(FunctionCallbackResolver.class);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.baichuan;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* 百川 {@link ChatModel} 实现类
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class BaiChuanChatModel implements ChatModel {
|
||||
|
||||
public static final String BASE_URL = "https://api.baichuan-ai.com";
|
||||
|
||||
public static final String MODEL_DEFAULT = "Baichuan4-Turbo";
|
||||
|
||||
/**
|
||||
* 兼容 OpenAI 接口,进行复用
|
||||
*/
|
||||
private final OpenAiChatModel openAiChatModel;
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
return openAiChatModel.call(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return openAiChatModel.stream(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return openAiChatModel.getDefaultOptions();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.deepseek;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* DeepSeek {@link ChatModel} 实现类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class DeepSeekChatModel implements ChatModel {
|
||||
|
||||
public static final String BASE_URL = "https://api.deepseek.com";
|
||||
|
||||
public static final String MODEL_DEFAULT = "deepseek-chat";
|
||||
|
||||
/**
|
||||
* 兼容 OpenAI 接口,进行复用
|
||||
*/
|
||||
private final OpenAiChatModel openAiChatModel;
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
return openAiChatModel.call(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return openAiChatModel.stream(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return openAiChatModel.getDefaultOptions();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* 字节豆包 {@link ChatModel} 实现类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class DouBaoChatModel implements ChatModel {
|
||||
|
||||
public static final String BASE_URL = "https://ark.cn-beijing.volces.com/api";
|
||||
|
||||
public static final String MODEL_DEFAULT = "doubao-1-5-lite-32k-250115";
|
||||
|
||||
/**
|
||||
* 兼容 OpenAI 接口,进行复用
|
||||
*/
|
||||
private final OpenAiChatModel openAiChatModel;
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
return openAiChatModel.call(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return openAiChatModel.stream(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return openAiChatModel.getDefaultOptions();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.hunyuan;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* 腾云混元 {@link ChatModel} 实现类
|
||||
*
|
||||
* 1. 混元大模型:基于 <a href="https://cloud.tencent.com/document/product/1729/111007">知识引擎原子能力</a> 实现
|
||||
* 2. 知识引擎原子能力:基于 <a href="https://cloud.tencent.com/document/product/1772/115969">知识引擎原子能力</a> 实现
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class HunYuanChatModel implements ChatModel {
|
||||
|
||||
public static final String BASE_URL = "https://api.hunyuan.cloud.tencent.com";
|
||||
|
||||
public static final String MODEL_DEFAULT = "hunyuan-turbo";
|
||||
|
||||
public static final String DEEP_SEEK_BASE_URL = "https://api.lkeap.cloud.tencent.com";
|
||||
|
||||
public static final String DEEP_SEEK_MODEL_DEFAULT = "deepseek-v3";
|
||||
|
||||
/**
|
||||
* 兼容 OpenAI 接口,进行复用
|
||||
*/
|
||||
private final OpenAiChatModel openAiChatModel;
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
return openAiChatModel.call(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return openAiChatModel.stream(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return openAiChatModel.getDefaultOptions();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,351 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.http.HttpStatusCode;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* Midjourney API
|
||||
*
|
||||
* @author fansili
|
||||
* @since 1.0
|
||||
*/
|
||||
@Slf4j
|
||||
public class MidjourneyApi {
|
||||
|
||||
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
|
||||
|
||||
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
|
||||
reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
|
||||
HttpRequest request = response.request();
|
||||
log.error("[midjourney-api] 调用失败!请求方式:[{}],请求地址:[{}],请求参数:[{}],响应数据: [{}]",
|
||||
request.getMethod(), request.getURI(), reqParam, responseBody);
|
||||
sink.error(new IllegalStateException("[midjourney-api] 调用失败!"));
|
||||
});
|
||||
|
||||
private final WebClient webClient;
|
||||
|
||||
/**
|
||||
* 回调地址
|
||||
*/
|
||||
private final String notifyUrl;
|
||||
|
||||
public MidjourneyApi(String baseUrl, String apiKey, String notifyUrl) {
|
||||
this.webClient = WebClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.defaultHeaders(httpHeaders -> {
|
||||
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
|
||||
httpHeaders.setBearerAuth(apiKey);
|
||||
})
|
||||
.build();
|
||||
this.notifyUrl = notifyUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* imagine - 根据提示词提交绘画任务
|
||||
*
|
||||
* @param request 请求
|
||||
* @return 提交结果
|
||||
*/
|
||||
public SubmitResponse imagine(ImagineRequest request) {
|
||||
if (StrUtil.isEmpty(request.getNotifyHook())) {
|
||||
request.setNotifyHook(notifyUrl);
|
||||
}
|
||||
String response = post("/submit/imagine", request);
|
||||
return JsonUtils.parseObject(response, SubmitResponse.class);
|
||||
}
|
||||
|
||||
/**
|
||||
* action - 放大、缩小、U1、U2...
|
||||
*
|
||||
* @param request 请求
|
||||
* @return 提交结果
|
||||
*/
|
||||
public SubmitResponse action(ActionRequest request) {
|
||||
if (StrUtil.isEmpty(request.getNotifyHook())) {
|
||||
request.setNotifyHook(notifyUrl);
|
||||
}
|
||||
String response = post("/submit/action", request);
|
||||
return JsonUtils.parseObject(response, SubmitResponse.class);
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量查询 task 任务
|
||||
*
|
||||
* @param ids 任务编号数组
|
||||
* @return task 任务
|
||||
*/
|
||||
public List<Notify> getTaskList(Collection<String> ids) {
|
||||
String res = post("/task/list-by-condition", ImmutableMap.of("ids", ids));
|
||||
return JsonUtils.parseArray(res, Notify.class);
|
||||
}
|
||||
|
||||
private String post(String uri, Object body) {
|
||||
return webClient.post()
|
||||
.uri(uri)
|
||||
.body(Mono.just(JsonUtils.toJsonString(body)), String.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(body))
|
||||
.bodyToMono(String.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
// ========== record 结构 ==========
|
||||
|
||||
/**
|
||||
* Imagine 请求(生成图片)
|
||||
*/
|
||||
@Data
|
||||
public static final class ImagineRequest {
|
||||
|
||||
/**
|
||||
* 垫图(参考图) base64 数组
|
||||
*/
|
||||
private List<String> base64Array;
|
||||
/**
|
||||
* 提示词
|
||||
*/
|
||||
private String prompt;
|
||||
/**
|
||||
* 通知地址
|
||||
*/
|
||||
private String notifyHook;
|
||||
/**
|
||||
* 自定义参数
|
||||
*/
|
||||
private String state;
|
||||
|
||||
public ImagineRequest(List<String> base64Array, String prompt, String notifyHook, String state) {
|
||||
this.base64Array = base64Array;
|
||||
this.prompt = prompt;
|
||||
this.notifyHook = notifyHook;
|
||||
this.state = state;
|
||||
}
|
||||
|
||||
public static String buildState(Integer width, Integer height, String version, String model) {
|
||||
StringBuilder params = new StringBuilder();
|
||||
// --ar 来设置尺寸
|
||||
params.append(String.format(" --ar %s:%s ", width, height));
|
||||
// --niji 模型
|
||||
if (ModelEnum.NIJI.getModel().equals(model)) {
|
||||
params.append(String.format(" --niji %s ", version));
|
||||
} else {
|
||||
params.append(String.format(" --v %s ", version));
|
||||
}
|
||||
return params.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Action 请求
|
||||
*/
|
||||
@Data
|
||||
public static final class ActionRequest {
|
||||
|
||||
private String customId;
|
||||
private String taskId;
|
||||
private String notifyHook;
|
||||
|
||||
public ActionRequest(String taskId, String customId, String notifyHook) {
|
||||
this.customId = customId;
|
||||
this.taskId = taskId;
|
||||
this.notifyHook = notifyHook;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Submit 统一返回
|
||||
*
|
||||
* @param code 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
|
||||
* @param description 描述
|
||||
* @param properties 扩展字段
|
||||
* @param result 任务ID
|
||||
*/
|
||||
public record SubmitResponse(String code,
|
||||
String description,
|
||||
Map<String, Object> properties,
|
||||
String result) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 通知 request
|
||||
*
|
||||
* @param id job id
|
||||
* @param action 任务类型 {@link TaskActionEnum}
|
||||
* @param status 任务状态 {@link TaskStatusEnum}
|
||||
* @param prompt 提示词
|
||||
* @param promptEn 提示词-英文
|
||||
* @param description 任务描述
|
||||
* @param state 自定义参数
|
||||
* @param submitTime 提交时间
|
||||
* @param startTime 开始执行时间
|
||||
* @param finishTime 结束时间
|
||||
* @param imageUrl 图片url
|
||||
* @param progress 任务进度
|
||||
* @param failReason 失败原因
|
||||
* @param buttons 任务完成后的可执行按钮
|
||||
*/
|
||||
public record Notify(String id,
|
||||
String action,
|
||||
String status,
|
||||
|
||||
String prompt,
|
||||
String promptEn,
|
||||
|
||||
String description,
|
||||
String state,
|
||||
|
||||
Long submitTime,
|
||||
Long startTime,
|
||||
Long finishTime,
|
||||
|
||||
String imageUrl,
|
||||
String progress,
|
||||
String failReason,
|
||||
List<Button> buttons) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* button
|
||||
*
|
||||
* @param customId MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862 动作标识
|
||||
* @param emoji 图标 emoji
|
||||
* @param label Make Variations 文本
|
||||
* @param type 类型,系统内部使用
|
||||
* @param style 样式: 2(Primary)、3(Green)
|
||||
*/
|
||||
public record Button(String customId,
|
||||
String emoji,
|
||||
String label,
|
||||
String type,
|
||||
String style) {
|
||||
}
|
||||
|
||||
// ============ enums ============
|
||||
|
||||
/**
|
||||
* 模型枚举
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
public enum ModelEnum {
|
||||
|
||||
MIDJOURNEY("midjourney", "midjourney"),
|
||||
NIJI("niji", "niji"),
|
||||
;
|
||||
|
||||
private final String model;
|
||||
private final String name;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 提交返回的状态码的枚举
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum SubmitCodeEnum {
|
||||
|
||||
SUBMIT_SUCCESS("1", "提交成功"),
|
||||
ALREADY_EXISTS("21", "已存在"),
|
||||
QUEUING("22", "排队中"),
|
||||
;
|
||||
|
||||
public static final List<String> SUCCESS_CODES = Lists.newArrayList(
|
||||
SUBMIT_SUCCESS.code,
|
||||
ALREADY_EXISTS.code,
|
||||
QUEUING.code
|
||||
);
|
||||
|
||||
private final String code;
|
||||
private final String name;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Action 枚举
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum TaskActionEnum {
|
||||
|
||||
/**
|
||||
* 生成图片
|
||||
*/
|
||||
IMAGINE,
|
||||
/**
|
||||
* 选中放大
|
||||
*/
|
||||
UPSCALE,
|
||||
/**
|
||||
* 选中其中的一张图,生成四张相似的
|
||||
*/
|
||||
VARIATION,
|
||||
/**
|
||||
* 重新执行
|
||||
*/
|
||||
REROLL,
|
||||
/**
|
||||
* 图转 prompt
|
||||
*/
|
||||
DESCRIBE,
|
||||
/**
|
||||
* 多图混合
|
||||
*/
|
||||
BLEND
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 任务状态枚举
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum TaskStatusEnum {
|
||||
|
||||
/**
|
||||
* 未启动
|
||||
*/
|
||||
NOT_START(0),
|
||||
/**
|
||||
* 已提交
|
||||
*/
|
||||
SUBMITTED(1),
|
||||
/**
|
||||
* 执行中
|
||||
*/
|
||||
IN_PROGRESS(3),
|
||||
/**
|
||||
* 失败
|
||||
*/
|
||||
FAILURE(4),
|
||||
/**
|
||||
* 成功
|
||||
*/
|
||||
SUCCESS(4);
|
||||
|
||||
private final int order;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow;
|
||||
|
||||
/**
|
||||
* SiliconFlow API 枚举类
|
||||
*
|
||||
* @author zzt
|
||||
*/
|
||||
public final class SiliconFlowApiConstants {
|
||||
|
||||
public static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn";
|
||||
|
||||
public static final String MODEL_DEFAULT = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B";
|
||||
|
||||
public static final String DEFAULT_IMAGE_MODEL = "Kwai-Kolors/Kolors";
|
||||
|
||||
public static final String PROVIDER_NAME = "Siiconflow";
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* 硅基流动 {@link ChatModel} 实现类
|
||||
*
|
||||
* 1. API 文档:<a href="https://docs.siliconflow.cn/cn/api-reference/chat-completions/chat-completions">API 文档</a>
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class SiliconFlowChatModel implements ChatModel {
|
||||
|
||||
/**
|
||||
* 兼容 OpenAI 接口,进行复用
|
||||
*/
|
||||
private final OpenAiChatModel openAiChatModel;
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
return openAiChatModel.call(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return openAiChatModel.stream(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return openAiChatModel.getDefaultOptions();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import org.springframework.ai.model.ApiKey;
|
||||
import org.springframework.ai.model.NoopApiKey;
|
||||
import org.springframework.ai.model.SimpleApiKey;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.ResponseErrorHandler;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 硅基流动 Image API
|
||||
*
|
||||
* @see <a href= "https://docs.siliconflow.cn/cn/api-reference/images/images-generations">Images</a>
|
||||
*
|
||||
* @author zzt
|
||||
*/
|
||||
public class SiliconFlowImageApi {
|
||||
|
||||
private final RestClient restClient;
|
||||
|
||||
public SiliconFlowImageApi(String aiToken) {
|
||||
this(SiliconFlowApiConstants.DEFAULT_BASE_URL, aiToken, RestClient.builder());
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, String openAiToken) {
|
||||
this(baseUrl, openAiToken, RestClient.builder());
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
|
||||
this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers,
|
||||
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers,
|
||||
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
// @formatter:off
|
||||
this.restClient = restClientBuilder.baseUrl(baseUrl)
|
||||
.defaultHeaders(h -> {
|
||||
if(!(apiKey instanceof NoopApiKey)) {
|
||||
h.setBearerAuth(apiKey.getValue());
|
||||
}
|
||||
h.setContentType(MediaType.APPLICATION_JSON);
|
||||
h.addAll(headers);
|
||||
})
|
||||
.defaultStatusHandler(responseErrorHandler)
|
||||
.build();
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
public ResponseEntity<OpenAiImageApi.OpenAiImageResponse> createImage(SiliconflowImageRequest siliconflowImageRequest) {
|
||||
Assert.notNull(siliconflowImageRequest, "Image request cannot be null.");
|
||||
Assert.hasLength(siliconflowImageRequest.prompt(), "Prompt cannot be empty.");
|
||||
|
||||
return this.restClient.post()
|
||||
.uri("v1/images/generations")
|
||||
.body(siliconflowImageRequest)
|
||||
.retrieve()
|
||||
.toEntity(OpenAiImageApi.OpenAiImageResponse.class);
|
||||
}
|
||||
|
||||
|
||||
// @formatter:off
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public record SiliconflowImageRequest (
|
||||
@JsonProperty("prompt") String prompt,
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("batch_size") Integer batchSize,
|
||||
@JsonProperty("negative_prompt") String negativePrompt,
|
||||
@JsonProperty("seed") Integer seed,
|
||||
@JsonProperty("num_inference_steps") Integer numInferenceSteps,
|
||||
@JsonProperty("guidance_scale") Float guidanceScale,
|
||||
@JsonProperty("image") String image) {
|
||||
|
||||
public SiliconflowImageRequest(String prompt, String model) {
|
||||
this(prompt, model, null, null, null, null, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import lombok.Setter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.image.*;
|
||||
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
|
||||
import org.springframework.ai.image.observation.ImageModelObservationContext;
|
||||
import org.springframework.ai.image.observation.ImageModelObservationConvention;
|
||||
import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.openai.OpenAiImageModel;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 硅基流动 {@link ImageModel} 实现类
|
||||
*
|
||||
* 参考 {@link OpenAiImageModel} 实现
|
||||
*
|
||||
* @author zzt
|
||||
*/
|
||||
public class SiliconFlowImageModel implements ImageModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(SiliconFlowImageModel.class);
|
||||
|
||||
private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention();
|
||||
|
||||
private final SiliconFlowImageOptions defaultOptions;
|
||||
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
private final SiliconFlowImageApi siliconFlowImageApi;
|
||||
|
||||
private final ObservationRegistry observationRegistry;
|
||||
|
||||
@Setter
|
||||
private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
|
||||
|
||||
public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi) {
|
||||
this(siliconFlowImageApi, SiliconFlowImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi, SiliconFlowImageOptions options, RetryTemplate retryTemplate) {
|
||||
this(siliconFlowImageApi, options, retryTemplate, ObservationRegistry.NOOP);
|
||||
}
|
||||
|
||||
public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi, SiliconFlowImageOptions options, RetryTemplate retryTemplate,
|
||||
ObservationRegistry observationRegistry) {
|
||||
Assert.notNull(siliconFlowImageApi, "OpenAiImageApi must not be null");
|
||||
Assert.notNull(options, "options must not be null");
|
||||
Assert.notNull(retryTemplate, "retryTemplate must not be null");
|
||||
Assert.notNull(observationRegistry, "observationRegistry must not be null");
|
||||
this.siliconFlowImageApi = siliconFlowImageApi;
|
||||
this.defaultOptions = options;
|
||||
this.retryTemplate = retryTemplate;
|
||||
this.observationRegistry = observationRegistry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageResponse call(ImagePrompt imagePrompt) {
|
||||
SiliconFlowImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
|
||||
SiliconFlowImageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);
|
||||
|
||||
var observationContext = ImageModelObservationContext.builder()
|
||||
.imagePrompt(imagePrompt)
|
||||
.provider(SiliconFlowApiConstants.PROVIDER_NAME)
|
||||
.requestOptions(imagePrompt.getOptions())
|
||||
.build();
|
||||
|
||||
return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION
|
||||
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
|
||||
this.observationRegistry)
|
||||
.observe(() -> {
|
||||
ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity = this.retryTemplate
|
||||
.execute(ctx -> this.siliconFlowImageApi.createImage(imageRequest));
|
||||
|
||||
ImageResponse imageResponse = convertResponse(imageResponseEntity, imageRequest);
|
||||
|
||||
observationContext.setResponse(imageResponse);
|
||||
|
||||
return imageResponse;
|
||||
});
|
||||
}
|
||||
|
||||
private SiliconFlowImageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt,
|
||||
SiliconFlowImageOptions requestImageOptions) {
|
||||
String instructions = imagePrompt.getInstructions().get(0).getText();
|
||||
|
||||
SiliconFlowImageApi.SiliconflowImageRequest imageRequest = new SiliconFlowImageApi.SiliconflowImageRequest(instructions,
|
||||
SiliconFlowApiConstants.DEFAULT_IMAGE_MODEL);
|
||||
|
||||
return ModelOptionsUtils.merge(requestImageOptions, imageRequest, SiliconFlowImageApi.SiliconflowImageRequest.class);
|
||||
}
|
||||
|
||||
private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
|
||||
SiliconFlowImageApi.SiliconflowImageRequest siliconflowImageRequest) {
|
||||
OpenAiImageApi.OpenAiImageResponse imageApiResponse = imageResponseEntity.getBody();
|
||||
if (imageApiResponse == null) {
|
||||
logger.warn("No image response returned for request: {}", siliconflowImageRequest);
|
||||
return new ImageResponse(List.of());
|
||||
}
|
||||
|
||||
List<ImageGeneration> imageGenerationList = imageApiResponse.data()
|
||||
.stream()
|
||||
.map(entry -> new ImageGeneration(new Image(entry.url(), entry.b64Json()),
|
||||
new OpenAiImageGenerationMetadata(entry.revisedPrompt())))
|
||||
.toList();
|
||||
|
||||
ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created());
|
||||
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
|
||||
}
|
||||
|
||||
private SiliconFlowImageOptions mergeOptions(@Nullable ImageOptions runtimeOptions, SiliconFlowImageOptions defaultOptions) {
|
||||
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
|
||||
SiliconFlowImageOptions.class);
|
||||
|
||||
if (runtimeOptionsForProvider == null) {
|
||||
return defaultOptions;
|
||||
}
|
||||
|
||||
return SiliconFlowImageOptions.builder()
|
||||
// Handle portable image options
|
||||
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
|
||||
.batchSize(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
|
||||
.width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
|
||||
.height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
|
||||
// Handle SiliconFlow specific image options
|
||||
.negativePrompt(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNegativePrompt(), defaultOptions.getNegativePrompt()))
|
||||
.numInferenceSteps(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNumInferenceSteps(), defaultOptions.getNumInferenceSteps()))
|
||||
.guidanceScale(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getGuidanceScale(), defaultOptions.getGuidanceScale()))
|
||||
.seed(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getSeed(), defaultOptions.getSeed()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
|
||||
/**
|
||||
* 硅基流动 {@link ImageOptions}
|
||||
*
|
||||
* @author zzt
|
||||
*/
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class SiliconFlowImageOptions implements ImageOptions {
|
||||
|
||||
@JsonProperty("model")
|
||||
private String model;
|
||||
|
||||
@JsonProperty("negative_prompt")
|
||||
private String negativePrompt;
|
||||
|
||||
/**
|
||||
* The number of images to generate. Must be between 1 and 4.
|
||||
*/
|
||||
@JsonProperty("image_size")
|
||||
private String imageSize;
|
||||
|
||||
/**
|
||||
* The number of images to generate. Must be between 1 and 4.
|
||||
*/
|
||||
@JsonProperty("batch_size")
|
||||
private Integer batchSize = 1;
|
||||
|
||||
/**
|
||||
* number of inference steps
|
||||
*/
|
||||
@JsonProperty("num_inference_steps")
|
||||
private Integer numInferenceSteps = 25;
|
||||
|
||||
/**
|
||||
* This value is used to control the degree of match between the generated image and the given prompt. The higher the value, the more the generated image will tend to strictly match the text prompt. The lower the value, the more creative and diverse the generated image will be, potentially containing more unexpected elements.
|
||||
*
|
||||
* Required range: 0 <= x <= 20
|
||||
*/
|
||||
@JsonProperty("guidance_scale")
|
||||
private Float guidanceScale = 0.75F;
|
||||
|
||||
/**
|
||||
* 如果想要每次都生成固定的图片,可以把 seed 设置为固定值
|
||||
*
|
||||
*/
|
||||
@JsonProperty("seed")
|
||||
private Integer seed = (int)(Math.random() * 1_000_000_000);
|
||||
|
||||
/**
|
||||
* The image that needs to be uploaded should be converted into base64 format.
|
||||
*/
|
||||
@JsonProperty("image")
|
||||
private String image;
|
||||
|
||||
/**
|
||||
* 宽
|
||||
*/
|
||||
private Integer width;
|
||||
|
||||
/**
|
||||
* 高
|
||||
*/
|
||||
private Integer height;
|
||||
|
||||
public void setHeight(Integer height) {
|
||||
this.height = height;
|
||||
if (this.width != null && this.height != null) {
|
||||
this.imageSize = this.width + "x" + this.height;
|
||||
}
|
||||
}
|
||||
|
||||
public void setWidth(Integer width) {
|
||||
this.width = width;
|
||||
if (this.width != null && this.height != null) {
|
||||
this.imageSize = this.width + "x" + this.height;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getN() {
|
||||
return batchSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResponseFormat() {
|
||||
return "url";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStyle() {
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.suno.api;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.text.StrPool;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.http.HttpStatusCode;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* Suno API
|
||||
* <p>
|
||||
* 对接 Suno Proxy:<a href="https://github.com/gcui-art/suno-api">suno-api</a>
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
@Slf4j
|
||||
public class SunoApi {
|
||||
|
||||
private final WebClient webClient;
|
||||
|
||||
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
|
||||
|
||||
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
|
||||
reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
|
||||
HttpRequest request = response.request();
|
||||
log.error("[suno-api] 调用失败!请求方式:[{}],请求地址:[{}],请求参数:[{}],响应数据: [{}]",
|
||||
request.getMethod(), request.getURI(), reqParam, responseBody);
|
||||
sink.error(new IllegalStateException("[suno-api] 调用失败!"));
|
||||
});
|
||||
|
||||
public SunoApi(String baseUrl) {
|
||||
this.webClient = WebClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.defaultHeaders((headers) -> headers.setContentType(MediaType.APPLICATION_JSON))
|
||||
.build();
|
||||
}
|
||||
|
||||
public List<MusicData> generate(MusicGenerateRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/generate")
|
||||
.body(Mono.just(request), MusicGenerateRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
public List<MusicData> customGenerate(MusicGenerateRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/custom_generate")
|
||||
.body(Mono.just(request), MusicGenerateRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
public LyricsData generateLyrics(String prompt) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/generate_lyrics")
|
||||
.body(Mono.just(new MusicGenerateRequest(prompt)), MusicGenerateRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(prompt))
|
||||
.bodyToMono(LyricsData.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
public List<MusicData> getMusicList(List<String> ids) {
|
||||
return this.webClient.get()
|
||||
.uri(uriBuilder -> uriBuilder
|
||||
.path("/api/get")
|
||||
.queryParam("ids", CollUtil.join(ids, StrPool.COMMA))
|
||||
.build())
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(ids))
|
||||
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
public LimitUsageData getLimitUsage() {
|
||||
return this.webClient.get()
|
||||
.uri("/api/get_limit")
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(null))
|
||||
.bodyToMono(LimitUsageData.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据提示生成音频
|
||||
*
|
||||
* @param prompt 用于生成音乐音频的提示
|
||||
* @param tags 音乐风格
|
||||
* @param title 音乐名称
|
||||
* @param model 模型
|
||||
* @param waitAudio false 表示后台模式,仅返回音频任务信息,需要调用 get API 获取详细的音频信息。
|
||||
* true 表示同步模式,API 最多等待 100s,音频生成完毕后直接返回音频链接等信息,建议在 GPT 等 agent 中使用。
|
||||
* @param makeInstrumental 指示音乐音频是否为定制,如果为 true,则从歌词生成,否则从提示生成
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record MusicGenerateRequest(
|
||||
String prompt,
|
||||
String tags,
|
||||
String title,
|
||||
String model,
|
||||
@JsonProperty("wait_audio") boolean waitAudio,
|
||||
@JsonProperty("make_instrumental") boolean makeInstrumental
|
||||
) {
|
||||
|
||||
public MusicGenerateRequest(String prompt) {
|
||||
this(prompt, null, null, null, false, false);
|
||||
}
|
||||
|
||||
public MusicGenerateRequest(String prompt, String model, boolean makeInstrumental) {
|
||||
this(prompt, null, null, model, false, makeInstrumental);
|
||||
}
|
||||
|
||||
public MusicGenerateRequest(String prompt, String model, String tags, String title) {
|
||||
this(prompt, tags, title, model, false, false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Suno API 响应的音频数据
|
||||
*
|
||||
* @param id 音乐数据的 ID
|
||||
* @param title 音乐音频的标题
|
||||
* @param imageUrl 音乐音频的图片 URL
|
||||
* @param lyric 音乐音频的歌词
|
||||
* @param audioUrl 音乐音频的 URL
|
||||
* @param videoUrl 音乐视频的 URL
|
||||
* @param createdAt 音乐音频的创建时间
|
||||
* @param modelName 模型名称
|
||||
* @param status submitted、queued、streaming、complete
|
||||
* @param gptDescriptionPrompt 描述词
|
||||
* @param prompt 生成音乐音频的提示
|
||||
* @param type 操作类型
|
||||
* @param tags 音乐类型标签
|
||||
* @param duration 音乐时长
|
||||
*/
|
||||
public record MusicData(
|
||||
String id,
|
||||
String title,
|
||||
@JsonProperty("image_url") String imageUrl,
|
||||
String lyric,
|
||||
@JsonProperty("audio_url") String audioUrl,
|
||||
@JsonProperty("video_url") String videoUrl,
|
||||
@JsonProperty("created_at") String createdAt,
|
||||
@JsonProperty("model_name") String modelName,
|
||||
String status,
|
||||
@JsonProperty("gpt_description_prompt") String gptDescriptionPrompt,
|
||||
@JsonProperty("error_message") String errorMessage,
|
||||
String prompt,
|
||||
String type,
|
||||
String tags,
|
||||
Double duration
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Suno API 响应的歌词数据。
|
||||
*
|
||||
* @param text 歌词
|
||||
* @param title 标题
|
||||
* @param status 状态
|
||||
*/
|
||||
public record LyricsData(
|
||||
String text,
|
||||
String title,
|
||||
String status
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Suno API 响应的限额数据,目前每日免费 50
|
||||
*/
|
||||
public record LimitUsageData(
|
||||
@JsonProperty("credits_left") Long creditsLeft,
|
||||
String period,
|
||||
@JsonProperty("monthly_limit") Long monthlyLimit,
|
||||
@JsonProperty("monthly_usage") Long monthlyUsage
|
||||
) {
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,381 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.wenduoduo.api;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import com.fasterxml.jackson.annotation.JsonFormat;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.http.HttpStatusCode;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.reactive.function.BodyInserters;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* 文多多 API
|
||||
*
|
||||
* @author xiaoxin
|
||||
* @see <a href="https://docmee.cn/open-platform/api">PPT 生成 API</a>
|
||||
*/
|
||||
@Slf4j
|
||||
public class WenDuoDuoPptApi {
|
||||
|
||||
public static final String BASE_URL = "https://docmee.cn";
|
||||
public static final String TOKEN_NAME = "token";
|
||||
|
||||
private final WebClient webClient;
|
||||
|
||||
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
|
||||
|
||||
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
|
||||
reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
|
||||
HttpRequest request = response.request();
|
||||
log.error("[WenDuoDuoPptApi] 调用失败!请求方式:[{}],请求地址:[{}],请求参数:[{}],响应数据: [{}]",
|
||||
request.getMethod(), request.getURI(), reqParam, responseBody);
|
||||
sink.error(new IllegalStateException("[WenDuoDuoPptApi] 调用失败!"));
|
||||
});
|
||||
|
||||
public WenDuoDuoPptApi(String token) {
|
||||
Assert.hasText(token, "token 不能为空");
|
||||
this.webClient = WebClient.builder()
|
||||
.baseUrl(BASE_URL)
|
||||
.defaultHeaders((headers) -> {
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.add(TOKEN_NAME, token);
|
||||
})
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 token
|
||||
*
|
||||
* @param request 请求信息
|
||||
* @return token
|
||||
*/
|
||||
public String createApiToken(CreateTokenRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/user/createApiToken")
|
||||
.header("Api-Key", request.apiKey)
|
||||
.body(Mono.just(request), CreateTokenRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToMono(ApiResponse.class)
|
||||
.<String>handle((response, sink) -> {
|
||||
if (response.code != 0) {
|
||||
sink.error(new IllegalStateException("创建 token 异常," + response.message));
|
||||
return;
|
||||
}
|
||||
sink.next(response.data.get("token").toString());
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建任务
|
||||
*
|
||||
* @param type 类型
|
||||
* @param content 内容
|
||||
* @param files 文件列表
|
||||
* @return 任务 ID
|
||||
* @see <a href="https://docmee.cn/open-platform/api#%E5%88%9B%E5%BB%BA%E4%BB%BB%E5%8A%A1">创建任务</a>
|
||||
*/
|
||||
public ApiResponse createTask(Integer type, String content, List<MultipartFile> files) {
|
||||
MultiValueMap<String, Object> formData = new LinkedMultiValueMap<>();
|
||||
formData.add("type", type);
|
||||
if (content != null) {
|
||||
formData.add("content", content);
|
||||
}
|
||||
if (files != null) {
|
||||
for (MultipartFile file : files) {
|
||||
formData.add("file", file.getResource());
|
||||
}
|
||||
}
|
||||
return this.webClient.post()
|
||||
.uri("/api/ppt/v2/createTask")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA)
|
||||
.body(BodyInserters.fromMultipartData(formData))
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(formData))
|
||||
.bodyToMono(ApiResponse.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取生成选项
|
||||
*
|
||||
* @param lang 语种
|
||||
* @return 生成选项
|
||||
*/
|
||||
public Map<String, Object> getOptions(String lang) {
|
||||
String uri = "/api/ppt/v2/options";
|
||||
if (lang != null) {
|
||||
uri += "?lang=" + lang;
|
||||
}
|
||||
return this.webClient.get()
|
||||
.uri(uri)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(lang))
|
||||
.bodyToMono(new ParameterizedTypeReference<ApiResponse>() {
|
||||
})
|
||||
.<Map<String, Object>>handle((response, sink) -> {
|
||||
if (response.code != 0) {
|
||||
sink.error(new IllegalStateException("获取生成选项异常," + response.message));
|
||||
return;
|
||||
}
|
||||
sink.next(response.data);
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 分页查询 PPT 模板
|
||||
*
|
||||
* @param token 令牌
|
||||
* @param request 请求体
|
||||
* @return 模板列表
|
||||
*/
|
||||
public PagePptTemplateInfo getTemplatePage(TemplateQueryRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/ppt/templates")
|
||||
.bodyValue(request)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToMono(new ParameterizedTypeReference<PagePptTemplateInfo>() {
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成大纲内容
|
||||
*
|
||||
* @return 大纲内容流
|
||||
*/
|
||||
public Flux<Map<String, Object>> createOutline(CreateOutlineRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/ppt/v2/generateContent")
|
||||
.body(Mono.just(request), CreateOutlineRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToFlux(new ParameterizedTypeReference<>() {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 修改大纲内容
|
||||
*
|
||||
* @param request 请求体
|
||||
* @return 大纲内容流
|
||||
*/
|
||||
public Flux<Map<String, Object>> updateOutline(UpdateOutlineRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/ppt/v2/updateContent")
|
||||
.body(Mono.just(request), UpdateOutlineRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToFlux(new ParameterizedTypeReference<>() {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成 PPT
|
||||
*
|
||||
* @return PPT信息
|
||||
*/
|
||||
public PptInfo create(PptCreateRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/ppt/v2/generatePptx")
|
||||
.body(Mono.just(request), PptCreateRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToMono(ApiResponse.class)
|
||||
.<PptInfo>handle((response, sink) -> {
|
||||
if (response.code != 0) {
|
||||
sink.error(new IllegalStateException("生成 PPT 异常," + response.message));
|
||||
return;
|
||||
}
|
||||
sink.next(Objects.requireNonNull(JsonUtils.parseObject(JsonUtils.toJsonString(response.data.get("pptInfo")), PptInfo.class)));
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 Token 请求参数
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record CreateTokenRequest(
|
||||
String apiKey,
|
||||
String uid,
|
||||
Integer limit
|
||||
) {
|
||||
|
||||
public CreateTokenRequest(String apiKey) {
|
||||
this(apiKey, null, null);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* API 通用响应
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record ApiResponse(
|
||||
Integer code,
|
||||
String message,
|
||||
Map<String, Object> data
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建任务
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record CreateTaskRequest(
|
||||
Integer type,
|
||||
String content,
|
||||
List<MultipartFile> files
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成大纲内容请求
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record CreateOutlineRequest(
|
||||
String id,
|
||||
String length,
|
||||
String scene,
|
||||
String audience,
|
||||
String lang,
|
||||
String prompt
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 修改大纲内容请求
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record UpdateOutlineRequest(
|
||||
String id,
|
||||
String markdown,
|
||||
String question
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成 PPT 请求参数
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record PptCreateRequest(
|
||||
String id,
|
||||
String templateId,
|
||||
String markdown
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* PPT 信息
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record PptInfo(
|
||||
String id,
|
||||
String name,
|
||||
String subject,
|
||||
String coverUrl,
|
||||
String fileUrl,
|
||||
String templateId,
|
||||
String pptxProperty,
|
||||
String userId,
|
||||
String userName,
|
||||
int companyId,
|
||||
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss")
|
||||
LocalDateTime updateTime,
|
||||
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss")
|
||||
LocalDateTime createTime,
|
||||
String createUser,
|
||||
String updateUser
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 模板查询请求参数
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record TemplateQueryRequest(
|
||||
int page,
|
||||
int size,
|
||||
Filter filters
|
||||
) {
|
||||
|
||||
/**
|
||||
* 模板查询过滤条件
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record Filter(
|
||||
int type,
|
||||
String category,
|
||||
String style,
|
||||
String themeColor
|
||||
) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* PPT模板分页信息
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record PagePptTemplateInfo(
|
||||
List<PptTemplateInfo> data,
|
||||
String total
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* PPT模板信息
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record PptTemplateInfo(
|
||||
String id,
|
||||
int type,
|
||||
Integer subType,
|
||||
String layout,
|
||||
String category,
|
||||
String style,
|
||||
String themeColor,
|
||||
String lang,
|
||||
boolean animation,
|
||||
String subject,
|
||||
String coverUrl,
|
||||
String fileUrl,
|
||||
List<String> pageCoverUrls,
|
||||
String pptxProperty,
|
||||
int sort,
|
||||
int num,
|
||||
Integer imgNum,
|
||||
int isDeleted,
|
||||
String userId,
|
||||
int companyId,
|
||||
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss")
|
||||
LocalDateTime updateTime,
|
||||
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss")
|
||||
LocalDateTime createTime,
|
||||
String createUser,
|
||||
String updateUser
|
||||
) {
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.xinghuo;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* 讯飞星火 {@link ChatModel} 实现类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class XingHuoChatModel implements ChatModel {
|
||||
|
||||
public static final String BASE_URL = "https://spark-api-open.xf-yun.com";
|
||||
|
||||
public static final String MODEL_DEFAULT = "generalv3.5";
|
||||
|
||||
/**
|
||||
* 兼容 OpenAI 接口,进行复用
|
||||
*/
|
||||
private final OpenAiChatModel openAiChatModel;
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
return openAiChatModel.call(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return openAiChatModel.stream(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return openAiChatModel.getDefaultOptions();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,522 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.xinghuo.api;
|
||||
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.crypto.SecureUtil;
|
||||
import cn.hutool.crypto.digest.HmacAlgorithm;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.io.ByteArrayResource;
|
||||
import org.springframework.http.HttpStatusCode;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.reactive.function.BodyInserters;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* 讯飞智能 PPT 生成 API
|
||||
*
|
||||
* @author xiaoxin
|
||||
* @see <a href="https://www.xfyun.cn/doc/spark/PPTv2.html">智能 PPT 生成 API</a>
|
||||
*/
|
||||
@Slf4j
|
||||
public class XunFeiPptApi {
|
||||
|
||||
public static final String BASE_URL = "https://zwapi.xfyun.cn/api/ppt/v2";
|
||||
private static final String HEADER_APP_ID = "appId";
|
||||
private static final String HEADER_TIMESTAMP = "timestamp";
|
||||
private static final String HEADER_SIGNATURE = "signature";
|
||||
|
||||
private final WebClient webClient;
|
||||
private final String appId;
|
||||
private final String apiSecret;
|
||||
|
||||
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
|
||||
|
||||
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
|
||||
reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
|
||||
log.error("[XunFeiPptApi] 调用失败!请求参数:[{}],响应数据: [{}]", reqParam, responseBody);
|
||||
sink.error(new IllegalStateException("[XunFeiPptApi] 调用失败!"));
|
||||
});
|
||||
|
||||
public XunFeiPptApi(String appId, String apiSecret) {
|
||||
this.appId = appId;
|
||||
this.apiSecret = apiSecret;
|
||||
this.webClient = WebClient.builder()
|
||||
.baseUrl(BASE_URL)
|
||||
.defaultHeaders((headers) -> {
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.add(HEADER_APP_ID, appId);
|
||||
})
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取签名
|
||||
*
|
||||
* @return 签名信息
|
||||
*/
|
||||
private SignatureInfo getSignature() {
|
||||
long timestamp = System.currentTimeMillis() / 1000;
|
||||
String ts = String.valueOf(timestamp);
|
||||
String signature = generateSignature(appId, apiSecret, timestamp);
|
||||
return new SignatureInfo(ts, signature);
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成签名
|
||||
*
|
||||
* @param appId 应用ID
|
||||
* @param apiSecret 应用密钥
|
||||
* @param timestamp 时间戳(秒)
|
||||
* @return 签名
|
||||
*/
|
||||
private String generateSignature(String appId, String apiSecret, long timestamp) {
|
||||
String auth = SecureUtil.md5(appId + timestamp);
|
||||
return SecureUtil.hmac(HmacAlgorithm.HmacSHA1, apiSecret).digestBase64(auth, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 PPT 模板列表
|
||||
*
|
||||
* @param style 风格,如"商务"
|
||||
* @param pageSize 每页数量
|
||||
* @return 模板列表
|
||||
*/
|
||||
public TemplatePageResponse getTemplatePage(String style, Integer pageSize) {
|
||||
SignatureInfo signInfo = getSignature();
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("style", style);
|
||||
requestBody.put("pageSize", ObjUtil.defaultIfNull(pageSize, 20));
|
||||
return this.webClient.post()
|
||||
.uri("/template/list")
|
||||
.header(HEADER_TIMESTAMP, signInfo.timestamp)
|
||||
.header(HEADER_SIGNATURE, signInfo.signature)
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.bodyValue(requestBody)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(requestBody))
|
||||
.bodyToMono(TemplatePageResponse.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建大纲(通过文本)
|
||||
*
|
||||
* @param query 查询文本
|
||||
* @return 大纲创建响应
|
||||
*/
|
||||
public CreateResponse createOutline(String query) {
|
||||
SignatureInfo signInfo = getSignature();
|
||||
MultiValueMap<String, Object> formData = new LinkedMultiValueMap<>();
|
||||
formData.add("query", query);
|
||||
return this.webClient.post()
|
||||
.uri("/createOutline")
|
||||
.header(HEADER_TIMESTAMP, signInfo.timestamp)
|
||||
.header(HEADER_SIGNATURE, signInfo.signature)
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA)
|
||||
.body(BodyInserters.fromMultipartData(formData))
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(formData))
|
||||
.bodyToMono(CreateResponse.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 直接创建 PPT(简化版 - 通过文本)
|
||||
*
|
||||
* @param query 查询文本
|
||||
* @return 创建响应
|
||||
*/
|
||||
public CreateResponse create(String query) {
|
||||
CreatePptRequest request = CreatePptRequest.builder()
|
||||
.query(query)
|
||||
.build();
|
||||
return create(request);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接创建 PPT(简化版 - 通过文件)
|
||||
*
|
||||
* @param file 文件
|
||||
* @param fileName 文件名
|
||||
* @return 创建响应
|
||||
*/
|
||||
public CreateResponse create(MultipartFile file, String fileName) {
|
||||
CreatePptRequest request = CreatePptRequest.builder()
|
||||
.file(file).fileName(fileName).build();
|
||||
return create(request);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接创建 PPT(完整版)
|
||||
*
|
||||
* @param request 请求参数
|
||||
* @return 创建响应
|
||||
*/
|
||||
public CreateResponse create(CreatePptRequest request) {
|
||||
SignatureInfo signInfo = getSignature();
|
||||
MultiValueMap<String, Object> formData = buildCreatePptFormData(request);
|
||||
return this.webClient.post()
|
||||
.uri("/create")
|
||||
.header(HEADER_TIMESTAMP, signInfo.timestamp)
|
||||
.header(HEADER_SIGNATURE, signInfo.signature)
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA)
|
||||
.body(BodyInserters.fromMultipartData(formData))
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(formData))
|
||||
.bodyToMono(CreateResponse.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 通过大纲创建 PPT(简化版)
|
||||
*
|
||||
* @param outline 大纲内容
|
||||
* @param query 查询文本
|
||||
* @return 创建响应
|
||||
*/
|
||||
public CreateResponse createPptByOutline(OutlineData outline, String query) {
|
||||
CreatePptByOutlineRequest request = CreatePptByOutlineRequest.builder()
|
||||
.outline(outline)
|
||||
.query(query)
|
||||
.build();
|
||||
return createPptByOutline(request);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过大纲创建 PPT(完整版)
|
||||
*
|
||||
* @param request 请求参数
|
||||
* @return 创建响应
|
||||
*/
|
||||
public CreateResponse createPptByOutline(CreatePptByOutlineRequest request) {
|
||||
SignatureInfo signInfo = getSignature();
|
||||
return this.webClient.post()
|
||||
.uri("/createPptByOutline")
|
||||
.header(HEADER_TIMESTAMP, signInfo.timestamp)
|
||||
.header(HEADER_SIGNATURE, signInfo.signature)
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.bodyValue(request)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToMono(CreateResponse.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 PPT 生成进度
|
||||
*
|
||||
* @param sid 任务 ID
|
||||
* @return 进度响应
|
||||
*/
|
||||
public ProgressResponse checkProgress(String sid) {
|
||||
SignatureInfo signInfo = getSignature();
|
||||
return this.webClient.get()
|
||||
.uri(uriBuilder -> uriBuilder
|
||||
.path("/progress")
|
||||
.queryParam("sid", sid)
|
||||
.build())
|
||||
.header(HEADER_TIMESTAMP, signInfo.timestamp)
|
||||
.header(HEADER_SIGNATURE, signInfo.signature)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(sid))
|
||||
.bodyToMono(ProgressResponse.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 签名信息
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
private record SignatureInfo(
|
||||
String timestamp,
|
||||
String signature
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 模板列表响应
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record TemplatePageResponse(
|
||||
boolean flag,
|
||||
int code,
|
||||
String desc,
|
||||
Integer count,
|
||||
TemplatePageData data
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 模板列表数据
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record TemplatePageData(
|
||||
String total,
|
||||
List<TemplateInfo> records,
|
||||
Integer pageNum
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 模板信息
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record TemplateInfo(
|
||||
String templateIndexId,
|
||||
Integer pageCount,
|
||||
String type,
|
||||
String color,
|
||||
String industry,
|
||||
String style,
|
||||
String detailImage
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建响应
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record CreateResponse(
|
||||
boolean flag,
|
||||
int code,
|
||||
String desc,
|
||||
Integer count,
|
||||
CreateResponseData data
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建响应数据
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record CreateResponseData(
|
||||
String sid,
|
||||
String coverImgSrc,
|
||||
String title,
|
||||
String subTitle,
|
||||
OutlineData outline
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 大纲数据结构
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record OutlineData(
|
||||
String title,
|
||||
String subTitle,
|
||||
List<Chapter> chapters
|
||||
) {
|
||||
|
||||
/**
|
||||
* 章节结构
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record Chapter(
|
||||
String chapterTitle,
|
||||
List<ChapterContent> chapterContents
|
||||
) {
|
||||
|
||||
/**
|
||||
* 章节内容
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record ChapterContent(
|
||||
String chapterTitle
|
||||
) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 将大纲对象转换为JSON字符串
|
||||
*
|
||||
* @return 大纲JSON字符串
|
||||
*/
|
||||
public String toJsonString() {
|
||||
return JsonUtils.toJsonString(this);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 进度响应
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record ProgressResponse(
|
||||
int code,
|
||||
String desc,
|
||||
ProgressResponseData data
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 进度响应数据
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record ProgressResponseData(
|
||||
int process,
|
||||
String pptId,
|
||||
String pptUrl,
|
||||
String pptStatus,
|
||||
String aiImageStatus,
|
||||
String cardNoteStatus,
|
||||
String errMsg,
|
||||
Integer totalPages,
|
||||
Integer donePages
|
||||
) {
|
||||
|
||||
/**
|
||||
* 是否全部完成
|
||||
*
|
||||
* @return 是否全部完成
|
||||
*/
|
||||
public boolean isAllDone() {
|
||||
return "done".equals(pptStatus)
|
||||
&& ("done".equals(aiImageStatus) || aiImageStatus == null)
|
||||
&& ("done".equals(cardNoteStatus) || cardNoteStatus == null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否失败
|
||||
*
|
||||
* @return 是否失败
|
||||
*/
|
||||
public boolean isFailed() {
|
||||
return "build_failed".equals(pptStatus);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取进度百分比
|
||||
*
|
||||
* @return 进度百分比
|
||||
*/
|
||||
public int getProgressPercent() {
|
||||
if (totalPages == null || totalPages == 0 || donePages == null) {
|
||||
return process; // 兼容旧版返回
|
||||
}
|
||||
return (int) (donePages * 100.0 / totalPages);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过大纲创建 PPT 请求参数
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
@Builder
|
||||
public record CreatePptByOutlineRequest(
|
||||
String query, // 用户生成PPT要求(最多8000字)
|
||||
String outlineSid, // 已生成大纲后,响应返回的请求大纲唯一id
|
||||
OutlineData outline, // 大纲内容
|
||||
String templateId, // 模板ID
|
||||
String businessId, // 业务ID(非必传)
|
||||
String author, // PPT作者名
|
||||
Boolean isCardNote, // 是否生成PPT演讲备注
|
||||
Boolean search, // 是否联网搜索
|
||||
String language, // 语种
|
||||
String fileUrl, // 文件地址
|
||||
String fileName, // 文件名(带文件名后缀)
|
||||
Boolean isFigure, // 是否自动配图
|
||||
String aiImage // ai配图类型:normal、advanced
|
||||
) {
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 构建创建 PPT 的表单数据
|
||||
*
|
||||
* @param request 请求参数
|
||||
* @return 表单数据
|
||||
*/
|
||||
private MultiValueMap<String, Object> buildCreatePptFormData(CreatePptRequest request) {
|
||||
MultiValueMap<String, Object> formData = new LinkedMultiValueMap<>();
|
||||
if (request.file() != null) {
|
||||
try {
|
||||
formData.add("file", new ByteArrayResource(request.file().getBytes()) {
|
||||
@Override
|
||||
public String getFilename() {
|
||||
return request.file().getOriginalFilename();
|
||||
}
|
||||
});
|
||||
} catch (IOException e) {
|
||||
log.error("[XunFeiPptApi] 文件处理失败", e);
|
||||
throw new IllegalStateException("[XunFeiPptApi] 文件处理失败", e);
|
||||
}
|
||||
}
|
||||
Map<String, Object> param = new HashMap<>();
|
||||
addIfPresent(param, "query", request.query());
|
||||
addIfPresent(param, "fileUrl", request.fileUrl());
|
||||
addIfPresent(param, "fileName", request.fileName());
|
||||
addIfPresent(param, "templateId", request.templateId());
|
||||
addIfPresent(param, "businessId", request.businessId());
|
||||
addIfPresent(param, "author", request.author());
|
||||
addIfPresent(param, "isCardNote", request.isCardNote());
|
||||
addIfPresent(param, "search", request.search());
|
||||
addIfPresent(param, "language", request.language());
|
||||
addIfPresent(param, "isFigure", request.isFigure());
|
||||
addIfPresent(param, "aiImage", request.aiImage());
|
||||
param.forEach(formData::add);
|
||||
return formData;
|
||||
}
|
||||
|
||||
public static <K, V> void addIfPresent(Map<K, V> map, K key, V value) {
|
||||
if (ObjUtil.isNull(key) || ObjUtil.isNull(map)) {
|
||||
return;
|
||||
}
|
||||
|
||||
boolean isPresent = false;
|
||||
if (ObjUtil.isNotNull(value)) {
|
||||
if (value instanceof String) {
|
||||
// 字符串:需要有实际内容
|
||||
isPresent = StringUtils.hasText((String) value);
|
||||
} else {
|
||||
// 其他类型:非 null 即视为存在
|
||||
isPresent = true;
|
||||
}
|
||||
}
|
||||
if (isPresent) {
|
||||
map.put(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成PPT请求参数
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
@Builder
|
||||
public record CreatePptRequest(
|
||||
String query, // 用户生成PPT要求(最多8000字)
|
||||
MultipartFile file, // 上传文件
|
||||
String fileUrl, // 文件地址
|
||||
String fileName, // 文件名(带文件名后缀)
|
||||
String templateId, // 模板ID
|
||||
String businessId, // 业务ID(非必传)
|
||||
String author, // PPT作者名
|
||||
Boolean isCardNote, // 是否生成PPT演讲备注
|
||||
Boolean search, // 是否联网搜索
|
||||
String language, // 语种
|
||||
Boolean isFigure, // 是否自动配图
|
||||
String aiImage // ai配图类型:normal、advanced
|
||||
) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
/**
|
||||
* AI 大模型组件,基于 Spring AI 拓展
|
||||
*
|
||||
* models 包路径:
|
||||
* 1. xinghuo 包:【讯飞】星火,自己实现
|
||||
* 2. deepseek 包:【深度求索】DeepSeek,自己实现
|
||||
* 3. doubao 包:【字节豆包】DouBao,自己实现
|
||||
* 4. hunyuan 包:【腾讯混元】HunYuan,自己实现
|
||||
* 5. siliconflow 包:【硅基硅流】SiliconFlow,自己实现
|
||||
* 6. midjourney 包:Midjourney API,对接 https://github.com/novicezk/midjourney-proxy 实现
|
||||
* 7. suno 包:Suno API,对接 https://github.com/gcui-art/suno-api 实现
|
||||
*/
|
||||
package cn.iocoder.yudao.module.ai.framework.ai;
|
||||
@@ -4,7 +4,7 @@ import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.ObjectUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
|
||||
|
||||
@@ -3,8 +3,8 @@ package cn.iocoder.yudao.module.ai.service.chat;
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package cn.iocoder.yudao.module.ai.service.image;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
|
||||
|
||||
@@ -9,9 +9,9 @@ import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.hutool.http.HttpUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageOptions;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowImageOptions;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
||||
|
||||
@@ -3,9 +3,9 @@ package cn.iocoder.yudao.module.ai.service.mindmap;
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.service.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package cn.iocoder.yudao.module.ai.service.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.AiModelFactory;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package cn.iocoder.yudao.module.ai.service.model.tool;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.module.ai.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.framework.security.core.LoginUser;
|
||||
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
||||
|
||||
@@ -6,7 +6,7 @@ import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.ObjectUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.http.HttpUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiMusicPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiMusicUpdateMyReqVO;
|
||||
|
||||
@@ -3,9 +3,10 @@ package cn.iocoder.yudao.module.ai.service.write;
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.dict.core.DictFrameworkUtils;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.ai.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package cn.iocoder.yudao.module.ai.util;
|
||||
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||
import cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder;
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||
import org.springframework.ai.chat.messages.*;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.minimax.MiniMaxChatOptions;
|
||||
import org.springframework.ai.moonshot.MoonshotChatOptions;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Spring AI 工具类
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class AiUtils {
|
||||
|
||||
public static final String TOOL_CONTEXT_LOGIN_USER = "LOGIN_USER";
|
||||
public static final String TOOL_CONTEXT_TENANT_ID = "TENANT_ID";
|
||||
|
||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||
return buildChatOptions(platform, model, temperature, maxTokens, null, null);
|
||||
}
|
||||
|
||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
||||
Set<String> toolNames, Map<String, Object> toolContext) {
|
||||
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
|
||||
.withFunctions(toolNames).withToolContext(toolContext).build();
|
||||
case YI_YAN:
|
||||
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
||||
case ZHI_PU:
|
||||
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.functions(toolNames).toolContext(toolContext).build();
|
||||
case MINI_MAX:
|
||||
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.functions(toolNames).toolContext(toolContext).build();
|
||||
case MOONSHOT:
|
||||
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.functions(toolNames).toolContext(toolContext).build();
|
||||
case OPENAI:
|
||||
case DEEP_SEEK: // 复用 OpenAI 客户端
|
||||
case DOU_BAO: // 复用 OpenAI 客户端
|
||||
case HUN_YUAN: // 复用 OpenAI 客户端
|
||||
case XING_HUO: // 复用 OpenAI 客户端
|
||||
case SILICON_FLOW: // 复用 OpenAI 客户端
|
||||
case BAI_CHUAN: // 复用 OpenAI 客户端
|
||||
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
case AZURE_OPENAI:
|
||||
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
case OLLAMA:
|
||||
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
public static Message buildMessage(String type, String content) {
|
||||
if (MessageType.USER.getValue().equals(type)) {
|
||||
return new UserMessage(content);
|
||||
}
|
||||
if (MessageType.ASSISTANT.getValue().equals(type)) {
|
||||
return new AssistantMessage(content);
|
||||
}
|
||||
if (MessageType.SYSTEM.getValue().equals(type)) {
|
||||
return new SystemMessage(content);
|
||||
}
|
||||
if (MessageType.TOOL.getValue().equals(type)) {
|
||||
throw new UnsupportedOperationException("暂不支持 tool 消息:" + content);
|
||||
}
|
||||
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
|
||||
}
|
||||
|
||||
public static Map<String, Object> buildCommonToolContext() {
|
||||
Map<String, Object> context = new HashMap<>();
|
||||
context.put(TOOL_CONTEXT_LOGIN_USER, SecurityFrameworkUtils.getLoginUser());
|
||||
context.put(TOOL_CONTEXT_TENANT_ID, TenantContextHolder.getTenantId());
|
||||
return context;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import com.azure.ai.openai.OpenAIClientBuilder;
|
||||
import com.azure.core.credential.AzureKeyCredential;
|
||||
import com.azure.core.util.ClientOptions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties.DEFAULT_DEPLOYMENT_NAME;
|
||||
|
||||
/**
|
||||
* {@link AzureOpenAiChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class AzureOpenAIChatModelTests {
|
||||
|
||||
// TODO @芋艿:晚点在调整
|
||||
private final OpenAIClientBuilder openAiApi = new OpenAIClientBuilder()
|
||||
.endpoint("https://eastusprejade.openai.azure.com")
|
||||
.credential(new AzureKeyCredential("xxx"))
|
||||
.clientOptions((new ClientOptions()).setApplicationId("spring-ai"));
|
||||
private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi,
|
||||
AzureOpenAiChatOptions.builder().deploymentName(DEFAULT_DEPLOYMENT_NAME).build());
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.baichuan.BaiChuanChatModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link BaiChuanChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class BaiChuanChatModelTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(BaiChuanChatModel.BASE_URL)
|
||||
.apiKey("sk-61b6766a94c70786ed02673f5e16af3c") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model("Baichuan4-Turbo") // 模型(https://platform.baichuan-ai.com/docs/api)
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 基于 {@link OpenAiChatModel} 集成 Coze 测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class CozeChatModelTests {
|
||||
|
||||
private final OpenAiChatModel chatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl("http://127.0.0.1:3000")
|
||||
.apiKey("app-4hy2d7fJauSbrKbzTKX1afuP") // apiKey
|
||||
.build())
|
||||
.build();
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link DeepSeekChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class DeepSeekChatModelTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(DeepSeekChatModel.BASE_URL)
|
||||
.apiKey("sk-e52047409b144d97b791a6a46a2d") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model("deepseek-chat") // 模型
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 基于 {@link OpenAiChatModel} 集成 Dify 测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class DifyChatModelTests {
|
||||
|
||||
private final OpenAiChatModel chatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl("http://127.0.0.1:3000")
|
||||
.apiKey("app-4hy2d7fJauSbrKbzTKX1afuP") // apiKey
|
||||
.build())
|
||||
.build();
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao.DouBaoChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link DouBaoChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class DouBaoChatModelTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(DouBaoChatModel.BASE_URL)
|
||||
.apiKey("5c1b5747-26d2-4ebd-a4e0-dd0e8d8b4272") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model("doubao-1-5-lite-32k-250115") // 模型(doubao)
|
||||
// .model("deepseek-r1-250120") // 模型(deepseek)
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final DouBaoChatModel chatModel = new DouBaoChatModel(openAiChatModel);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
// TODO @芋艿:因为使用的是 v1 api,导致 deepseek-r1-250120 不返回 think 过程,后续需要优化
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 基于 {@link OpenAiChatModel} 集成 FastGPT 测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class FastGPTChatModelTests {
|
||||
|
||||
private final OpenAiChatModel chatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl("https://cloud.fastgpt.cn/api")
|
||||
.apiKey("fastgpt-aqcc61kFtF8CeaglnGAfQOCIDWwjGdJVJHv6hIlMo28otFlva2aZNK") // apiKey
|
||||
.build())
|
||||
.build();
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.hunyuan.HunYuanChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link HunYuanChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class HunYuanChatModelTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(HunYuanChatModel.BASE_URL)
|
||||
.apiKey("sk-bcd") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(HunYuanChatModel.MODEL_DEFAULT) // 模型
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final HunYuanChatModel chatModel = new HunYuanChatModel(openAiChatModel);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
private final OpenAiChatModel deepSeekOpenAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(HunYuanChatModel.DEEP_SEEK_BASE_URL)
|
||||
.apiKey("sk-abc") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
// .model(HunYuanChatModel.DEEP_SEEK_MODEL_DEFAULT) // 模型("deepseek-v3")
|
||||
.model("deepseek-r1") // 模型("deepseek-r1")
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final HunYuanChatModel deepSeekChatModel = new HunYuanChatModel(deepSeekOpenAiChatModel);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall_deepseek() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = deepSeekChatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream_deekseek() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = deepSeekChatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaModel;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link OllamaChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class LlamaChatModelTests {
|
||||
|
||||
private final OllamaChatModel chatModel = OllamaChatModel.builder()
|
||||
.ollamaApi(new OllamaApi("http://127.0.0.1:11434")) // Ollama 服务地址
|
||||
.defaultOptions(OllamaOptions.builder()
|
||||
.model(OllamaModel.LLAMA3.getName()) // 模型
|
||||
.build())
|
||||
.build();
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.minimax.MiniMaxChatModel;
|
||||
import org.springframework.ai.minimax.MiniMaxChatOptions;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link MiniMaxChatModel} 的集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class MiniMaxChatModelTests {
|
||||
|
||||
private final MiniMaxChatModel chatModel = new MiniMaxChatModel(
|
||||
new MiniMaxApi("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLnjovmlofmlowiLCJVc2VyTmFtZSI6IueOi-aWh-aWjCIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxODk3Mjg3MjQ5NDU2ODA4MzQ2IiwiUGhvbmUiOiIxNTYwMTY5MTM5OSIsIkdyb3VwSUQiOiIxODk3Mjg3MjQ5NDQ4NDE5NzM4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjUtMDMtMTEgMTI6NTI6MDIiLCJUb2tlblR5cGUiOjEsImlzcyI6Im1pbmltYXgifQ.aAuB7gWW_oA4IYhh-CF7c9MfWWxKN49B_HK-DYjXaDwwffhiG-H1571z1WQhp9QytWG-DqgLejneeSxkiq1wQIe3FsEP2wz4BmGBct31LehbJu8ehLxg_vg75Uod1nFAHbm5mZz6JSVLNIlSo87Xr3UtSzJhAXlapEkcqlA4YOzOpKrZ8l5_OJPTORTCmHWZYgJcRS-faNiH62ZnUEHUozesTFhubJHo5GfJCw_edlnmfSUocERV1BjWvenhZ9My-aYXNktcW9WaSj9l6gayV7A0Ium_PL55T9ln1PcI8gayiVUKJGJDoqNyF1AF9_aF9NOKtTnQzwNqnZdlTYH6hw"), // 密钥
|
||||
MiniMaxChatOptions.builder()
|
||||
.model(MiniMaxApi.ChatModel.ABAB_6_5_G_Chat.getValue()) // 模型
|
||||
.build());
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.moonshot.MoonshotChatModel;
|
||||
import org.springframework.ai.moonshot.MoonshotChatOptions;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link org.springframework.ai.moonshot.MoonshotChatModel} 的集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class MoonshotChatModelTests {
|
||||
|
||||
private final MoonshotChatModel chatModel = new MoonshotChatModel(
|
||||
new MoonshotApi("sk-aHYYV1SARscItye5QQRRNbXij4fy65Ee7pNZlC9gsSQnUKXA"), // 密钥
|
||||
MoonshotChatOptions.builder()
|
||||
.model("moonshot-v1-8k") // 模型
|
||||
.build());
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link OllamaChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class OllamaChatModelTests {
|
||||
|
||||
private final OllamaChatModel chatModel = OllamaChatModel.builder()
|
||||
.ollamaApi(new OllamaApi("http://127.0.0.1:11434")) // Ollama 服务地址
|
||||
.defaultOptions(OllamaOptions.builder()
|
||||
// .model("qwen") // 模型(https://ollama.com/library/qwen)
|
||||
.model("deepseek-r1") // 模型(https://ollama.com/library/deepseek-r1)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link OpenAiChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class OpenAIChatModelTests {
|
||||
|
||||
private final OpenAiChatModel chatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl("https://api.holdai.top")
|
||||
.apiKey("sk-aN6nWn3fILjrgLFT0fC4Aa60B72e4253826c77B29dC94f17") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(OpenAiApi.ChatModel.GPT_4_O) // 模型
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link SiliconFlowChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class SiliconFlowChatModelTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(SiliconFlowApiConstants.DEFAULT_BASE_URL)
|
||||
.apiKey("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(SiliconFlowApiConstants.MODEL_DEFAULT) // 模型
|
||||
// .model("deepseek-ai/DeepSeek-R1") // 模型(deepseek-ai/DeepSeek-R1)可用赠费
|
||||
// .model("Pro/deepseek-ai/DeepSeek-R1") // 模型(Pro/deepseek-ai/DeepSeek-R1)需要付费
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final SiliconFlowChatModel chatModel = new SiliconFlowChatModel(openAiChatModel);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link DashScopeChatModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class TongYiChatModelTests {
|
||||
|
||||
private final DashScopeChatModel chatModel = new DashScopeChatModel(
|
||||
new DashScopeApi("sk-7d903764249848cfa912733146da12d1"),
|
||||
DashScopeChatOptions.builder()
|
||||
.withModel("qwen1.5-72b-chat") // 模型
|
||||
// .withModel("deepseek-r1") // 模型(deepseek-r1)
|
||||
// .withModel("deepseek-v3") // 模型(deepseek-v3)
|
||||
// .withModel("deepseek-r1-distill-qwen-1.5b") // 模型(deepseek-r1-distill-qwen-1.5b)
|
||||
.build());
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link XingHuoChatModel} 集成测试
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class XingHuoChatModelTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(XingHuoChatModel.BASE_URL)
|
||||
.apiKey("75b161ed2aef4719b275d6e7f2a4d4cd:YWYxYWI2MTA4ODI2NGZlYTQyNjAzZTcz") // appKey:secretKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model("generalv3.5") // 模型
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final XingHuoChatModel chatModel = new XingHuoChatModel(openAiChatModel);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.qianfan.QianFanChatModel;
|
||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||
import org.springframework.ai.qianfan.api.QianFanApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
// TODO @芋艿:百度千帆 API 提供了 V2 版本,目前 Spring AI 不兼容,可关键 <https://github.com/spring-projects/spring-ai/issues/2179> 进展
|
||||
/**
|
||||
* {@link QianFanChatModel} 的集成测试
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class YiYanChatModelTests {
|
||||
|
||||
private final QianFanChatModel chatModel = new QianFanChatModel(
|
||||
new QianFanApi("qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e"), // 密钥
|
||||
QianFanChatOptions.builder()
|
||||
.model(QianFanApi.ChatModel.ERNIE_4_0_8K_Preview.getValue())
|
||||
.build()
|
||||
);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
// TODO @芋艿:文心一言,只要带上 system message 就报错,已经各种测试,很莫名!
|
||||
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
// TODO @芋艿:文心一言,只要带上 system message 就报错,已经各种测试,很莫名!
|
||||
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link ZhiPuAiChatModel} 的集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class ZhiPuAiChatModelTests {
|
||||
|
||||
private final ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(
|
||||
new ZhiPuAiApi("32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs"), // 密钥
|
||||
ZhiPuAiChatOptions.builder()
|
||||
.model(ZhiPuAiApi.ChatModel.GLM_4.getName()) // 模型
|
||||
.build()
|
||||
);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.image;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link MidjourneyApi} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class MidjourneyApiTests {
|
||||
|
||||
private final MidjourneyApi midjourneyApi = new MidjourneyApi(
|
||||
"https://api.holdai.top/mj", // 链接
|
||||
"sk-aN6nWn3fILjrgLFT0fC4Aa60B72e4253826c77B29dC94f17", // 密钥
|
||||
null);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testImagine() {
|
||||
// 准备参数
|
||||
MidjourneyApi.ImagineRequest request = new MidjourneyApi.ImagineRequest(null,
|
||||
"生成一个小猫,可爱的", null,
|
||||
MidjourneyApi.ImagineRequest.buildState(512, 512, "6.0", MidjourneyApi.ModelEnum.MIDJOURNEY.getModel()));
|
||||
|
||||
// 方法调用
|
||||
MidjourneyApi.SubmitResponse response = midjourneyApi.imagine(request);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testAction() {
|
||||
// 准备参数
|
||||
MidjourneyApi.ActionRequest request = new MidjourneyApi.ActionRequest("1720277033455953",
|
||||
"MJ::JOB::upsample::1::ee267661-ee52-4ced-a530-0343ba95af3b", null);
|
||||
|
||||
// 方法调用
|
||||
MidjourneyApi.SubmitResponse response = midjourneyApi.action(request);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGetTaskList() {
|
||||
// 准备参数。该参数可以通过 MidjourneyApi.SubmitResponse 的 result 获取
|
||||
// String taskId = "1720277033455953";
|
||||
String taskId = "1720277214045971";
|
||||
|
||||
// 方法调用
|
||||
List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(Collections.singletonList(taskId));
|
||||
// 打印结果
|
||||
System.out.println(taskList);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.image;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.openai.OpenAiImageModel;
|
||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
|
||||
/**
|
||||
* {@link OpenAiImageModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class OpenAiImageModelTests {
|
||||
|
||||
private final OpenAiImageModel imageModel = new OpenAiImageModel(OpenAiImageApi.builder()
|
||||
.baseUrl("https://api.holdai.top") // apiKey
|
||||
.apiKey("sk-aN6nWn3fILjrgLFT0fC4Aa60B72e4253826c77B29dC94f17")
|
||||
.build());
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
ImageOptions options = OpenAiImageOptions.builder()
|
||||
.withModel(OpenAiImageApi.ImageModel.DALL_E_2.getValue()) // 这个模型比较便宜
|
||||
.withHeight(256).withWidth(256)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.image;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.qianfan.QianFanImageModel;
|
||||
import org.springframework.ai.qianfan.QianFanImageOptions;
|
||||
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
||||
|
||||
import static cn.iocoder.yudao.module.ai.framework.ai.core.model.image.StabilityAiImageModelTests.viewImage;
|
||||
|
||||
// TODO @芋艿:百度千帆 API 提供了 V2 版本,目前 Spring AI 不兼容,可关键 <https://github.com/spring-projects/spring-ai/issues/2179> 进展
|
||||
|
||||
/**
|
||||
* {@link QianFanImageModel} 集成测试类
|
||||
*/
|
||||
public class QianFanImageTests {
|
||||
|
||||
private final QianFanImageModel imageModel = new QianFanImageModel(
|
||||
new QianFanImageApi("qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e")); // 密钥
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
// 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576
|
||||
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
|
||||
.model(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
|
||||
.width(1024).height(1024)
|
||||
.N(1)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("good", imageOptions);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
String b64Json = response.getResult().getOutput().getB64Json();
|
||||
System.out.println(response);
|
||||
viewImage(b64Json);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.image;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowImageApi;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowImageModel;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowImageOptions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
|
||||
/**
|
||||
* {@link SiliconFlowImageModel} 集成测试
|
||||
*/
|
||||
public class SiliconFlowImageModelTests {
|
||||
|
||||
private final SiliconFlowImageModel imageModel = new SiliconFlowImageModel(
|
||||
new SiliconFlowImageApi("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // 密钥
|
||||
);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
SiliconFlowImageOptions imageOptions = SiliconFlowImageOptions.builder()
|
||||
.model("Kwai-Kolors/Kolors")
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.image;
|
||||
|
||||
import cn.hutool.core.codec.Base64;
|
||||
import cn.hutool.core.thread.ThreadUtil;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
import org.springframework.ai.stabilityai.StabilityAiImageModel;
|
||||
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* {@link StabilityAiImageModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class StabilityAiImageModelTests {
|
||||
|
||||
private final StabilityAiImageModel imageModel = new StabilityAiImageModel(
|
||||
new StabilityAiApi("sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx") // 密钥
|
||||
);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
ImageOptions options = OpenAiImageOptions.builder()
|
||||
.withModel("stable-diffusion-v1-6")
|
||||
.withHeight(320).withWidth(320)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("great wall", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
String b64Json = response.getResult().getOutput().getB64Json();
|
||||
System.out.println(response);
|
||||
viewImage(b64Json);
|
||||
}
|
||||
|
||||
public static void viewImage(String b64Json) {
|
||||
// 创建一个 JFrame
|
||||
JFrame frame = new JFrame("Byte Image Display");
|
||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
||||
frame.setSize(800, 600);
|
||||
|
||||
// 创建一个 JLabel 来显示图片
|
||||
byte[] imageBytes = Base64.decode(b64Json);
|
||||
JLabel label = new JLabel(new ImageIcon(imageBytes));
|
||||
|
||||
// 将 JLabel 添加到 JFrame
|
||||
frame.getContentPane().add(label, BorderLayout.CENTER);
|
||||
|
||||
// 显示 JFrame
|
||||
frame.setVisible(true);
|
||||
ThreadUtil.sleep(1, TimeUnit.HOURS);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.image;
|
||||
|
||||
import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
|
||||
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
|
||||
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
|
||||
/**
|
||||
* {@link DashScopeImageModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class TongYiImagesModelTest {
|
||||
|
||||
private final DashScopeImageModel imageModel = new DashScopeImageModel(
|
||||
new DashScopeImageApi("sk-7d903764249848cfa912733146da12d1"));
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void imageCallTest() {
|
||||
// 准备参数
|
||||
ImageOptions options = DashScopeImageOptions.builder()
|
||||
.withModel("wanx-v1")
|
||||
.withHeight(256).withWidth(256)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.image;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
||||
|
||||
/**
|
||||
* {@link ZhiPuAiImageModel} 集成测试
|
||||
*/
|
||||
public class ZhiPuAiImageModelTests {
|
||||
|
||||
private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(
|
||||
new ZhiPuAiImageApi("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy") // 密钥
|
||||
);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
|
||||
.model(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.mcp;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao.DouBaoChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
|
||||
|
||||
@Disabled
|
||||
public class DouBaoMcpTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(DouBaoChatModel.BASE_URL)
|
||||
.apiKey("5c1b5747-26d2-4ebd-a4e0-dd0e8d8b4272") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model("doubao-1-5-lite-32k-250115") // 模型(doubao)
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final DouBaoChatModel chatModel = new DouBaoChatModel(openAiChatModel);
|
||||
|
||||
private final MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder()
|
||||
.toolObjects(new UserService())
|
||||
.build();
|
||||
|
||||
private final ChatClient chatClient = ChatClient.builder(chatModel)
|
||||
.defaultTools(provider)
|
||||
.build();
|
||||
|
||||
@Test
|
||||
public void testMcpGetUserInfo() {
|
||||
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("目前有哪些工具可以使用")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("小新的年龄是多少")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("获取小新的基本信息")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("小新是什么职业的")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("小新的教育背景")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("小新的兴趣爱好是什么")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
|
||||
}
|
||||
|
||||
|
||||
static class UserService {
|
||||
|
||||
@Tool(name = "getUserAge", description = "获取用户年龄")
|
||||
public String getUserAge(String userName) {
|
||||
return "《" + userName + "》的年龄为:18";
|
||||
}
|
||||
|
||||
@Tool(name = "getUserSex", description = "获取用户性别")
|
||||
public String getUserSex(String userName) {
|
||||
return "《" + userName + "》的性别为:男";
|
||||
}
|
||||
|
||||
@Tool(name = "getUserBasicInfo", description = "获取用户基本信息,包括姓名、年龄、性别等")
|
||||
public String getUserBasicInfo(String userName) {
|
||||
return "《" + userName + "》的基本信息:\n姓名:" + userName + "\n年龄:18\n性别:男\n身高:175cm\n体重:65kg";
|
||||
}
|
||||
|
||||
@Tool(name = "getUserContact", description = "获取用户联系方式,包括电话、邮箱等")
|
||||
public String getUserContact(String userName) {
|
||||
return "《" + userName + "》的联系方式:\n电话:138****1234\n邮箱:" + userName.toLowerCase() + "@example.com\nQQ:123456789";
|
||||
}
|
||||
|
||||
@Tool(name = "getUserAddress", description = "获取用户地址信息")
|
||||
public String getUserAddress(String userName) {
|
||||
return "《" + userName + "》的地址信息:北京市朝阳区科技园区88号";
|
||||
}
|
||||
|
||||
@Tool(name = "getUserJob", description = "获取用户职业信息")
|
||||
public String getUserJob(String userName) {
|
||||
return "《" + userName + "》的职业信息:软件工程师,就职于ABC科技有限公司,工作年限5年";
|
||||
}
|
||||
|
||||
@Tool(name = "getUserHobbies", description = "获取用户兴趣爱好")
|
||||
public String getUserHobbies(String userName) {
|
||||
return "《" + userName + "》的兴趣爱好:编程、阅读、旅游、摄影、打篮球";
|
||||
}
|
||||
|
||||
@Tool(name = "getUserEducation", description = "获取用户教育背景")
|
||||
public String getUserEducation(String userName) {
|
||||
return "《" + userName + "》的教育背景:\n本科:计算机科学与技术专业,北京大学\n硕士:软件工程专业,清华大学";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.music;
|
||||
|
||||
import cn.hutool.core.collection.ListUtil;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.suno.api.SunoApi;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link SunoApi} 集成测试
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
public class SunoApiTests {
|
||||
|
||||
private final SunoApi sunoApi = new SunoApi("https://suno-3tah0ycyt-status2xxs-projects.vercel.app");
|
||||
// private final SunoApi sunoApi = new SunoApi("http://127.0.0.1:3001");
|
||||
|
||||
@Test // 描述模式
|
||||
@Disabled
|
||||
public void testGenerate() {
|
||||
// 准备参数
|
||||
SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
|
||||
"happy music",
|
||||
"chirp-v3-5",
|
||||
false);
|
||||
|
||||
// 调用方法
|
||||
List<SunoApi.MusicData> musicList = sunoApi.generate(generateRequest);
|
||||
// 打印结果
|
||||
System.out.println(musicList);
|
||||
}
|
||||
|
||||
@Test // 歌词模式
|
||||
@Disabled
|
||||
public void testCustomGenerate() {
|
||||
// 准备参数
|
||||
SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
|
||||
"创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。",
|
||||
"Happy",
|
||||
"Happy Song",
|
||||
"chirp-v3.5",
|
||||
false,
|
||||
false);
|
||||
|
||||
// 调用方法
|
||||
List<SunoApi.MusicData> musicList = sunoApi.customGenerate(generateRequest);
|
||||
// 打印结果
|
||||
System.out.println(musicList);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGenerateLyrics() {
|
||||
// 调用方法
|
||||
SunoApi.LyricsData lyricsData = sunoApi.generateLyrics("A soothing lullaby");
|
||||
// 打印结果
|
||||
System.out.println(lyricsData);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGetMusicList() {
|
||||
// 准备参数
|
||||
// String id = "d460ddda-7c87-4f34-b751-419b08a590ca";
|
||||
String id = "584729e5-0fe9-4157-86da-1b4803ff42bf";
|
||||
|
||||
// 调用方法
|
||||
List<SunoApi.MusicData> musicList = sunoApi.getMusicList(ListUtil.of(id));
|
||||
// 打印结果
|
||||
System.out.println(musicList);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGetLimitUsage() {
|
||||
// 调用方法
|
||||
SunoApi.LimitUsageData limitUsageData = sunoApi.getLimitUsage();
|
||||
// 打印结果
|
||||
System.out.println(limitUsageData);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,315 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.ppt.wdd;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.wenduoduo.api.WenDuoDuoPptApi;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* {@link WenDuoDuoPptApi} 集成测试
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
@Disabled
|
||||
public class WenDuoDuoPptApiTests {
|
||||
|
||||
private final String token = ""; // API Token
|
||||
private final WenDuoDuoPptApi wenDuoDuoPptApi = new WenDuoDuoPptApi(token);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCreateApiToken() {
|
||||
// 准备参数
|
||||
String apiKey = "";
|
||||
WenDuoDuoPptApi.CreateTokenRequest request = new WenDuoDuoPptApi.CreateTokenRequest(apiKey);
|
||||
// 调用方法
|
||||
String token = wenDuoDuoPptApi.createApiToken(request);
|
||||
// 打印结果
|
||||
System.out.println(token);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建任务
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCreateTask() {
|
||||
WenDuoDuoPptApi.ApiResponse apiResponse = wenDuoDuoPptApi.createTask(1, "dify 介绍", null);
|
||||
System.out.println(apiResponse);
|
||||
}
|
||||
|
||||
|
||||
@Test // 创建大纲
|
||||
@Disabled
|
||||
public void testGenerateOutlineRequest() {
|
||||
WenDuoDuoPptApi.CreateOutlineRequest request = new WenDuoDuoPptApi.CreateOutlineRequest(
|
||||
"1901539019628613632", "medium", null, null, null, null);
|
||||
// 调用
|
||||
Flux<Map<String, Object>> flux = wenDuoDuoPptApi.createOutline(request);
|
||||
StringBuffer contentBuffer = new StringBuffer();
|
||||
flux.doOnNext(chunk -> {
|
||||
contentBuffer.append(chunk.get("text"));
|
||||
if (Objects.equals(Integer.parseInt(String.valueOf(chunk.get("status"))), 4)) {
|
||||
// status 为 4,最终 markdown 结构树
|
||||
System.out.println(JsonUtils.toJsonString(chunk.get("result")));
|
||||
System.out.println(" ########################################################################");
|
||||
}
|
||||
}).then().block();
|
||||
// 打印结果
|
||||
System.out.println(contentBuffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* 修改大纲
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testUpdateOutlineRequest() {
|
||||
WenDuoDuoPptApi.UpdateOutlineRequest request = new WenDuoDuoPptApi.UpdateOutlineRequest(
|
||||
"1901539019628613632", TEST_OUT_LINE_CONTENT, "精简一点,三个章节即可");
|
||||
// 调用
|
||||
Flux<Map<String, Object>> flux = wenDuoDuoPptApi.updateOutline(request);
|
||||
StringBuffer contentBuffer = new StringBuffer();
|
||||
flux.doOnNext(chunk -> {
|
||||
contentBuffer.append(chunk.get("text"));
|
||||
if (Objects.equals(Integer.parseInt(String.valueOf(chunk.get("status"))), 4)) {
|
||||
// status 为 4,最终 markdown 结构树
|
||||
System.out.println(JsonUtils.toJsonString(chunk.get("result")));
|
||||
System.out.println(" ########################################################################");
|
||||
}
|
||||
}).then().block();
|
||||
// 打印结果
|
||||
System.out.println(contentBuffer);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 PPT 模版分页
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGetPptTemplatePage() {
|
||||
// 准备参数
|
||||
WenDuoDuoPptApi.TemplateQueryRequest.Filter filter = new WenDuoDuoPptApi.TemplateQueryRequest.Filter(
|
||||
1, null, null, null);
|
||||
WenDuoDuoPptApi.TemplateQueryRequest request = new WenDuoDuoPptApi.TemplateQueryRequest(1, 10, filter);
|
||||
// 调用
|
||||
WenDuoDuoPptApi.PagePptTemplateInfo pptTemplatePage = wenDuoDuoPptApi.getTemplatePage(request);
|
||||
// 打印结果
|
||||
System.out.println(pptTemplatePage);
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成 PPT
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGeneratePptx() {
|
||||
// 准备参数
|
||||
WenDuoDuoPptApi.PptCreateRequest request = new WenDuoDuoPptApi.PptCreateRequest("1901539019628613632", "1805081814809960448", TEST_OUT_LINE_CONTENT);
|
||||
// 调用
|
||||
WenDuoDuoPptApi.PptInfo pptInfo = wenDuoDuoPptApi.create(request);
|
||||
// 打印结果
|
||||
System.out.println(pptInfo);
|
||||
}
|
||||
|
||||
private final String TEST_OUT_LINE_CONTENT = """
|
||||
# Dify:新一代AI应用开发平台
|
||||
|
||||
## 1 什么是Dify
|
||||
### 1.1 Dify定义:AI应用开发平台
|
||||
#### 1.1.1 低代码开发
|
||||
Dify是一个低代码AI应用开发平台,旨在简化AI应用的构建过程,让开发者无需编写大量代码即可快速创建各种智能应用。
|
||||
#### 1.1.2 核心功能
|
||||
Dify的核心功能包括数据集成、模型选择、流程编排和应用部署,提供一站式解决方案,加速AI应用的落地和迭代。
|
||||
#### 1.1.3 开源与商业
|
||||
Dify提供开源版本和商业版本,满足不同用户的需求,开源版本适合个人开发者和小型团队,商业版本则提供更强大的功能和技术支持。
|
||||
|
||||
### 1.2 Dify解决的问题:AI开发痛点
|
||||
#### 1.2.1 开发周期长
|
||||
传统AI应用开发周期长,需要大量的人力和时间投入,Dify通过可视化界面和预置组件,大幅缩短开发周期。
|
||||
#### 1.2.2 技术门槛高
|
||||
AI技术门槛高,需要专业的知识和技能,Dify降低技术门槛,让更多开发者能够参与到AI应用的开发中来。
|
||||
#### 1.2.3 部署和维护复杂
|
||||
AI应用的部署和维护复杂,需要专业的运维团队,Dify提供自动化的部署和维护工具,简化流程,降低成本。
|
||||
|
||||
### 1.3 Dify发展历程
|
||||
#### 1.3.1 早期探索
|
||||
Dify的早期版本主要关注于自然语言处理领域的应用,通过集成各种NLP模型,提供文本分类、情感分析等功能。
|
||||
#### 1.3.2 功能扩展
|
||||
随着用户需求的不断增长,Dify的功能逐渐扩展到图像识别、语音识别等领域,支持更多类型的AI应用。
|
||||
#### 1.3.3 生态建设
|
||||
Dify积极建设开发者生态,提供丰富的文档、教程和案例,帮助开发者更好地使用Dify平台,共同推动AI技术的发展。
|
||||
|
||||
## 2 Dify的核心功能
|
||||
### 2.1 数据集成:连接各种数据源
|
||||
#### 2.1.1 支持多种数据源
|
||||
Dify支持连接各种数据源,包括关系型数据库、NoSQL数据库、文件系统、云存储等,满足不同场景的数据需求。
|
||||
#### 2.1.2 数据转换和清洗
|
||||
Dify提供数据转换和清洗功能,可以将不同格式的数据转换为统一的格式,并去除无效数据,提高数据质量。
|
||||
#### 2.1.3 数据安全
|
||||
Dify注重数据安全,采用各种安全措施保护用户的数据,包括数据加密、访问控制、权限管理等。
|
||||
|
||||
### 2.2 模型选择:丰富的AI模型库
|
||||
#### 2.2.1 预置模型
|
||||
Dify预置了丰富的AI模型,包括自然语言处理、图像识别、语音识别等领域的模型,开发者可以直接使用这些模型,无需自行训练,极大的简化了开发流程。
|
||||
#### 2.2.2 自定义模型
|
||||
Dify支持开发者上传自定义模型,满足个性化的需求。开发者可以将自己训练的模型部署到Dify平台上,与其他开发者共享。
|
||||
#### 2.2.3 模型评估
|
||||
Dify提供模型评估功能,可以对不同模型进行评估,选择最优的模型,提高应用性能。
|
||||
|
||||
### 2.3 流程编排:可视化流程设计器
|
||||
#### 2.3.1 可视化界面
|
||||
Dify提供可视化的流程设计器,开发者可以通过拖拽组件的方式,设计AI应用的流程,无需编写代码,简单高效。
|
||||
#### 2.3.2 灵活的流程控制
|
||||
Dify支持灵活的流程控制,可以根据不同的条件执行不同的分支,实现复杂的业务逻辑。
|
||||
#### 2.3.3 实时调试
|
||||
Dify提供实时调试功能,可以在设计流程的过程中,实时查看流程的执行结果,及时发现和解决问题。
|
||||
|
||||
### 2.4 应用部署:一键部署和管理
|
||||
#### 2.4.1 快速部署
|
||||
Dify提供一键部署功能,可以将AI应用快速部署到各种环境,包括本地环境、云环境、容器环境等。
|
||||
#### 2.4.2 自动伸缩
|
||||
Dify支持自动伸缩,可以根据应用的负载自动调整资源,保证应用的稳定性和性能。
|
||||
#### 2.4.3 监控和告警
|
||||
Dify提供监控和告警功能,可以实时监控应用的状态,并在出现问题时及时告警,方便运维人员进行处理。
|
||||
|
||||
## 3 Dify的特点和优势
|
||||
### 3.1 低代码:降低开发门槛
|
||||
#### 3.1.1 可视化开发
|
||||
Dify采用可视化开发模式,开发者无需编写大量代码,只需通过拖拽组件即可完成AI应用的开发,降低了开发门槛。
|
||||
#### 3.1.2 预置组件
|
||||
Dify预置了丰富的组件,包括数据源组件、模型组件、流程控制组件等,开发者可以直接使用这些组件,提高开发效率。
|
||||
#### 3.1.3 减少代码量
|
||||
Dify可以显著减少代码量,降低开发难度,让更多开发者能够参与到AI应用的开发中来。
|
||||
|
||||
### 3.2 灵活:满足不同场景需求
|
||||
#### 3.2.1 支持多种数据源
|
||||
Dify支持多种数据源,可以连接各种数据源,满足不同场景的数据需求。
|
||||
#### 3.2.2 支持自定义模型
|
||||
Dify支持自定义模型,开发者可以将自己训练的模型部署到Dify平台上,满足个性化的需求。
|
||||
#### 3.2.3 灵活的流程控制
|
||||
Dify支持灵活的流程控制,可以根据不同的条件执行不同的分支,实现复杂的业务逻辑。
|
||||
|
||||
### 3.3 高效:加速应用落地
|
||||
#### 3.3.1 快速开发
|
||||
Dify通过可视化界面和预置组件,大幅缩短开发周期,加速AI应用的落地。
|
||||
#### 3.3.2 快速部署
|
||||
Dify提供一键部署功能,可以将AI应用快速部署到各种环境,提高部署效率。
|
||||
#### 3.3.3 自动化运维
|
||||
Dify提供自动化的运维工具,简化运维流程,降低运维成本。
|
||||
|
||||
### 3.4 开放:构建繁荣生态
|
||||
#### 3.4.1 开源社区
|
||||
Dify拥有活跃的开源社区,开发者可以在社区中交流经验、分享资源、共同推动Dify的发展。
|
||||
#### 3.4.2 丰富的文档
|
||||
Dify提供丰富的文档、教程和案例,帮助开发者更好地使用Dify平台。
|
||||
#### 3.4.3 API支持
|
||||
Dify提供API支持,开发者可以通过API将Dify集成到自己的系统中,扩展Dify的功能。
|
||||
|
||||
## 4 Dify的使用场景
|
||||
### 4.1 智能客服:提升客户服务质量
|
||||
#### 4.1.1 自动回复
|
||||
Dify可以用于构建智能客服系统,实现自动回复客户的常见问题,提高客户服务效率。
|
||||
#### 4.1.2 情感分析
|
||||
Dify可以对客户的语音或文本进行情感分析,判断客户的情绪,并根据情绪提供个性化的服务。
|
||||
#### 4.1.3 知识库问答
|
||||
Dify可以构建知识库问答系统,让客户通过提问的方式获取所需的信息,提高客户满意度。
|
||||
|
||||
### 4.2 金融风控:提高风险识别能力
|
||||
#### 4.2.1 欺诈检测
|
||||
Dify可以用于构建金融风控系统,实现欺诈检测,识别可疑交易,降低风险。
|
||||
#### 4.2.2 信用评估
|
||||
Dify可以对用户的信用进行评估,并根据评估结果提供不同的金融服务。
|
||||
#### 4.2.3 反洗钱
|
||||
Dify可以用于反洗钱,识别可疑资金流动,防止犯罪行为。
|
||||
|
||||
### 4.3 智慧医疗:提升医疗服务水平
|
||||
#### 4.3.1 疾病诊断
|
||||
Dify可以用于辅助疾病诊断,提高诊断准确率,缩短诊断时间。
|
||||
#### 4.3.2 药物研发
|
||||
Dify可以用于药物研发,加速新药的发现和开发。
|
||||
#### 4.3.3 智能健康管理
|
||||
Dify可以构建智能健康管理系统,为用户提供个性化的健康建议和服务。
|
||||
|
||||
### 4.4 智慧城市:提升城市管理效率
|
||||
#### 4.4.1 交通优化
|
||||
Dify可以用于交通优化,提高交通效率,缓解交通拥堵。
|
||||
#### 4.4.2 环境监测
|
||||
Dify可以用于环境监测,实时监测空气质量、水质等环境指标,及时发现和解决环境问题。
|
||||
#### 4.4.3 智能安防
|
||||
Dify可以用于智能安防,提高城市安全水平,预防犯罪行为。
|
||||
|
||||
## 5 Dify的成功案例
|
||||
### 5.1 Case 1:某电商平台的智能客服
|
||||
#### 5.1.1 项目背景
|
||||
该电商平台客户服务压力大,人工客服成本高,需要一种智能化的解决方案。
|
||||
#### 5.1.2 解决方案
|
||||
使用Dify构建智能客服系统,实现自动回复客户的常见问题,并根据客户的情绪提供个性化的服务。
|
||||
#### 5.1.3 效果
|
||||
客户服务效率提高50%,客户满意度提高20%,人工客服成本降低30%。
|
||||
|
||||
### 5.2 Case 2:某银行的金融风控系统
|
||||
#### 5.2.1 项目背景
|
||||
该银行面临日益增长的金融风险,需要一种更有效的风险识别和控制手段。
|
||||
#### 5.2.2 解决方案
|
||||
使用Dify构建金融风控系统,实现欺诈检测、信用评估和反洗钱等功能,提高风险识别能力。
|
||||
#### 5.2.3 效果
|
||||
欺诈交易识别率提高40%,信用评估准确率提高30%,洗钱风险降低25%。
|
||||
|
||||
### 5.3 Case 3:某医院的辅助疾病诊断系统
|
||||
#### 5.3.1 项目背景
|
||||
该医院医生工作压力大,疾病诊断准确率有待提高,需要一种辅助诊断工具。
|
||||
#### 5.3.2 解决方案
|
||||
使用Dify构建辅助疾病诊断系统,根据患者的病历和症状,提供诊断建议,提高诊断准确率。
|
||||
#### 5.3.3 效果
|
||||
疾病诊断准确率提高20%,诊断时间缩短15%,医生工作效率提高10%。
|
||||
|
||||
## 6 Dify的未来展望
|
||||
### 6.1 技术升级
|
||||
#### 6.1.1 模型优化
|
||||
Dify将不断优化预置模型,提高模型性能,并支持更多类型的AI模型。
|
||||
#### 6.1.2 流程引擎升级
|
||||
Dify将升级流程引擎,提高流程的灵活性和可扩展性,支持更复杂的业务逻辑。
|
||||
#### 6.1.3 平台性能优化
|
||||
Dify将不断优化平台性能,提高平台的稳定性和可靠性,满足大规模应用的需求。
|
||||
|
||||
### 6.2 生态建设
|
||||
#### 6.2.1 社区建设
|
||||
Dify将继续加强开源社区建设,吸引更多开发者参与,共同推动Dify的发展。
|
||||
#### 6.2.2 合作伙伴拓展
|
||||
Dify将拓展合作伙伴,与更多的企业和机构合作,共同推动AI技术的应用。
|
||||
#### 6.2.3 应用商店
|
||||
Dify将构建应用商店,让开发者可以分享自己的应用,用户可以购买和使用这些应用,构建繁荣的生态系统。
|
||||
|
||||
### 6.3 应用领域拓展
|
||||
#### 6.3.1 智能制造
|
||||
Dify将拓展到智能制造领域,为企业提供智能化的生产管理和质量控制解决方案。
|
||||
#### 6.3.2 智慧农业
|
||||
Dify将拓展到智慧农业领域,为农民提供智能化的种植和养殖管理解决方案。
|
||||
#### 6.3.3 更多领域
|
||||
Dify将拓展到更多领域,为各行各业提供智能化的解决方案,推动社会发展。
|
||||
|
||||
## 7 总结
|
||||
### 7.1 Dify的价值
|
||||
#### 7.1.1 降低AI开发门槛
|
||||
Dify通过低代码的方式,让更多开发者能够参与到AI应用的开发中来。
|
||||
#### 7.1.2 加速AI应用落地
|
||||
Dify提供一站式解决方案,加速AI应用的落地和迭代。
|
||||
#### 7.1.3 构建繁荣的AI生态
|
||||
Dify通过开源社区和应用商店,构建繁荣的AI生态系统。
|
||||
|
||||
### 7.2 共同发展
|
||||
#### 7.2.1 欢迎加入Dify社区
|
||||
欢迎更多开发者加入Dify社区,共同推动Dify的发展。
|
||||
#### 7.2.2 合作共赢
|
||||
期待与更多的企业和机构合作,共同推动AI技术的应用。
|
||||
#### 7.2.3 共创未来
|
||||
让我们一起用AI技术改变世界,共创美好未来。
|
||||
""";
|
||||
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.ppt.xunfei;
|
||||
|
||||
import cn.hutool.core.io.FileUtil;
|
||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.xinghuo.api.XunFeiPptApi;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.mock.web.MockMultipartFile;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
/**
|
||||
* {@link XunFeiPptApi} 集成测试
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
public class XunFeiPptApiTests {
|
||||
|
||||
// 讯飞 API 配置信息,实际使用时请替换为您的应用信息
|
||||
private static final String APP_ID = "6c8ac023";
|
||||
private static final String API_SECRET = "Y2RjM2Q1MWJjZTdkYmFiODc0OGE5NmRk";
|
||||
|
||||
private final XunFeiPptApi xunfeiPptApi = new XunFeiPptApi(APP_ID, API_SECRET);
|
||||
|
||||
/**
|
||||
* 获取 PPT 模板列表
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGetTemplatePage() {
|
||||
// 调用方法
|
||||
XunFeiPptApi.TemplatePageResponse response = xunfeiPptApi.getTemplatePage("商务", 10);
|
||||
// 打印结果
|
||||
System.out.println("模板列表响应:" + JsonUtils.toJsonString(response));
|
||||
|
||||
if (response != null && response.data() != null && response.data().records() != null) {
|
||||
System.out.println("模板总数:" + response.data().total());
|
||||
System.out.println("当前页码:" + response.data().pageNum());
|
||||
System.out.println("模板数量:" + response.data().records().size());
|
||||
|
||||
// 打印第一个模板的信息(如果存在)
|
||||
if (!response.data().records().isEmpty()) {
|
||||
XunFeiPptApi.TemplateInfo firstTemplate = response.data().records().get(0);
|
||||
System.out.println("模板ID:" + firstTemplate.templateIndexId());
|
||||
System.out.println("模板风格:" + firstTemplate.style());
|
||||
System.out.println("模板颜色:" + firstTemplate.color());
|
||||
System.out.println("模板行业:" + firstTemplate.industry());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建大纲(通过文本)
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCreateOutline() {
|
||||
XunFeiPptApi.CreateResponse response = getCreateResponse();
|
||||
// 打印结果
|
||||
System.out.println("创建大纲响应:" + JsonUtils.toJsonString(response));
|
||||
|
||||
// 保存 sid 和 outline 用于后续测试
|
||||
if (response != null && response.data() != null) {
|
||||
System.out.println("sid: " + response.data().sid());
|
||||
if (response.data().outline() != null) {
|
||||
// 使用 OutlineData 的 toJsonString 方法
|
||||
System.out.println("outline: " + response.data().outline().toJsonString());
|
||||
// 将 outline 对象转换为 JSON 字符串,用于后续 createPptByOutline 测试
|
||||
String outlineJson = response.data().outline().toJsonString();
|
||||
System.out.println("可用于 createPptByOutline 的 outline 字符串: " + outlineJson);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建大纲(通过文本)
|
||||
*
|
||||
* @return 创建大纲响应
|
||||
*/
|
||||
private XunFeiPptApi.CreateResponse getCreateResponse() {
|
||||
String param = "智能体平台 Dify 介绍";
|
||||
return xunfeiPptApi.createOutline(param);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过大纲创建 PPT(完整参数)
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCreatePptByOutlineWithFullParams() {
|
||||
// 创建大纲对象
|
||||
XunFeiPptApi.CreateResponse createResponse = getCreateResponse();
|
||||
// 调用方法
|
||||
XunFeiPptApi.CreateResponse response = xunfeiPptApi.createPptByOutline(createResponse.data().outline(), "精简一些,不要超过6个章节");
|
||||
// 打印结果
|
||||
System.out.println("通过大纲创建 PPT 响应:" + JsonUtils.toJsonString(response));
|
||||
|
||||
// 保存sid用于后续进度查询
|
||||
if (response != null && response.data() != null) {
|
||||
System.out.println("sid: " + response.data().sid());
|
||||
if (response.data().coverImgSrc() != null) {
|
||||
System.out.println("封面图片: " + response.data().coverImgSrc());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 PPT 生成进度
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCheckProgress() {
|
||||
// 准备参数 - 使用之前创建 PPT 时返回的 sid
|
||||
String sid = "e96dac09f2ec4ee289f029a5fb874ecd"; // 替换为实际的sid
|
||||
|
||||
// 调用方法
|
||||
XunFeiPptApi.ProgressResponse response = xunfeiPptApi.checkProgress(sid);
|
||||
// 打印结果
|
||||
System.out.println("检查进度响应:" + JsonUtils.toJsonString(response));
|
||||
|
||||
// 安全地访问响应数据
|
||||
if (response != null && response.data() != null) {
|
||||
XunFeiPptApi.ProgressResponseData data = response.data();
|
||||
|
||||
// 打印PPT生成状态
|
||||
System.out.println("PPT 构建状态: " + data.pptStatus());
|
||||
System.out.println("AI 配图状态: " + data.aiImageStatus());
|
||||
System.out.println("演讲备注状态: " + data.cardNoteStatus());
|
||||
|
||||
// 打印进度信息
|
||||
if (data.totalPages() != null && data.donePages() != null) {
|
||||
System.out.println("总页数: " + data.totalPages());
|
||||
System.out.println("已完成页数: " + data.donePages());
|
||||
System.out.println("完成进度: " + data.getProgressPercent() + "%");
|
||||
} else {
|
||||
System.out.println("进度: " + data.process() + "%");
|
||||
}
|
||||
|
||||
// 检查是否完成
|
||||
if (data.isAllDone()) {
|
||||
System.out.println("PPT 生成已完成!");
|
||||
System.out.println("PPT 下载链接: " + data.pptUrl());
|
||||
}
|
||||
// 检查是否失败
|
||||
else if (data.isFailed()) {
|
||||
System.out.println("PPT 生成失败!");
|
||||
System.out.println("错误信息: " + data.errMsg());
|
||||
}
|
||||
// 正在进行中
|
||||
else {
|
||||
System.out.println("PPT 生成中,请稍后再查询...");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 轮询检查 PPT 生成进度直到完成
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testPollCheckProgress() throws InterruptedException {
|
||||
// 准备参数 - 使用之前创建 PP T时返回的 sid
|
||||
String sid = "1690ef6ee0344e72b5c5434f403b8eaa"; // 替换为实际的sid
|
||||
|
||||
// 最大轮询次数
|
||||
int maxPolls = 20;
|
||||
// 轮询间隔(毫秒)- 讯飞 API 限流为 3 秒一次
|
||||
long pollInterval = 3500;
|
||||
|
||||
for (int i = 0; i < maxPolls; i++) {
|
||||
System.out.println("第" + (i + 1) + "次查询进度...");
|
||||
|
||||
// 调用方法
|
||||
XunFeiPptApi.ProgressResponse response = xunfeiPptApi.checkProgress(sid);
|
||||
|
||||
// 安全地访问响应数据
|
||||
if (response != null && response.data() != null) {
|
||||
XunFeiPptApi.ProgressResponseData data = response.data();
|
||||
|
||||
// 打印进度信息
|
||||
System.out.println("PPT 构建状态: " + data.pptStatus());
|
||||
if (data.totalPages() != null && data.donePages() != null) {
|
||||
System.out.println("完成进度: " + data.donePages() + "/" + data.totalPages()
|
||||
+ " (" + data.getProgressPercent() + "%)");
|
||||
}
|
||||
|
||||
// 检查是否完成
|
||||
if (data.isAllDone()) {
|
||||
System.out.println("PPT 生成已完成!");
|
||||
System.out.println("PPT 下载链接: " + data.pptUrl());
|
||||
break;
|
||||
}
|
||||
// 检查是否失败
|
||||
else if (data.isFailed()) {
|
||||
System.out.println("PPT 生成失败!");
|
||||
System.out.println("错误信息: " + data.errMsg());
|
||||
break;
|
||||
}
|
||||
// 正在进行中,继续轮询
|
||||
else {
|
||||
System.out.println("PPT 生成中,等待" + (pollInterval / 1000) + "秒后继续查询...");
|
||||
Thread.sleep(pollInterval);
|
||||
}
|
||||
} else {
|
||||
System.out.println("查询失败,等待" + (pollInterval / 1000) + "秒后重试...");
|
||||
Thread.sleep(pollInterval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接创建 PPT(通过文本)
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCreatePptByText() {
|
||||
// 准备参数
|
||||
String query = "合肥天气趋势分析,包括近5年的气温变化、降水量变化、极端天气事件,以及对城市生活的影响";
|
||||
|
||||
// 调用方法
|
||||
XunFeiPptApi.CreateResponse response = xunfeiPptApi.create(query);
|
||||
// 打印结果
|
||||
System.out.println("直接创建 PPT 响应:" + JsonUtils.toJsonString(response));
|
||||
|
||||
// 保存 sid 用于后续进度查询
|
||||
if (response != null && response.data() != null) {
|
||||
System.out.println("sid: " + response.data().sid());
|
||||
if (response.data().coverImgSrc() != null) {
|
||||
System.out.println("封面图片: " + response.data().coverImgSrc());
|
||||
}
|
||||
System.out.println("标题: " + response.data().title());
|
||||
System.out.println("副标题: " + response.data().subTitle());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接创建 PPT(通过文件)
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCreatePptByFile() {
|
||||
// 准备参数
|
||||
File file = new File("src/test/resources/test.txt"); // 请确保此文件存在
|
||||
MultipartFile multipartFile = convertFileToMultipartFile(file);
|
||||
|
||||
// 调用方法
|
||||
XunFeiPptApi.CreateResponse response = xunfeiPptApi.create(multipartFile, file.getName());
|
||||
// 打印结果
|
||||
System.out.println("通过文件创建PPT响应:" + JsonUtils.toJsonString(response));
|
||||
|
||||
// 保存 sid 用于后续进度查询
|
||||
if (response != null && response.data() != null) {
|
||||
System.out.println("sid: " + response.data().sid());
|
||||
if (response.data().coverImgSrc() != null) {
|
||||
System.out.println("封面图片: " + response.data().coverImgSrc());
|
||||
}
|
||||
System.out.println("标题: " + response.data().title());
|
||||
System.out.println("副标题: " + response.data().subTitle());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接创建 PPT(完整参数)
|
||||
*/
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCreatePptWithFullParams() {
|
||||
// 准备参数
|
||||
String query = "合肥天气趋势分析,包括近 5 年的气温变化、降水量变化、极端天气事件,以及对城市生活的影响";
|
||||
|
||||
// 创建请求对象
|
||||
XunFeiPptApi.CreatePptRequest request = XunFeiPptApi.CreatePptRequest.builder()
|
||||
.query(query)
|
||||
.language("cn")
|
||||
.isCardNote(true)
|
||||
.search(true)
|
||||
.isFigure(true)
|
||||
.aiImage("advanced")
|
||||
.author("测试用户")
|
||||
.build();
|
||||
|
||||
// 调用方法
|
||||
XunFeiPptApi.CreateResponse response = xunfeiPptApi.create(request);
|
||||
// 打印结果
|
||||
System.out.println("使用完整参数创建 PPT 响应:" + JsonUtils.toJsonString(response));
|
||||
|
||||
// 保存 sid 用于后续进度查询
|
||||
if (response != null && response.data() != null) {
|
||||
String sid = response.data().sid();
|
||||
System.out.println("sid: " + sid);
|
||||
if (response.data().coverImgSrc() != null) {
|
||||
System.out.println("封面图片: " + response.data().coverImgSrc());
|
||||
}
|
||||
System.out.println("标题: " + response.data().title());
|
||||
System.out.println("副标题: " + response.data().subTitle());
|
||||
|
||||
// 立即查询一次进度
|
||||
System.out.println("立即查询进度...");
|
||||
XunFeiPptApi.ProgressResponse progressResponse = xunfeiPptApi.checkProgress(sid);
|
||||
if (progressResponse != null && progressResponse.data() != null) {
|
||||
XunFeiPptApi.ProgressResponseData progressData = progressResponse.data();
|
||||
System.out.println("PPT 构建状态: " + progressData.pptStatus());
|
||||
if (progressData.totalPages() != null && progressData.donePages() != null) {
|
||||
System.out.println("完成进度: " + progressData.donePages() + "/" + progressData.totalPages()
|
||||
+ " (" + progressData.getProgressPercent() + "%)");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 File 转换为 MultipartFile
|
||||
*/
|
||||
private MultipartFile convertFileToMultipartFile(File file) {
|
||||
return new MockMultipartFile("file", file.getName(), "text/plain", FileUtil.readBytes(file));
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user