File size: 1,063 Bytes
f2dd44e
 
 
 
 
 
747c2ab
 
f2dd44e
747c2ab
f2dd44e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747c2ab
f2dd44e
 
 
 
 
747c2ab
f2dd44e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Calculate the corr matrix."""
# pylint: disable=invalid-name
from typing import List, Optional

import numpy as np
from logzero import logger
# from model_pool import load_model_s
from hf_model_s import model_s

model = model_s()


def gradio_cmat(
    list1: List[str],
    list2_: Optional[List[str]] = None,
) -> np.ndarray:
    """Gen corr matrix given two lists of str.

    Args:
        list1: list of strings
        list2_: list of strings, if None, set to list1
    Returns:
        numpy.ndarray, (len(list1)xlen(list2))
    """
    if not list2_:
        list2 = list1[:]
    else:
        list2 = list2_[:]

    try:
        vec1 = model.encode(list1)
    except Exception as e:
        logger.error("mode_s.encode(list1) error: %s", e)
        raise

    try:
        vec2 = model.encode(list2)
    except Exception as e:
        logger.error("mode_s.encode(list2) error: %s", e)
        raise

    try:
        res = vec1.dot(vec2.T)
    except Exception as e:
        logger.error("vec1.dot(vec2.T) error: %s", e)
        raise

    return res