gradio-cmat / gradio_cmat /gradio_cmat.py
freemt
Fix gr.interface removed load
747c2ab
"""Calculate the corr matrix."""
# pylint: disable=invalid-name
from typing import List, Optional
import numpy as np
from logzero import logger
# from model_pool import load_model_s
from hf_model_s import model_s
model = model_s()
def gradio_cmat(
list1: List[str],
list2_: Optional[List[str]] = None,
) -> np.ndarray:
"""Gen corr matrix given two lists of str.
Args:
list1: list of strings
list2_: list of strings, if None, set to list1
Returns:
numpy.ndarray, (len(list1)xlen(list2))
"""
if not list2_:
list2 = list1[:]
else:
list2 = list2_[:]
try:
vec1 = model.encode(list1)
except Exception as e:
logger.error("mode_s.encode(list1) error: %s", e)
raise
try:
vec2 = model.encode(list2)
except Exception as e:
logger.error("mode_s.encode(list2) error: %s", e)
raise
try:
res = vec1.dot(vec2.T)
except Exception as e:
logger.error("vec1.dot(vec2.T) error: %s", e)
raise
return res