chroma / clients /js /src /embeddings /OpenAIEmbeddingFunction.ts
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame
4.55 kB
import {IEmbeddingFunction} from "./IEmbeddingFunction";
let OpenAIApi: any;
let openAiVersion = null;
let openAiMajorVersion = null;
interface OpenAIAPI {
createEmbedding: (params: {
model: string;
input: string[];
user?: string;
}) => Promise<number[][]>;
}
class OpenAIAPIv3 implements OpenAIAPI {
private readonly configuration: any;
private openai: any;
constructor(configuration: { organization: string, apiKey: string }) {
this.configuration = new OpenAIApi.Configuration({
organization: configuration.organization,
apiKey: configuration.apiKey,
});
this.openai = new OpenAIApi.OpenAIApi(this.configuration);
}
public async createEmbedding(params: {
model: string,
input: string[],
user?: string
}): Promise<number[][]> {
const embeddings: number[][] = [];
const response = await this.openai.createEmbedding({
model: params.model,
input: params.input,
}).catch((error: any) => {
throw error;
});
// @ts-ignore
const data = response.data["data"];
for (let i = 0; i < data.length; i += 1) {
embeddings.push(data[i]["embedding"]);
}
return embeddings
}
}
class OpenAIAPIv4 implements OpenAIAPI {
private readonly apiKey: any;
private openai: any;
constructor(apiKey: any) {
this.apiKey = apiKey;
this.openai = new OpenAIApi({
apiKey: this.apiKey,
});
}
public async createEmbedding(params: {
model: string,
input: string[],
user?: string
}): Promise<number[][]> {
const embeddings: number[][] = [];
const response = await this.openai.embeddings.create(params);
const data = response["data"];
for (let i = 0; i < data.length; i += 1) {
embeddings.push(data[i]["embedding"]);
}
return embeddings
}
}
export class OpenAIEmbeddingFunction implements IEmbeddingFunction {
private api_key: string;
private org_id: string;
private model: string;
private openaiApi?: OpenAIAPI;
constructor({openai_api_key, openai_model, openai_organization_id}: {
openai_api_key: string,
openai_model?: string,
openai_organization_id?: string
}) {
// we used to construct the client here, but we need to async import the types
// for the openai npm package, and the constructor can not be async
this.api_key = openai_api_key;
this.org_id = openai_organization_id || "";
this.model = openai_model || "text-embedding-ada-002";
}
private async loadClient() {
// cache the client
if(this.openaiApi) return;
try {
const { openai, version } = await OpenAIEmbeddingFunction.import();
OpenAIApi = openai;
let versionVar: string = version;
openAiVersion = versionVar.replace(/[^0-9.]/g, '');
openAiMajorVersion = parseInt(openAiVersion.split('.')[0]);
} catch (_a) {
// @ts-ignore
if (_a.code === 'MODULE_NOT_FOUND') {
throw new Error("Please install the openai package to use the OpenAIEmbeddingFunction, `npm install -S openai`");
}
throw _a; // Re-throw other errors
}
if (openAiMajorVersion > 3) {
this.openaiApi = new OpenAIAPIv4(this.api_key);
} else {
this.openaiApi = new OpenAIAPIv3({
organization: this.org_id,
apiKey: this.api_key,
});
}
}
public async generate(texts: string[]): Promise<number[][]> {
await this.loadClient();
return await this.openaiApi!.createEmbedding({
model: this.model,
input: texts,
}).catch((error: any) => {
throw error;
});
}
/** @ignore */
static async import(): Promise<{
// @ts-ignore
openai: typeof import("openai");
version: string;
}> {
try {
// @ts-ignore
const { default: openai } = await import("openai");
// @ts-ignore
const { VERSION } = await import('openai/version');
return { openai, version: VERSION };
} catch (e) {
throw new Error(
"Please install openai as a dependency with, e.g. `yarn add openai`"
);
}
}
}