Mustehson commited on
Commit
5023c74
·
1 Parent(s): 9d5557c

Langsmith Logs

Browse files
Files changed (2) hide show
  1. app.py +36 -18
  2. requirements.txt +7 -4
app.py CHANGED
@@ -6,8 +6,10 @@ import pandas as pd
6
  import pandera as pa
7
  from pandera import Column
8
  import ydata_profiling as pp
9
- from huggingface_hub import InferenceClient
 
10
  from prompt import PROMPT_PANDERA, PANDERA_USER_INPUT_PROMPT
 
11
  import warnings
12
  warnings.filterwarnings("ignore", category=DeprecationWarning)
13
 
@@ -16,12 +18,9 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
16
  TAB_LINES = 8
17
  # Load Token
18
  md_token = os.getenv('MD_TOKEN')
19
- os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
20
-
21
-
22
 
23
  INPUT_PROMPT = '''
24
- Here is the frist few samples of data:
25
  <Sample Data>
26
  {data}
27
  </Sample Data<>
@@ -29,7 +28,7 @@ Here is the frist few samples of data:
29
 
30
 
31
  USER_INPUT = '''
32
- Here is the frist few samples of data:
33
  <Sample Data>
34
  {data}
35
  </Sample Data<>
@@ -44,7 +43,22 @@ Here is the User Description:
44
  print('Connecting to DB...')
45
  # Connect to DB
46
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
47
- client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Get Databases
50
  def get_schemas():
@@ -71,29 +85,33 @@ def get_data_df(schema):
71
 
72
 
73
  def chat_template(system_prompt, user_prompt, df):
74
-
75
- messages=[
76
- {"role": "system", "content": system_prompt},
77
- {"role": "user", "content": user_prompt.format(data=df.head().to_json(orient='records'))},
78
  ]
79
  return messages
 
80
  def chat_template_user(system_prompt, user_prompt, user_description, df):
81
 
82
- messages=[
83
- {"role": "system", "content": system_prompt},
84
- {"role": "user", "content": user_prompt.format(data=df.head(1).to_json(orient='records'), user_description=user_description)},
85
  ]
86
- return messages
 
87
 
 
88
  def run_llm(messages):
89
  try:
90
- response = client.chat_completion(messages, max_tokens=1024)
91
- print(response.choices[0].message.content)
92
- tests = json.loads(response.choices[0].message.content)
93
  except Exception as e:
94
  return e
95
  return tests
96
 
 
97
  # Get Schema
98
  def get_table_schema(table):
99
  result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
 
6
  import pandera as pa
7
  from pandera import Column
8
  import ydata_profiling as pp
9
+ from langchain_core.messages import HumanMessage, SystemMessage
10
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
11
  from prompt import PROMPT_PANDERA, PANDERA_USER_INPUT_PROMPT
12
+ from langsmith import traceable
13
  import warnings
14
  warnings.filterwarnings("ignore", category=DeprecationWarning)
15
 
 
18
  TAB_LINES = 8
19
  # Load Token
20
  md_token = os.getenv('MD_TOKEN')
 
 
 
21
 
22
  INPUT_PROMPT = '''
23
+ Here are the first few samples of data:
24
  <Sample Data>
25
  {data}
26
  </Sample Data<>
 
28
 
29
 
30
  USER_INPUT = '''
31
+ Here are the first few samples of data:
32
  <Sample Data>
33
  {data}
34
  </Sample Data<>
 
43
  print('Connecting to DB...')
44
  # Connect to DB
45
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
46
+
47
+ models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct",
48
+ "meta-llama/Llama-3.1-70B-Instruct"]
49
+
50
+ model_loaded = False
51
+ for model in models:
52
+ try:
53
+ endpoint = HuggingFaceEndpoint(repo_id=model, max_new_tokens=8192)
54
+ info = endpoint.client.get_endpoint_info()
55
+ model_loaded = True
56
+ break
57
+ except Exception as e:
58
+ print(f"Error for model {model}: {e}")
59
+ continue
60
+
61
+ llm = ChatHuggingFace(llm=endpoint).bind_tools(tools=[], max_tokens=8192)
62
 
63
  # Get Databases
64
  def get_schemas():
 
85
 
86
 
87
  def chat_template(system_prompt, user_prompt, df):
88
+
89
+ messages = [
90
+ SystemMessage(content=system_prompt),
91
+ HumanMessage(content=user_prompt.format(data=df.head().to_json(orient='records'))),
92
  ]
93
  return messages
94
+
95
  def chat_template_user(system_prompt, user_prompt, user_description, df):
96
 
97
+ messages = [
98
+ SystemMessage(content=system_prompt),
99
+ HumanMessage(content=user_prompt.format(data=df.head(1).to_json(orient='records'), user_description=user_description)),
100
  ]
101
+ return messages
102
+
103
 
104
+ @traceable()
105
  def run_llm(messages):
106
  try:
107
+ response = llm.invoke(messages)
108
+ print(response.content)
109
+ tests = json.loads(response.content)
110
  except Exception as e:
111
  return e
112
  return tests
113
 
114
+
115
  # Get Schema
116
  def get_table_schema(table):
117
  result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
requirements.txt CHANGED
@@ -1,7 +1,10 @@
1
  torch
2
  huggingface_hub
3
- accelerate
 
4
  transformers==4.44.2
5
- duckdb
6
- pandera
7
- ydata-profiling
 
 
 
1
  torch
2
  huggingface_hub
3
+ langchain_huggingface
4
+ accelerate==0.34.2
5
  transformers==4.44.2
6
+ duckdb==1.1.1
7
+ langsmith==0.1.135
8
+ pandera==0.20.4
9
+ ydata-profiling==v4.11.0
10
+ langchain-core==0.3.12