tanbushi commited on
Commit
bad204d
·
1 Parent(s): 10bcd3f

langchain message to openai

Browse files
Files changed (1) hide show
  1. routers/openai_v1.py +31 -18
routers/openai_v1.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import APIRouter, Depends
2
- import os
3
- from langchain_core.output_parsers import StrOutputParser
4
  from langchain_core.prompts import ChatPromptTemplate
 
 
5
 
6
  from global_state import get
7
  from db.tbs_db import TbsDb
@@ -45,15 +45,16 @@ async def chat_completions(chat_model:ChatModel, current_user: UserModel = Depen
45
 
46
  # 生成prompt模板
47
  lc_messages = [(message.role, message.content) for message in chat_model.messages]
48
- parser = StrOutputParser()
49
  prompt_template = ChatPromptTemplate.from_messages(lc_messages)
50
-
51
- chain = prompt_template | llm | parser
52
  try:
53
- result = chain.invoke({})
54
  except Exception as e:
55
  return {'error': str(e)}
56
- return result
 
 
 
57
 
58
  # 从数据库获取默认模型
59
  async def get_default_model():
@@ -76,9 +77,6 @@ ORDER BY ak.last_call_at
76
  limit 1
77
  """
78
  response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
79
- # return response
80
- # results = Db(os.getenv("DB_PATH")).list_query(sql)
81
- # return results
82
  try:
83
  result = response['result'][0]['results'][0]
84
  api_key = result['api_key']
@@ -89,11 +87,26 @@ limit 1
89
  TbsDb(db_module_filename, "Cloudflare").execute_query(query)
90
  return result
91
 
92
- async def get_user_id(api_key):
93
- sql = f"select id from users where api_key='{api_key}'"
94
- results = Db(os.getenv("DB_PATH")).list_query(sql)
95
- try:
96
- result = results[0]['id']
97
- except:
98
- result = 0
99
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import APIRouter, Depends
 
 
2
  from langchain_core.prompts import ChatPromptTemplate
3
+ from datetime import datetime
4
+ # import uuid
5
 
6
  from global_state import get
7
  from db.tbs_db import TbsDb
 
45
 
46
  # 生成prompt模板
47
  lc_messages = [(message.role, message.content) for message in chat_model.messages]
 
48
  prompt_template = ChatPromptTemplate.from_messages(lc_messages)
49
+ chain = prompt_template | llm
 
50
  try:
51
+ result = chain.invoke({}) # AIMessage 类对象
52
  except Exception as e:
53
  return {'error': str(e)}
54
+
55
+ # 转换为OpenAI格式
56
+ converted_data = convert_to_openai_format(result)
57
+ return converted_data
58
 
59
  # 从数据库获取默认模型
60
  async def get_default_model():
 
77
  limit 1
78
  """
79
  response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
 
 
 
80
  try:
81
  result = response['result'][0]['results'][0]
82
  api_key = result['api_key']
 
87
  TbsDb(db_module_filename, "Cloudflare").execute_query(query)
88
  return result
89
 
90
+ def convert_to_openai_format(original_json):
91
+ # 创建新的JSON对象
92
+ new_json = {
93
+ "id": "chatcmpl-123", # 这里可以生成一个唯一的ID,或者使用传入的id
94
+ "object": "chat.completion",
95
+ "created": int(datetime.now().timestamp()), # 当前时间戳
96
+ "choices": [
97
+ {
98
+ "index": 0,
99
+ "message": {
100
+ "role": "assistant",
101
+ "content": original_json.content # 使用原始内容
102
+ },
103
+ "finish_reason": "stop"
104
+ }
105
+ ],
106
+ "usage": {
107
+ "prompt_tokens": original_json.usage_metadata.get("input_tokens",0),
108
+ "completion_tokens": original_json.usage_metadata.get("output_tokens", 0),
109
+ "total_tokens": original_json.usage_metadata.get("total_tokens", 0)
110
+ }
111
+ }
112
+ return new_json