Spaces:
Running
Running
File size: 3,507 Bytes
a62d4c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import localforage from 'localforage'
export type modelType = 'inpaint' | 'superResolution'
localforage.config({
name: 'modelCache',
})
export async function saveModel(modelType: modelType, modelBlob: ArrayBuffer) {
await localforage.setItem(getModel(modelType).name, modelBlob)
}
function getModel(modelType: modelType) {
if (modelType === 'inpaint') {
const modelList = [
{
name: 'model',
url: 'https://huggingface.co/lxfater/inpaint-web/resolve/main/migan.onnx',
backupUrl: '',
},
{
name: 'model-perf',
url: 'https://huggingface.co/andraniksargsyan/migan/resolve/main/migan.onnx',
backupUrl: '',
},
{
name: 'migan-pipeline-v2',
url: 'https://huggingface.co/andraniksargsyan/migan/resolve/main/migan_pipeline_v2.onnx',
backupUrl:
'https://worker-share-proxy-01f5.lxfater.workers.dev/andraniksargsyan/migan/resolve/main/migan_pipeline_v2.onnx',
},
]
const currentModel = modelList[2]
return currentModel
}
if (modelType === 'superResolution') {
const modelList = [
{
name: 'realesrgan-x4',
url: 'https://huggingface.co/lxfater/inpaint-web/resolve/main/realesrgan-x4.onnx',
backupUrl:
'https://worker-share-proxy-01f5.lxfater.workers.dev/lxfater/inpaint-web/resolve/main/realesrgan-x4.onnx',
},
]
const currentModel = modelList[0]
return currentModel
}
throw new Error('wrong modelType')
}
export async function loadModel(modelType: modelType): Promise<ArrayBuffer> {
const model = (await localforage.getItem(
getModel(modelType).name
)) as ArrayBuffer
return model
}
export async function modelExists(modelType: modelType) {
const model = await loadModel(modelType)
return model !== null && model !== undefined
}
export async function ensureModel(modelType: modelType) {
if (await modelExists(modelType)) {
return loadModel(modelType)
}
const model = getModel(modelType)
const response = await fetch(model.url)
const buffer = await response.arrayBuffer()
await saveModel(modelType, buffer)
return buffer
}
export async function downloadModel(
modelType: modelType,
setDownloadProgress: (arg0: number) => void
) {
if (await modelExists(modelType)) {
return
}
async function downloadFromUrl(url: string) {
console.log('start download from', url)
setDownloadProgress(0)
const response = await fetch(url)
const fullSize = response.headers.get('content-length')
const reader = response.body!.getReader()
const total: Uint8Array[] = []
let downloaded = 0
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
downloaded += value?.length || 0
if (value) {
total.push(value)
}
setDownloadProgress((downloaded / Number(fullSize)) * 100)
}
const buffer = new Uint8Array(downloaded)
let offset = 0
for (const chunk of total) {
buffer.set(chunk, offset)
offset += chunk.length
}
await saveModel(modelType, buffer)
setDownloadProgress(100)
}
const model = getModel(modelType)
try {
await downloadFromUrl(model.url)
} catch (e) {
if (model.backupUrl) {
try {
await downloadFromUrl(model.backupUrl)
} catch (r) {
alert(`Failed to download the backup model: ${r}`)
}
}
alert(`Failed to download the model, network problem: ${e}`)
}
}
|