@@ -1,10 +1,12 @@
package cn.iocoder.yudao.framework.ai.core.model.suno.api ;
import cn.iocoder.yudao.framework.ai.core.model.suno.SunoConfig ;
import cn.hutool.core.collection.CollUtil ;
import cn.hutool.core.text.StrPool ;
import com.fasterxml.jackson.annotation.JsonInclude ;
import com.fasterxml.jackson.annotation.JsonProperty ;
import lombok.extern.slf4j.Slf4j ;
import org.springframework.core.ParameterizedTypeReference ;
import org.springframework.http.HttpRequest ;
import org.springframework.http.HttpStatusCode ;
import org.springframework.http.MediaType ;
import org.springframework.web.reactive.function.client.ClientResponse ;
@@ -17,11 +19,10 @@ import java.util.function.Predicate;
/**
* Suno API
* <br >
* <b>
* 文档地址: https://github.com/status2xx/suno-api/blob/main/README_CN.md
*
* @A uthor xiaoxin
* @Date 2024/6/3
* @a uthor xiaoxin
*/
@Slf4j
public class SunoApi {
@@ -29,86 +30,88 @@ public class SunoApi {
private final WebClient webClient ;
private final Predicate < HttpStatusCode > STATUS_PREDICATE = status - > ! status . is2xxSuccessful ( ) ;
private final Function < ClientResponse , Mono < ? extends Throwable > > EXCEPTION_FUNCTION = response - > response . bodyToMono ( String . class )
private final Function < Object , Function < ClientResponse , Mono < ? extends Throwable > > > EXCEPTION_FUNCTION = reqParam - > response - > response . bodyToMono ( String . class )
. handle ( ( respBody , sink ) - > {
// TODO @xin: 最好是 request、response 都有哈
log . error ( " 【 suno-api】 调用失败!resp : 【 {}】 " , respBody ) ;
sink . error ( new IllegalStateException ( " 【 suno-api】 调用失败!" ) ) ;
HttpRequest request = response . request ( ) ;
log . error ( " [ suno-api] 调用失败!请求方式:[{}], 请求地址:[{}], 请求参数:[{}], 响应数据 : [ {}] " , request . getMethod ( ) , request . getURI ( ) , reqParam , respBody ) ;
sink . error ( new IllegalStateException ( " [ suno-api] 调用失败!" ) ) ;
} ) ;
public SunoApi ( SunoConfig config ) {
public SunoApi ( String baseUrl ) {
this . webClient = WebClient . builder ( )
. baseUrl ( config . getBaseUrl ( ) )
. baseUrl ( baseUrl )
. defaultHeaders ( ( headers ) - > headers . setContentType ( MediaType . APPLICATION_JSON ) )
. build ( ) ;
}
public List < MusicData > generate ( Suno Request request ) {
public List < MusicData > generate ( MusicGenerate Request request ) {
return this . webClient . post ( )
. uri ( " /api/generate " )
. body ( Mono . just ( request ) , Suno Request. class )
. body ( Mono . just ( request ) , MusicGenerate Request. class )
. retrieve ( )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION )
. bodyToMono ( new ParameterizedTypeReference < List < MusicData > > ( ) { } )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION . apply ( request ) )
. bodyToMono ( new ParameterizedTypeReference < List < MusicData > > ( ) {
} )
. block ( ) ;
}
public List < MusicData > customGenerate ( Suno Request request ) {
public List < MusicData > customGenerate ( MusicGenerate Request request ) {
return this . webClient . post ( )
. uri ( " /api/custom_generate " )
. body ( Mono . just ( request ) , Suno Request. class )
. body ( Mono . just ( request ) , MusicGenerate Request. class )
. retrieve ( )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION )
. bodyToMono ( new ParameterizedTypeReference < List < MusicData > > ( ) { } )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION . apply ( request ) )
. bodyToMono ( new ParameterizedTypeReference < List < MusicData > > ( ) {
} )
. block ( ) ;
}
// TODO @xin: 是不是叫 chatCompletion
public List < MusicData > doChatCompletion ( String prompt ) {
public List < MusicData > chatCompletion ( String prompt ) {
return this . webClient . post ( )
. uri ( " /v1/chat/completions " )
. body ( Mono . just ( new Suno Request( prompt ) ) , Suno Request. class )
. body ( Mono . just ( new MusicGenerate Request( prompt ) ) , MusicGenerate Request. class )
. retrieve ( )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION )
. bodyToMono ( new ParameterizedTypeReference < List < MusicData > > ( ) { } )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION . apply ( prompt ) )
. bodyToMono ( new ParameterizedTypeReference < List < MusicData > > ( ) {
} )
. block ( ) ;
}
public LyricsData generateLyrics ( String prompt ) {
return this . webClient . post ( )
. uri ( " /api/generate_lyrics " )
. body ( Mono . just ( new Suno Request( prompt ) ) , Suno Request. class )
. body ( Mono . just ( new MusicGenerate Request( prompt ) ) , MusicGenerate Request. class )
. retrieve ( )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION . apply ( prompt ) )
. bodyToMono ( LyricsData . class )
. block ( ) ;
}
// TODO @xin:应该传入 List<String> ids
// TODO @xin:方法名,建议使用 getMusicList
public List < MusicData > selectById ( String ids ) {
public List < MusicData > getMusicList ( List < String > ids ) {
return this . webClient . get ( )
. uri ( uriBuilder - > uriBuilder
. path ( " /api/get " )
. queryParam ( " ids " , ids )
. queryParam ( " ids " , CollUtil . join ( ids , StrPool . COMMA ) )
. build ( ) )
. retrieve ( )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION )
. bodyToMono ( new ParameterizedTypeReference < List < MusicData > > ( ) { } )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION . apply ( ids ) )
. bodyToMono ( new ParameterizedTypeReference < List < MusicData > > ( ) {
} )
. block ( ) ;
}
// TODO @xin:方法名,建议使用 getLimitUsage
public LimitData selectLimit ( ) {
public LimitUsageData getLimitUsage ( ) {
return this . webClient . get ( )
. uri ( " /api/get_limit " )
. retrieve ( )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION )
. bodyToMono ( LimitData . class )
. onStatus ( STATUS_PREDICATE , EXCEPTION_FUNCTION . apply ( null ) )
. bodyToMono ( LimitUsage Data . class )
. block ( ) ;
}
// TODO @xin: 可以改成 MusicGenerateRequest
/**
* 根据提示生成音频
*
@@ -121,7 +124,7 @@ public class SunoApi {
* @param makeInstrumental 指示音乐音频是否为定制,如果为 true, 则从歌词生成, 否则从提示生成
*/
@JsonInclude ( value = JsonInclude . Include . NON_NULL )
public record Suno Request(
public record MusicGenerate Request(
String prompt ,
String tags ,
String title ,
@@ -130,15 +133,15 @@ public class SunoApi {
@JsonProperty ( " make_instrumental " ) boolean makeInstrumental
) {
public Suno Request( String prompt ) {
public MusicGenerate Request( String prompt ) {
this ( prompt , null , null , null , false , false ) ;
}
public Suno Request( String prompt , String mv , boolean makeInstrumental ) {
public MusicGenerate Request( String prompt , String mv , boolean makeInstrumental ) {
this ( prompt , null , null , mv , false , makeInstrumental ) ;
}
public Suno Request( String prompt , String mv , String tags , String title ) {
public MusicGenerate Request( String prompt , String mv , String tags , String title ) {
this ( prompt , tags , title , mv , false , false ) ;
}
@@ -154,12 +157,12 @@ public class SunoApi {
* @param audioUrl 音乐音频的 URL
* @param videoUrl 音乐视频的 URL
* @param createdAt 音乐音频的创建时间
* @param modelName
* @param modelName 模型名称
* @param status submitted、queued、streaming、complete
* @param gptDescriptionPrompt
* @param gptDescriptionPrompt 描述词
* @param prompt 生成音乐音频的提示
* @param type
* @param tags
* @param type 操作类型
* @param tags 音乐类型标签
*/
public record MusicData (
String id ,
@@ -195,7 +198,7 @@ public class SunoApi {
/**
* Suno API 响应的限额数据, 目前每日免费50
*/
public record LimitData (
public record LimitUsage Data (
@JsonProperty ( " credits_left " ) Long creditsLeft ,
String period ,
@JsonProperty ( " monthly_limit " ) Long monthlyLimit ,