Mustehson commited on
Commit
e44c00b
Β·
1 Parent(s): 0eb72dc

Gemma Model

Browse files
Files changed (2) hide show
  1. app.py +15 -17
  2. visualization_prompt.py +4 -1
app.py CHANGED
@@ -31,12 +31,16 @@ else:
31
 
32
  print('Loading Model...')
33
 
34
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
 
35
 
36
  quantization_config = BitsAndBytesConfig(
37
- load_in_4bit=True)
 
 
 
38
 
39
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config=quantization_config,
40
  device_map="auto", torch_dtype=torch.bfloat16)
41
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
42
  llm = HuggingFacePipeline(pipeline=pipe)
@@ -45,17 +49,6 @@ llm = HuggingFacePipeline(pipeline=pipe)
45
  print('Model Loaded...')
46
  print(f'Model Device: {model.device}')
47
 
48
-
49
- # def generate_output(schema, prompt, generated_sql_query, result_output, result_visulaized):
50
- # return {
51
- # table_schema: schema,
52
- # input_prompt: prompt,
53
- # generated_query: generated_sql_query,
54
- # result_output: result_output,
55
- # result_visulaized: result_visulaized
56
- # }
57
-
58
- # Get Databases
59
  def get_schemas():
60
  schemas = conn.execute("""
61
  SELECT DISTINCT schema_name
@@ -148,25 +141,29 @@ Recommend a visualization:
148
  final_prompt = prompt.format_prompt(question=text_query,
149
  sql_query=sql_query, results=sql_result)
150
  response = run_llm(final_prompt)
151
-
152
  lines = response.strip().split('\n')
 
153
  visualization = lines[0].split(': ')[1]
154
  reason = lines[1].split(': ')[1]
155
 
156
  return visualization, reason
157
 
 
 
 
158
  def format_data(text_query, sql_query, sql_result, visualization_type):
159
  instruction = graph_instructions[visualization_type]
160
 
161
  template = ChatPromptTemplate.from_messages([
162
  ("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user, it's sql query, the result of the query and the format you need to format it in."),
163
- ("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it or add any label or text."),
164
  ])
165
 
166
 
167
  prompt = template.format_prompt(question=text_query, sql_query=sql_query,
168
  results=sql_result, instructions=instruction)
169
-
170
  formatted_data = run_llm(prompt)
171
  print(f'Formatted Data {formatted_data}')
172
  return json.loads(formatted_data.replace('.', '').strip())
@@ -205,6 +202,7 @@ def visualize_result(text_query, visualization_type, sql_query,
205
  return fig
206
 
207
 
 
208
  def main(table, text_query):
209
  if table is None:
210
  return ["", "", "", pd.DataFrame([{"error": "❌ Table is None."}])]
 
31
 
32
  print('Loading Model...')
33
 
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
36
 
37
  quantization_config = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_compute_dtype=torch.bfloat16,
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_quant_type= "nf4")
42
 
43
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", quantization_config=quantization_config,
44
  device_map="auto", torch_dtype=torch.bfloat16)
45
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
46
  llm = HuggingFacePipeline(pipeline=pipe)
 
49
  print('Model Loaded...')
50
  print(f'Model Device: {model.device}')
51
 
 
 
 
 
 
 
 
 
 
 
 
52
  def get_schemas():
53
  schemas = conn.execute("""
54
  SELECT DISTINCT schema_name
 
141
  final_prompt = prompt.format_prompt(question=text_query,
142
  sql_query=sql_query, results=sql_result)
143
  response = run_llm(final_prompt)
144
+ response = response.replace('```', '')
145
  lines = response.strip().split('\n')
146
+ print(lines)
147
  visualization = lines[0].split(': ')[1]
148
  reason = lines[1].split(': ')[1]
149
 
150
  return visualization, reason
151
 
152
+
153
+
154
+
155
  def format_data(text_query, sql_query, sql_result, visualization_type):
156
  instruction = graph_instructions[visualization_type]
157
 
158
  template = ChatPromptTemplate.from_messages([
159
  ("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user, it's sql query, the result of the query and the format you need to format it in."),
160
+ ("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it. Do not give backticks."),
161
  ])
162
 
163
 
164
  prompt = template.format_prompt(question=text_query, sql_query=sql_query,
165
  results=sql_result, instructions=instruction)
166
+ print(prompt)
167
  formatted_data = run_llm(prompt)
168
  print(f'Formatted Data {formatted_data}')
169
  return json.loads(formatted_data.replace('.', '').strip())
 
202
  return fig
203
 
204
 
205
+
206
  def main(table, text_query):
207
  if table is None:
208
  return ["", "", "", pd.DataFrame([{"error": "❌ Table is None."}])]
visualization_prompt.py CHANGED
@@ -4,7 +4,8 @@ barGraphIntstruction = '''
4
  labels: string[]
5
  values: {\data: number[], label: string}[]
6
  }
7
-
 
8
  // Examples of usage:
9
  Each label represents a column on the x axis.
10
  Each array in values represents a different entity.
@@ -20,6 +21,8 @@ Here we are looking at the performance of american and european players for each
20
  labels: ['series A', 'series B', 'series C'],
21
  values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
22
  }
 
 
23
  '''
24
 
25
  horizontalBarGraphIntstruction = '''
 
4
  labels: string[]
5
  values: {\data: number[], label: string}[]
6
  }
7
+
8
+ The output must follow this format strictly, even if the input data differs from the examples below.
9
  // Examples of usage:
10
  Each label represents a column on the x axis.
11
  Each array in values represents a different entity.
 
21
  labels: ['series A', 'series B', 'series C'],
22
  values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
23
  }
24
+
25
+ The output format must be consistent with this structure, regardless of the specific input data.
26
  '''
27
 
28
  horizontalBarGraphIntstruction = '''