Spaces:
Paused
Paused
| import { useCallback } from 'react' | |
| import { | |
| getIncomers, | |
| useReactFlow, | |
| useStoreApi, | |
| } from 'reactflow' | |
| import produce from 'immer' | |
| import { v4 as uuidV4 } from 'uuid' | |
| import { usePathname } from 'next/navigation' | |
| import { useWorkflowStore } from '../store' | |
| import { useNodesSyncDraft } from '../hooks' | |
| import type { Node } from '../types' | |
| import { | |
| NodeRunningStatus, | |
| WorkflowRunningStatus, | |
| } from '../types' | |
| import { DEFAULT_ITER_TIMES } from '../constants' | |
| import { useWorkflowUpdate } from './use-workflow-interactions' | |
| import { useStore as useAppStore } from '@/app/components/app/store' | |
| import type { IOtherOptions } from '@/service/base' | |
| import { ssePost } from '@/service/base' | |
| import { | |
| fetchPublishedWorkflow, | |
| stopWorkflowRun, | |
| } from '@/service/workflow' | |
| import { useFeaturesStore } from '@/app/components/base/features/hooks' | |
| import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' | |
| import { | |
| getProcessedFilesFromResponse, | |
| } from '@/app/components/base/file-uploader/utils' | |
| export const useWorkflowRun = () => { | |
| const store = useStoreApi() | |
| const workflowStore = useWorkflowStore() | |
| const reactflow = useReactFlow() | |
| const featuresStore = useFeaturesStore() | |
| const { doSyncWorkflowDraft } = useNodesSyncDraft() | |
| const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() | |
| const pathname = usePathname() | |
| const handleBackupDraft = useCallback(() => { | |
| const { | |
| getNodes, | |
| edges, | |
| } = store.getState() | |
| const { getViewport } = reactflow | |
| const { | |
| backupDraft, | |
| setBackupDraft, | |
| environmentVariables, | |
| } = workflowStore.getState() | |
| const { features } = featuresStore!.getState() | |
| if (!backupDraft) { | |
| setBackupDraft({ | |
| nodes: getNodes(), | |
| edges, | |
| viewport: getViewport(), | |
| features, | |
| environmentVariables, | |
| }) | |
| doSyncWorkflowDraft() | |
| } | |
| }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft]) | |
| const handleLoadBackupDraft = useCallback(() => { | |
| const { | |
| backupDraft, | |
| setBackupDraft, | |
| setEnvironmentVariables, | |
| } = workflowStore.getState() | |
| if (backupDraft) { | |
| const { | |
| nodes, | |
| edges, | |
| viewport, | |
| features, | |
| environmentVariables, | |
| } = backupDraft | |
| handleUpdateWorkflowCanvas({ | |
| nodes, | |
| edges, | |
| viewport, | |
| }) | |
| setEnvironmentVariables(environmentVariables) | |
| featuresStore!.setState({ features }) | |
| setBackupDraft(undefined) | |
| } | |
| }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore]) | |
| const handleRun = useCallback(async ( | |
| params: any, | |
| callback?: IOtherOptions, | |
| ) => { | |
| const { | |
| getNodes, | |
| setNodes, | |
| } = store.getState() | |
| const newNodes = produce(getNodes(), (draft) => { | |
| draft.forEach((node) => { | |
| node.data.selected = false | |
| node.data._runningStatus = undefined | |
| }) | |
| }) | |
| setNodes(newNodes) | |
| await doSyncWorkflowDraft() | |
| const { | |
| onWorkflowStarted, | |
| onWorkflowFinished, | |
| onNodeStarted, | |
| onNodeFinished, | |
| onIterationStart, | |
| onIterationNext, | |
| onIterationFinish, | |
| onError, | |
| ...restCallback | |
| } = callback || {} | |
| workflowStore.setState({ historyWorkflowData: undefined }) | |
| const appDetail = useAppStore.getState().appDetail | |
| const workflowContainer = document.getElementById('workflow-container') | |
| const { | |
| clientWidth, | |
| clientHeight, | |
| } = workflowContainer! | |
| let url = '' | |
| if (appDetail?.mode === 'advanced-chat') | |
| url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run` | |
| if (appDetail?.mode === 'workflow') | |
| url = `/apps/${appDetail.id}/workflows/draft/run` | |
| let prevNodeId = '' | |
| const { | |
| setWorkflowRunningData, | |
| } = workflowStore.getState() | |
| setWorkflowRunningData({ | |
| result: { | |
| status: WorkflowRunningStatus.Running, | |
| }, | |
| tracing: [], | |
| resultText: '', | |
| }) | |
| let ttsUrl = '' | |
| let ttsIsPublic = false | |
| if (params.token) { | |
| ttsUrl = '/text-to-audio' | |
| ttsIsPublic = true | |
| } | |
| else if (params.appId) { | |
| if (pathname.search('explore/installed') > -1) | |
| ttsUrl = `/installed-apps/${params.appId}/text-to-audio` | |
| else | |
| ttsUrl = `/apps/${params.appId}/text-to-audio` | |
| } | |
| const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', (_: any): any => {}) | |
| ssePost( | |
| url, | |
| { | |
| body: params, | |
| }, | |
| { | |
| onWorkflowStarted: (params) => { | |
| const { task_id, data } = params | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| setIterParallelLogMap, | |
| } = workflowStore.getState() | |
| const { | |
| edges, | |
| setEdges, | |
| } = store.getState() | |
| setIterParallelLogMap(new Map()) | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| draft.task_id = task_id | |
| draft.result = { | |
| ...draft?.result, | |
| ...data, | |
| status: WorkflowRunningStatus.Running, | |
| } | |
| })) | |
| const newEdges = produce(edges, (draft) => { | |
| draft.forEach((edge) => { | |
| edge.data = { | |
| ...edge.data, | |
| _run: false, | |
| } | |
| }) | |
| }) | |
| setEdges(newEdges) | |
| if (onWorkflowStarted) | |
| onWorkflowStarted(params) | |
| }, | |
| onWorkflowFinished: (params) => { | |
| const { data } = params | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| } = workflowStore.getState() | |
| const isStringOutput = data.outputs && Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string' | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| draft.result = { | |
| ...draft.result, | |
| ...data, | |
| files: getProcessedFilesFromResponse(data.files || []), | |
| } as any | |
| if (isStringOutput) { | |
| draft.resultTabActive = true | |
| draft.resultText = data.outputs[Object.keys(data.outputs)[0]] | |
| } | |
| })) | |
| prevNodeId = '' | |
| if (onWorkflowFinished) | |
| onWorkflowFinished(params) | |
| }, | |
| onError: (params) => { | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| } = workflowStore.getState() | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| draft.result = { | |
| ...draft.result, | |
| status: WorkflowRunningStatus.Failed, | |
| } | |
| })) | |
| if (onError) | |
| onError(params) | |
| }, | |
| onNodeStarted: (params) => { | |
| const { data } = params | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| iterParallelLogMap, | |
| setIterParallelLogMap, | |
| } = workflowStore.getState() | |
| const { | |
| getNodes, | |
| setNodes, | |
| edges, | |
| setEdges, | |
| transform, | |
| } = store.getState() | |
| const nodes = getNodes() | |
| const node = nodes.find(node => node.id === data.node_id) | |
| if (node?.parentId) { | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| const tracing = draft.tracing! | |
| const iterations = tracing.find(trace => trace.node_id === node?.parentId) | |
| const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] | |
| if (!data.parallel_run_id) { | |
| currIteration?.push({ | |
| ...data, | |
| status: NodeRunningStatus.Running, | |
| } as any) | |
| } | |
| else { | |
| if (!iterParallelLogMap.has(data.parallel_run_id)) | |
| iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) | |
| else | |
| iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) | |
| setIterParallelLogMap(iterParallelLogMap) | |
| if (iterations) | |
| iterations.details = Array.from(iterParallelLogMap.values()) | |
| } | |
| })) | |
| } | |
| else { | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| draft.tracing!.push({ | |
| ...data, | |
| status: NodeRunningStatus.Running, | |
| } as any) | |
| })) | |
| const { | |
| setViewport, | |
| } = reactflow | |
| const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id) | |
| const currentNode = nodes[currentNodeIndex] | |
| const position = currentNode.position | |
| const zoom = transform[2] | |
| if (!currentNode.parentId) { | |
| setViewport({ | |
| x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom, | |
| y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom, | |
| zoom: transform[2], | |
| }) | |
| } | |
| const newNodes = produce(nodes, (draft) => { | |
| draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running | |
| }) | |
| setNodes(newNodes) | |
| const incomeNodesId = getIncomers({ id: data.node_id } as Node, newNodes, edges).filter(node => node.data._runningStatus === NodeRunningStatus.Succeeded).map(node => node.id) | |
| const newEdges = produce(edges, (draft) => { | |
| draft.forEach((edge) => { | |
| if (edge.target === data.node_id && incomeNodesId.includes(edge.source)) | |
| edge.data = { ...edge.data, _run: true } as any | |
| }) | |
| }) | |
| setEdges(newEdges) | |
| } | |
| if (onNodeStarted) | |
| onNodeStarted(params) | |
| }, | |
| onNodeFinished: (params) => { | |
| const { data } = params | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| iterParallelLogMap, | |
| setIterParallelLogMap, | |
| } = workflowStore.getState() | |
| const { | |
| getNodes, | |
| setNodes, | |
| } = store.getState() | |
| const nodes = getNodes() | |
| const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId | |
| if (nodeParentId) { | |
| if (!data.execution_metadata.parallel_mode_run_id) { | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| const tracing = draft.tracing! | |
| const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node | |
| if (iterations && iterations.details) { | |
| const iterationIndex = data.execution_metadata?.iteration_index || 0 | |
| if (!iterations.details[iterationIndex]) | |
| iterations.details[iterationIndex] = [] | |
| const currIteration = iterations.details[iterationIndex] | |
| const nodeIndex = currIteration.findIndex(node => | |
| node.node_id === data.node_id && ( | |
| node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), | |
| ) | |
| if (nodeIndex !== -1) { | |
| currIteration[nodeIndex] = { | |
| ...currIteration[nodeIndex], | |
| ...data, | |
| } as any | |
| } | |
| else { | |
| currIteration.push({ | |
| ...data, | |
| } as any) | |
| } | |
| } | |
| })) | |
| } | |
| else { | |
| // open parallel mode | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| const tracing = draft.tracing! | |
| const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node | |
| if (iterations && iterations.details) { | |
| const iterRunID = data.execution_metadata?.parallel_mode_run_id | |
| const currIteration = iterParallelLogMap.get(iterRunID) | |
| const nodeIndex = currIteration?.findIndex(node => | |
| node.node_id === data.node_id && ( | |
| node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id), | |
| ) | |
| if (currIteration) { | |
| if (nodeIndex !== undefined && nodeIndex !== -1) { | |
| currIteration[nodeIndex] = { | |
| ...currIteration[nodeIndex], | |
| ...data, | |
| } as any | |
| } | |
| else { | |
| currIteration.push({ | |
| ...data, | |
| } as any) | |
| } | |
| } | |
| setIterParallelLogMap(iterParallelLogMap) | |
| iterations.details = Array.from(iterParallelLogMap.values()) | |
| } | |
| })) | |
| } | |
| } | |
| else { | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| const currentIndex = draft.tracing!.findIndex((trace) => { | |
| if (!trace.execution_metadata?.parallel_id) | |
| return trace.node_id === data.node_id | |
| return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id | |
| }) | |
| if (currentIndex > -1 && draft.tracing) { | |
| draft.tracing[currentIndex] = { | |
| ...(draft.tracing[currentIndex].extras | |
| ? { extras: draft.tracing[currentIndex].extras } | |
| : {}), | |
| ...data, | |
| } as any | |
| } | |
| })) | |
| const newNodes = produce(nodes, (draft) => { | |
| const currentNode = draft.find(node => node.id === data.node_id)! | |
| currentNode.data._runningStatus = data.status as any | |
| }) | |
| setNodes(newNodes) | |
| prevNodeId = data.node_id | |
| } | |
| if (onNodeFinished) | |
| onNodeFinished(params) | |
| }, | |
| onIterationStart: (params) => { | |
| const { data } = params | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| setIterTimes, | |
| } = workflowStore.getState() | |
| const { | |
| getNodes, | |
| setNodes, | |
| edges, | |
| setEdges, | |
| transform, | |
| } = store.getState() | |
| const nodes = getNodes() | |
| setIterTimes(DEFAULT_ITER_TIMES) | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| draft.tracing!.push({ | |
| ...data, | |
| status: NodeRunningStatus.Running, | |
| details: [], | |
| } as any) | |
| })) | |
| const { | |
| setViewport, | |
| } = reactflow | |
| const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id) | |
| const currentNode = nodes[currentNodeIndex] | |
| const position = currentNode.position | |
| const zoom = transform[2] | |
| if (!currentNode.parentId) { | |
| setViewport({ | |
| x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom, | |
| y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom, | |
| zoom: transform[2], | |
| }) | |
| } | |
| const newNodes = produce(nodes, (draft) => { | |
| draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running | |
| draft[currentNodeIndex].data._iterationLength = data.metadata.iterator_length | |
| }) | |
| setNodes(newNodes) | |
| const newEdges = produce(edges, (draft) => { | |
| const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId) | |
| if (edge) | |
| edge.data = { ...edge.data, _run: true } as any | |
| }) | |
| setEdges(newEdges) | |
| if (onIterationStart) | |
| onIterationStart(params) | |
| }, | |
| onIterationNext: (params) => { | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| iterTimes, | |
| setIterTimes, | |
| } = workflowStore.getState() | |
| const { data } = params | |
| const { | |
| getNodes, | |
| setNodes, | |
| } = store.getState() | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id) | |
| if (iteration) { | |
| if (iteration.details!.length >= iteration.metadata.iterator_length!) | |
| return | |
| } | |
| if (!data.parallel_mode_run_id) | |
| iteration?.details!.push([]) | |
| })) | |
| const nodes = getNodes() | |
| const newNodes = produce(nodes, (draft) => { | |
| const currentNode = draft.find(node => node.id === data.node_id)! | |
| currentNode.data._iterationIndex = iterTimes | |
| setIterTimes(iterTimes + 1) | |
| }) | |
| setNodes(newNodes) | |
| if (onIterationNext) | |
| onIterationNext(params) | |
| }, | |
| onIterationFinish: (params) => { | |
| const { data } = params | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| setIterTimes, | |
| } = workflowStore.getState() | |
| const { | |
| getNodes, | |
| setNodes, | |
| } = store.getState() | |
| const nodes = getNodes() | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| const tracing = draft.tracing! | |
| const currIterationNode = tracing.find(trace => trace.node_id === data.node_id) | |
| if (currIterationNode) { | |
| Object.assign(currIterationNode, { | |
| ...data, | |
| status: NodeRunningStatus.Succeeded, | |
| }) | |
| } | |
| })) | |
| setIterTimes(DEFAULT_ITER_TIMES) | |
| const newNodes = produce(nodes, (draft) => { | |
| const currentNode = draft.find(node => node.id === data.node_id)! | |
| currentNode.data._runningStatus = data.status | |
| }) | |
| setNodes(newNodes) | |
| prevNodeId = data.node_id | |
| if (onIterationFinish) | |
| onIterationFinish(params) | |
| }, | |
| onParallelBranchStarted: (params) => { | |
| // console.log(params, 'parallel start') | |
| }, | |
| onParallelBranchFinished: (params) => { | |
| // console.log(params, 'finished') | |
| }, | |
| onTextChunk: (params) => { | |
| const { data: { text } } = params | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| } = workflowStore.getState() | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| draft.resultTabActive = true | |
| draft.resultText += text | |
| })) | |
| }, | |
| onTextReplace: (params) => { | |
| const { data: { text } } = params | |
| const { | |
| workflowRunningData, | |
| setWorkflowRunningData, | |
| } = workflowStore.getState() | |
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |
| draft.resultText = text | |
| })) | |
| }, | |
| onTTSChunk: (messageId: string, audio: string, audioType?: string) => { | |
| if (!audio || audio === '') | |
| return | |
| player.playAudioWithAudio(audio, true) | |
| AudioPlayerManager.getInstance().resetMsgId(messageId) | |
| }, | |
| onTTSEnd: (messageId: string, audio: string, audioType?: string) => { | |
| player.playAudioWithAudio(audio, false) | |
| }, | |
| ...restCallback, | |
| }, | |
| ) | |
| }, [store, reactflow, workflowStore, doSyncWorkflowDraft]) | |
| const handleStopRun = useCallback((taskId: string) => { | |
| const appId = useAppStore.getState().appDetail?.id | |
| stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`) | |
| }, []) | |
| const handleRestoreFromPublishedWorkflow = useCallback(async () => { | |
| const appDetail = useAppStore.getState().appDetail | |
| const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`) | |
| if (publishedWorkflow) { | |
| const nodes = publishedWorkflow.graph.nodes | |
| const edges = publishedWorkflow.graph.edges | |
| const viewport = publishedWorkflow.graph.viewport! | |
| handleUpdateWorkflowCanvas({ | |
| nodes, | |
| edges, | |
| viewport, | |
| }) | |
| featuresStore?.setState({ features: publishedWorkflow.features }) | |
| workflowStore.getState().setPublishedAt(publishedWorkflow.created_at) | |
| workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || []) | |
| } | |
| }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore]) | |
| return { | |
| handleBackupDraft, | |
| handleLoadBackupDraft, | |
| handleRun, | |
| handleStopRun, | |
| handleRestoreFromPublishedWorkflow, | |
| } | |
| } | |