Spaces:
Runtime error
Runtime error
"""Calculate the corr matrix.""" | |
# pylint: disable=invalid-name | |
from typing import List, Optional | |
import numpy as np | |
from logzero import logger | |
# from model_pool import load_model_s | |
from hf_model_s import model_s | |
model = model_s() | |
def gradio_cmat( | |
list1: List[str], | |
list2_: Optional[List[str]] = None, | |
) -> np.ndarray: | |
"""Gen corr matrix given two lists of str. | |
Args: | |
list1: list of strings | |
list2_: list of strings, if None, set to list1 | |
Returns: | |
numpy.ndarray, (len(list1)xlen(list2)) | |
""" | |
if not list2_: | |
list2 = list1[:] | |
else: | |
list2 = list2_[:] | |
try: | |
vec1 = model.encode(list1) | |
except Exception as e: | |
logger.error("mode_s.encode(list1) error: %s", e) | |
raise | |
try: | |
vec2 = model.encode(list2) | |
except Exception as e: | |
logger.error("mode_s.encode(list2) error: %s", e) | |
raise | |
try: | |
res = vec1.dot(vec2.T) | |
except Exception as e: | |
logger.error("vec1.dot(vec2.T) error: %s", e) | |
raise | |
return res | |