|
from loguru import logger |
|
|
|
from app.Services.provider import ServiceProvider |
|
|
|
CURRENT_VERSION = 2 |
|
|
|
services: ServiceProvider | None = None |
|
|
|
|
|
async def migrate_v1_v2(): |
|
logger.info("Migrating from v1 to v2...") |
|
next_id = None |
|
count = 0 |
|
while True: |
|
points, next_id = await services.db_context.scroll_points(next_id, count=100) |
|
for point in points: |
|
count += 1 |
|
logger.info("[{}] Migrating point {}", count, point.id) |
|
if point.url.startswith('/'): |
|
|
|
|
|
point.local = True |
|
await services.db_context.updatePayload(point) |
|
if point.ocr_text is not None: |
|
point.text_contain_vector = services.transformers_service.get_bert_vector(point.ocr_text_lower) |
|
|
|
logger.info("Updating vectors...") |
|
|
|
await services.db_context.updateVectors([t for t in points if t.text_contain_vector is not None]) |
|
if next_id is None: |
|
break |
|
|
|
|
|
async def migrate(from_version: int): |
|
global services |
|
services = ServiceProvider() |
|
await services.onload() |
|
match from_version: |
|
case 1: |
|
await migrate_v1_v2() |
|
case 2: |
|
logger.info("Already up to date.") |
|
case _: |
|
raise ValueError(f"Unknown version {from_version}") |
|
|