Spaces:
Running
Running
File size: 4,552 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import {IEmbeddingFunction} from "./IEmbeddingFunction";
let OpenAIApi: any;
let openAiVersion = null;
let openAiMajorVersion = null;
interface OpenAIAPI {
createEmbedding: (params: {
model: string;
input: string[];
user?: string;
}) => Promise<number[][]>;
}
class OpenAIAPIv3 implements OpenAIAPI {
private readonly configuration: any;
private openai: any;
constructor(configuration: { organization: string, apiKey: string }) {
this.configuration = new OpenAIApi.Configuration({
organization: configuration.organization,
apiKey: configuration.apiKey,
});
this.openai = new OpenAIApi.OpenAIApi(this.configuration);
}
public async createEmbedding(params: {
model: string,
input: string[],
user?: string
}): Promise<number[][]> {
const embeddings: number[][] = [];
const response = await this.openai.createEmbedding({
model: params.model,
input: params.input,
}).catch((error: any) => {
throw error;
});
// @ts-ignore
const data = response.data["data"];
for (let i = 0; i < data.length; i += 1) {
embeddings.push(data[i]["embedding"]);
}
return embeddings
}
}
class OpenAIAPIv4 implements OpenAIAPI {
private readonly apiKey: any;
private openai: any;
constructor(apiKey: any) {
this.apiKey = apiKey;
this.openai = new OpenAIApi({
apiKey: this.apiKey,
});
}
public async createEmbedding(params: {
model: string,
input: string[],
user?: string
}): Promise<number[][]> {
const embeddings: number[][] = [];
const response = await this.openai.embeddings.create(params);
const data = response["data"];
for (let i = 0; i < data.length; i += 1) {
embeddings.push(data[i]["embedding"]);
}
return embeddings
}
}
export class OpenAIEmbeddingFunction implements IEmbeddingFunction {
private api_key: string;
private org_id: string;
private model: string;
private openaiApi?: OpenAIAPI;
constructor({openai_api_key, openai_model, openai_organization_id}: {
openai_api_key: string,
openai_model?: string,
openai_organization_id?: string
}) {
// we used to construct the client here, but we need to async import the types
// for the openai npm package, and the constructor can not be async
this.api_key = openai_api_key;
this.org_id = openai_organization_id || "";
this.model = openai_model || "text-embedding-ada-002";
}
private async loadClient() {
// cache the client
if(this.openaiApi) return;
try {
const { openai, version } = await OpenAIEmbeddingFunction.import();
OpenAIApi = openai;
let versionVar: string = version;
openAiVersion = versionVar.replace(/[^0-9.]/g, '');
openAiMajorVersion = parseInt(openAiVersion.split('.')[0]);
} catch (_a) {
// @ts-ignore
if (_a.code === 'MODULE_NOT_FOUND') {
throw new Error("Please install the openai package to use the OpenAIEmbeddingFunction, `npm install -S openai`");
}
throw _a; // Re-throw other errors
}
if (openAiMajorVersion > 3) {
this.openaiApi = new OpenAIAPIv4(this.api_key);
} else {
this.openaiApi = new OpenAIAPIv3({
organization: this.org_id,
apiKey: this.api_key,
});
}
}
public async generate(texts: string[]): Promise<number[][]> {
await this.loadClient();
return await this.openaiApi!.createEmbedding({
model: this.model,
input: texts,
}).catch((error: any) => {
throw error;
});
}
/** @ignore */
static async import(): Promise<{
// @ts-ignore
openai: typeof import("openai");
version: string;
}> {
try {
// @ts-ignore
const { default: openai } = await import("openai");
// @ts-ignore
const { VERSION } = await import('openai/version');
return { openai, version: VERSION };
} catch (e) {
throw new Error(
"Please install openai as a dependency with, e.g. `yarn add openai`"
);
}
}
}
|