Spaces:
Running
Running
| 'use client'; | |
| import { useEffect, useState } from 'react'; | |
| import { useSearchParams, useRouter } from 'next/navigation'; | |
| import { defaultJobConfig, defaultDatasetConfig } from './jobConfig'; | |
| import { JobConfig } from '@/types'; | |
| import { objectCopy } from '@/utils/basic'; | |
| import { useNestedState } from '@/utils/hooks'; | |
| import { SelectInput} from '@/components/formInputs'; | |
| import useSettings from '@/hooks/useSettings'; | |
| import useGPUInfo from '@/hooks/useGPUInfo'; | |
| import useDatasetList from '@/hooks/useDatasetList'; | |
| import path from 'path'; | |
| import { TopBar, MainContent } from '@/components/layout'; | |
| import { Button } from '@headlessui/react'; | |
| import { FaChevronLeft } from 'react-icons/fa'; | |
| import SimpleJob from './SimpleJob'; | |
| import AdvancedJob from './AdvancedJob'; | |
| import ErrorBoundary from '@/components/ErrorBoundary'; | |
| import { apiClient } from '@/utils/api'; | |
| const isDev = process.env.NODE_ENV === 'development'; | |
| export default function TrainingForm() { | |
| const router = useRouter(); | |
| const searchParams = useSearchParams(); | |
| const runId = searchParams.get('id'); | |
| const [gpuIDs, setGpuIDs] = useState<string | null>(null); | |
| const { settings, isSettingsLoaded } = useSettings(); | |
| const { gpuList, isGPUInfoLoaded } = useGPUInfo(); | |
| const { datasets, status: datasetFetchStatus } = useDatasetList(); | |
| const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]); | |
| const [showAdvancedView, setShowAdvancedView] = useState(false); | |
| const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig)); | |
| const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); | |
| useEffect(() => { | |
| if (!isSettingsLoaded) return; | |
| if (datasetFetchStatus !== 'success') return; | |
| const datasetOptions = datasets.map(name => ({ value: path.join(settings.DATASETS_FOLDER, name), label: name })); | |
| setDatasetOptions(datasetOptions); | |
| const defaultDatasetPath = defaultDatasetConfig.folder_path; | |
| for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) { | |
| const dataset = jobConfig.config.process[0].datasets[i]; | |
| if (dataset.folder_path === defaultDatasetPath) { | |
| if (datasetOptions.length > 0) { | |
| setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`); | |
| } | |
| } | |
| } | |
| }, [datasets, settings, isSettingsLoaded, datasetFetchStatus]); | |
| useEffect(() => { | |
| if (runId) { | |
| apiClient | |
| .get(`/api/jobs?id=${runId}`) | |
| .then(res => res.data) | |
| .then(data => { | |
| console.log('Training:', data); | |
| setGpuIDs(data.gpu_ids); | |
| setJobConfig(JSON.parse(data.job_config)); | |
| }) | |
| .catch(error => console.error('Error fetching training:', error)); | |
| } | |
| }, [runId]); | |
| useEffect(() => { | |
| if (isGPUInfoLoaded) { | |
| if (gpuIDs === null && gpuList.length > 0) { | |
| setGpuIDs(`${gpuList[0].index}`); | |
| } | |
| } | |
| }, [gpuList, isGPUInfoLoaded]); | |
| useEffect(() => { | |
| if (isSettingsLoaded) { | |
| setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder'); | |
| } | |
| }, [settings, isSettingsLoaded]); | |
| const saveJob = async () => { | |
| if (status === 'saving') return; | |
| setStatus('saving'); | |
| apiClient | |
| .post('/api/jobs', { | |
| id: runId, | |
| name: jobConfig.config.name, | |
| gpu_ids: gpuIDs, | |
| job_config: jobConfig, | |
| }) | |
| .then(res => { | |
| setStatus('success'); | |
| if (runId) { | |
| router.push(`/jobs/${runId}`); | |
| } else { | |
| router.push(`/jobs/${res.data.id}`); | |
| } | |
| }) | |
| .catch(error => { | |
| if (error.response?.status === 409) { | |
| alert('Training name already exists. Please choose a different name.'); | |
| } else { | |
| alert('Failed to save job. Please try again.'); | |
| } | |
| console.log('Error saving training:', error); | |
| }) | |
| .finally(() => | |
| setTimeout(() => { | |
| setStatus('idle'); | |
| }, 2000), | |
| ); | |
| }; | |
| const handleSubmit = async (e: React.FormEvent) => { | |
| e.preventDefault(); | |
| saveJob(); | |
| }; | |
| return ( | |
| <> | |
| <TopBar> | |
| <div> | |
| <Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}> | |
| <FaChevronLeft /> | |
| </Button> | |
| </div> | |
| <div> | |
| <h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1> | |
| </div> | |
| <div className="flex-1"></div> | |
| {showAdvancedView && ( | |
| <> | |
| <div> | |
| <SelectInput | |
| value={`${gpuIDs}`} | |
| onChange={value => setGpuIDs(value)} | |
| options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} | |
| /> | |
| </div> | |
| <div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div> | |
| </> | |
| )} | |
| <div className="pr-2"> | |
| <Button | |
| className="text-gray-200 bg-gray-800 px-3 py-1 rounded-md" | |
| onClick={() => setShowAdvancedView(!showAdvancedView)} | |
| > | |
| {showAdvancedView ? 'Show Simple' : 'Show Advanced'} | |
| </Button> | |
| </div> | |
| <div> | |
| <Button | |
| className="text-gray-200 bg-green-800 px-3 py-1 rounded-md" | |
| onClick={() => saveJob()} | |
| disabled={status === 'saving'} | |
| > | |
| {status === 'saving' ? 'Saving...' : runId ? 'Update Job' : 'Create Job'} | |
| </Button> | |
| </div> | |
| </TopBar> | |
| {showAdvancedView ? ( | |
| <div className="pt-[48px] absolute top-0 left-0 w-full h-full overflow-auto"> | |
| <AdvancedJob | |
| jobConfig={jobConfig} | |
| setJobConfig={setJobConfig} | |
| status={status} | |
| handleSubmit={handleSubmit} | |
| runId={runId} | |
| gpuIDs={gpuIDs} | |
| setGpuIDs={setGpuIDs} | |
| gpuList={gpuList} | |
| datasetOptions={datasetOptions} | |
| settings={settings} | |
| /> | |
| </div> | |
| ) : ( | |
| <MainContent> | |
| <ErrorBoundary fallback={ | |
| <div className="flex items-center justify-center h-64 text-lg text-red-600 font-medium bg-red-100 dark:bg-red-900/20 dark:text-red-400 border border-red-300 dark:border-red-700 rounded-lg"> | |
| Advanced job detected. Please switch to advanced view to continue. | |
| </div> | |
| }> | |
| <SimpleJob | |
| jobConfig={jobConfig} | |
| setJobConfig={setJobConfig} | |
| status={status} | |
| handleSubmit={handleSubmit} | |
| runId={runId} | |
| gpuIDs={gpuIDs} | |
| setGpuIDs={setGpuIDs} | |
| gpuList={gpuList} | |
| datasetOptions={datasetOptions} | |
| /> | |
| </ErrorBoundary> | |
| <div className="pt-20"></div> | |
| </MainContent> | |
| )} | |
| </> | |
| ); | |
| } |