| import { LlmIcon } from '@/components/svg-icon'; | |
| import { LlmModelType } from '@/constants/knowledge'; | |
| import { ResponseGetType } from '@/interfaces/database/base'; | |
| import { | |
| IFactory, | |
| IMyLlmValue, | |
| IThirdOAIModelCollection as IThirdAiModelCollection, | |
| IThirdOAIModelCollection, | |
| } from '@/interfaces/database/llm'; | |
| import { | |
| IAddLlmRequestBody, | |
| IDeleteLlmRequestBody, | |
| } from '@/interfaces/request/llm'; | |
| import userService from '@/services/user-service'; | |
| import { sortLLmFactoryListBySpecifiedOrder } from '@/utils/common-util'; | |
| import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; | |
| import { Flex, message } from 'antd'; | |
| import { DefaultOptionType } from 'antd/es/select'; | |
| import { useMemo } from 'react'; | |
| import { useTranslation } from 'react-i18next'; | |
| export const useFetchLlmList = ( | |
| modelType?: LlmModelType, | |
| ): IThirdAiModelCollection => { | |
| const { data } = useQuery({ | |
| queryKey: ['llmList'], | |
| initialData: {}, | |
| queryFn: async () => { | |
| const { data } = await userService.llm_list({ model_type: modelType }); | |
| return data?.data ?? {}; | |
| }, | |
| }); | |
| return data; | |
| }; | |
| export const useSelectLlmOptions = () => { | |
| const llmInfo: IThirdOAIModelCollection = useFetchLlmList(); | |
| const embeddingModelOptions = useMemo(() => { | |
| return Object.entries(llmInfo).map(([key, value]) => { | |
| return { | |
| label: key, | |
| options: value.map((x) => ({ | |
| label: x.llm_name, | |
| value: `${x.llm_name}@${x.fid}`, | |
| disabled: !x.available, | |
| })), | |
| }; | |
| }); | |
| }, [llmInfo]); | |
| return embeddingModelOptions; | |
| }; | |
| const getLLMIconName = (fid: string, llm_name: string) => { | |
| if (fid === 'FastEmbed') { | |
| return llm_name.split('/').at(0) ?? ''; | |
| } | |
| return fid; | |
| }; | |
| export const useSelectLlmOptionsByModelType = () => { | |
| const llmInfo: IThirdOAIModelCollection = useFetchLlmList(); | |
| const groupOptionsByModelType = (modelType: LlmModelType) => { | |
| return Object.entries(llmInfo) | |
| .filter(([, value]) => | |
| modelType ? value.some((x) => x.model_type.includes(modelType)) : true, | |
| ) | |
| .map(([key, value]) => { | |
| return { | |
| label: key, | |
| options: value | |
| .filter( | |
| (x) => | |
| (modelType ? x.model_type.includes(modelType) : true) && | |
| x.available, | |
| ) | |
| .map((x) => ({ | |
| label: ( | |
| <Flex align="center" gap={6}> | |
| <LlmIcon | |
| name={getLLMIconName(x.fid, x.llm_name)} | |
| width={26} | |
| height={26} | |
| size={'small'} | |
| /> | |
| <span>{x.llm_name}</span> | |
| </Flex> | |
| ), | |
| value: `${x.llm_name}@${x.fid}`, | |
| disabled: !x.available, | |
| })), | |
| }; | |
| }) | |
| .filter((x) => x.options.length > 0); | |
| }; | |
| return { | |
| [LlmModelType.Chat]: groupOptionsByModelType(LlmModelType.Chat), | |
| [LlmModelType.Embedding]: groupOptionsByModelType(LlmModelType.Embedding), | |
| [LlmModelType.Image2text]: groupOptionsByModelType(LlmModelType.Image2text), | |
| [LlmModelType.Speech2text]: groupOptionsByModelType( | |
| LlmModelType.Speech2text, | |
| ), | |
| [LlmModelType.Rerank]: groupOptionsByModelType(LlmModelType.Rerank), | |
| [LlmModelType.TTS]: groupOptionsByModelType(LlmModelType.TTS), | |
| }; | |
| }; | |
| export const useComposeLlmOptionsByModelTypes = ( | |
| modelTypes: LlmModelType[], | |
| ) => { | |
| const allOptions = useSelectLlmOptionsByModelType(); | |
| return modelTypes.reduce<DefaultOptionType[]>((pre, cur) => { | |
| const options = allOptions[cur]; | |
| options.forEach((x) => { | |
| const item = pre.find((y) => y.label === x.label); | |
| if (item) { | |
| item.options.push(...x.options); | |
| } else { | |
| pre.push(x); | |
| } | |
| }); | |
| return pre; | |
| }, []); | |
| }; | |
| export const useFetchLlmFactoryList = (): ResponseGetType<IFactory[]> => { | |
| const { data, isFetching: loading } = useQuery({ | |
| queryKey: ['factoryList'], | |
| initialData: [], | |
| gcTime: 0, | |
| queryFn: async () => { | |
| const { data } = await userService.factories_list(); | |
| return data?.data ?? []; | |
| }, | |
| }); | |
| return { data, loading }; | |
| }; | |
| export type LlmItem = { name: string; logo: string } & IMyLlmValue; | |
| export const useFetchMyLlmList = (): ResponseGetType< | |
| Record<string, IMyLlmValue> | |
| > => { | |
| const { data, isFetching: loading } = useQuery({ | |
| queryKey: ['myLlmList'], | |
| initialData: {}, | |
| gcTime: 0, | |
| queryFn: async () => { | |
| const { data } = await userService.my_llm(); | |
| return data?.data ?? {}; | |
| }, | |
| }); | |
| return { data, loading }; | |
| }; | |
| export const useSelectLlmList = () => { | |
| const { data: myLlmList, loading: myLlmListLoading } = useFetchMyLlmList(); | |
| const { data: factoryList, loading: factoryListLoading } = | |
| useFetchLlmFactoryList(); | |
| const nextMyLlmList: Array<LlmItem> = useMemo(() => { | |
| return Object.entries(myLlmList).map(([key, value]) => ({ | |
| name: key, | |
| logo: factoryList.find((x) => x.name === key)?.logo ?? '', | |
| ...value, | |
| })); | |
| }, [myLlmList, factoryList]); | |
| const nextFactoryList = useMemo(() => { | |
| const currentList = factoryList.filter((x) => | |
| Object.keys(myLlmList).every((y) => y !== x.name), | |
| ); | |
| return sortLLmFactoryListBySpecifiedOrder(currentList); | |
| }, [factoryList, myLlmList]); | |
| return { | |
| myLlmList: nextMyLlmList, | |
| factoryList: nextFactoryList, | |
| loading: myLlmListLoading || factoryListLoading, | |
| }; | |
| }; | |
| export interface IApiKeySavingParams { | |
| llm_factory: string; | |
| api_key: string; | |
| llm_name?: string; | |
| model_type?: string; | |
| base_url?: string; | |
| } | |
| export const useSaveApiKey = () => { | |
| const queryClient = useQueryClient(); | |
| const { t } = useTranslation(); | |
| const { | |
| data, | |
| isPending: loading, | |
| mutateAsync, | |
| } = useMutation({ | |
| mutationKey: ['saveApiKey'], | |
| mutationFn: async (params: IApiKeySavingParams) => { | |
| const { data } = await userService.set_api_key(params); | |
| if (data.code === 0) { | |
| message.success(t('message.modified')); | |
| queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); | |
| queryClient.invalidateQueries({ queryKey: ['factoryList'] }); | |
| } | |
| return data.code; | |
| }, | |
| }); | |
| return { data, loading, saveApiKey: mutateAsync }; | |
| }; | |
| export interface ISystemModelSettingSavingParams { | |
| tenant_id: string; | |
| name?: string; | |
| asr_id: string; | |
| embd_id: string; | |
| img2txt_id: string; | |
| llm_id: string; | |
| } | |
| export const useSaveTenantInfo = () => { | |
| const { t } = useTranslation(); | |
| const { | |
| data, | |
| isPending: loading, | |
| mutateAsync, | |
| } = useMutation({ | |
| mutationKey: ['saveTenantInfo'], | |
| mutationFn: async (params: ISystemModelSettingSavingParams) => { | |
| const { data } = await userService.set_tenant_info(params); | |
| if (data.code === 0) { | |
| message.success(t('message.modified')); | |
| } | |
| return data.code; | |
| }, | |
| }); | |
| return { data, loading, saveTenantInfo: mutateAsync }; | |
| }; | |
| export const useAddLlm = () => { | |
| const queryClient = useQueryClient(); | |
| const { t } = useTranslation(); | |
| const { | |
| data, | |
| isPending: loading, | |
| mutateAsync, | |
| } = useMutation({ | |
| mutationKey: ['addLlm'], | |
| mutationFn: async (params: IAddLlmRequestBody) => { | |
| const { data } = await userService.add_llm(params); | |
| if (data.code === 0) { | |
| queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); | |
| queryClient.invalidateQueries({ queryKey: ['factoryList'] }); | |
| message.success(t('message.modified')); | |
| } | |
| return data.code; | |
| }, | |
| }); | |
| return { data, loading, addLlm: mutateAsync }; | |
| }; | |
| export const useDeleteLlm = () => { | |
| const queryClient = useQueryClient(); | |
| const { t } = useTranslation(); | |
| const { | |
| data, | |
| isPending: loading, | |
| mutateAsync, | |
| } = useMutation({ | |
| mutationKey: ['deleteLlm'], | |
| mutationFn: async (params: IDeleteLlmRequestBody) => { | |
| const { data } = await userService.delete_llm(params); | |
| if (data.code === 0) { | |
| queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); | |
| queryClient.invalidateQueries({ queryKey: ['factoryList'] }); | |
| message.success(t('message.deleted')); | |
| } | |
| return data.code; | |
| }, | |
| }); | |
| return { data, loading, deleteLlm: mutateAsync }; | |
| }; | |
| export const useDeleteFactory = () => { | |
| const queryClient = useQueryClient(); | |
| const { t } = useTranslation(); | |
| const { | |
| data, | |
| isPending: loading, | |
| mutateAsync, | |
| } = useMutation({ | |
| mutationKey: ['deleteFactory'], | |
| mutationFn: async (params: IDeleteLlmRequestBody) => { | |
| const { data } = await userService.deleteFactory(params); | |
| if (data.code === 0) { | |
| queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); | |
| queryClient.invalidateQueries({ queryKey: ['factoryList'] }); | |
| message.success(t('message.deleted')); | |
| } | |
| return data.code; | |
| }, | |
| }); | |
| return { data, loading, deleteFactory: mutateAsync }; | |
| }; | |