File size: 2,970 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import { IEmbeddingFunction } from "./IEmbeddingFunction";

interface CohereAIAPI {
  createEmbedding: (params: {
    model: string;
    input: string[];
  }) => Promise<number[][]>;
}

class CohereAISDK56 implements CohereAIAPI {
  private cohereClient: any;
  private apiKey: string;

  constructor(configuration: { apiKey: string }) {
    this.apiKey = configuration.apiKey;
  }

  private async loadClient() {
    if (this.cohereClient) return;
    //@ts-ignore
    const { default: cohere } = await import("cohere-ai");
    // @ts-ignore
    cohere.init(this.apiKey);
    this.cohereClient = cohere;
  }

  public async createEmbedding(params: {
    model: string;
    input: string[];
  }): Promise<number[][]> {
    await this.loadClient();
    return await this.cohereClient
      .embed({
        texts: params.input,
        model: params.model,
      })
      .then((response: any) => {
        return response.body.embeddings;
      });
  }
}

class CohereAISDK7 implements CohereAIAPI {
  private cohereClient: any;
  private apiKey: string;

  constructor(configuration: { apiKey: string }) {
    this.apiKey = configuration.apiKey;
  }

  private async loadClient() {
    if (this.cohereClient) return;
    //@ts-ignore
    const cohere = await import("cohere-ai").then((cohere) => {
      return cohere;
    });
    // @ts-ignore
    this.cohereClient = new cohere.CohereClient({
      token: this.apiKey,
    });
  }

  public async createEmbedding(params: {
    model: string;
    input: string[];
  }): Promise<number[][]> {
    await this.loadClient();
    return await this.cohereClient
      .embed({ texts: params.input, model: params.model })
      .then((response: any) => {
        return response.embeddings;
      });
  }
}

export class CohereEmbeddingFunction implements IEmbeddingFunction {
  private cohereAiApi?: CohereAIAPI;
  private model: string;
  private apiKey: string;
  constructor({
    cohere_api_key,
    model,
  }: {
    cohere_api_key: string;
    model?: string;
  }) {
    this.model = model || "large";
    this.apiKey = cohere_api_key;
  }

  private async initCohereClient() {
    if (this.cohereAiApi) return;
    try {
      // @ts-ignore
      this.cohereAiApi = await import("cohere-ai").then((cohere) => {
        // @ts-ignore
        if (cohere.CohereClient) {
          return new CohereAISDK7({ apiKey: this.apiKey });
        } else {
          return new CohereAISDK56({ apiKey: this.apiKey });
        }
      });
    } catch (e) {
      // @ts-ignore
      if (e.code === "MODULE_NOT_FOUND") {
        throw new Error(
          "Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`"
        );
      }
      throw e;
    }
  }

  public async generate(texts: string[]): Promise<number[][]> {
    await this.initCohereClient();
    // @ts-ignore
    return await this.cohereAiApi.createEmbedding({
      model: this.model,
      input: texts,
    });
  }
}