Spaces:
Runtime error
Runtime error
Mustehson
commited on
Commit
Β·
e44c00b
1
Parent(s):
0eb72dc
Gemma Model
Browse files- app.py +15 -17
- visualization_prompt.py +4 -1
app.py
CHANGED
@@ -31,12 +31,16 @@ else:
|
|
31 |
|
32 |
print('Loading Model...')
|
33 |
|
34 |
-
|
|
|
35 |
|
36 |
quantization_config = BitsAndBytesConfig(
|
37 |
-
load_in_4bit=True
|
|
|
|
|
|
|
38 |
|
39 |
-
model = AutoModelForCausalLM.from_pretrained("
|
40 |
device_map="auto", torch_dtype=torch.bfloat16)
|
41 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
|
42 |
llm = HuggingFacePipeline(pipeline=pipe)
|
@@ -45,17 +49,6 @@ llm = HuggingFacePipeline(pipeline=pipe)
|
|
45 |
print('Model Loaded...')
|
46 |
print(f'Model Device: {model.device}')
|
47 |
|
48 |
-
|
49 |
-
# def generate_output(schema, prompt, generated_sql_query, result_output, result_visulaized):
|
50 |
-
# return {
|
51 |
-
# table_schema: schema,
|
52 |
-
# input_prompt: prompt,
|
53 |
-
# generated_query: generated_sql_query,
|
54 |
-
# result_output: result_output,
|
55 |
-
# result_visulaized: result_visulaized
|
56 |
-
# }
|
57 |
-
|
58 |
-
# Get Databases
|
59 |
def get_schemas():
|
60 |
schemas = conn.execute("""
|
61 |
SELECT DISTINCT schema_name
|
@@ -148,25 +141,29 @@ Recommend a visualization:
|
|
148 |
final_prompt = prompt.format_prompt(question=text_query,
|
149 |
sql_query=sql_query, results=sql_result)
|
150 |
response = run_llm(final_prompt)
|
151 |
-
|
152 |
lines = response.strip().split('\n')
|
|
|
153 |
visualization = lines[0].split(': ')[1]
|
154 |
reason = lines[1].split(': ')[1]
|
155 |
|
156 |
return visualization, reason
|
157 |
|
|
|
|
|
|
|
158 |
def format_data(text_query, sql_query, sql_result, visualization_type):
|
159 |
instruction = graph_instructions[visualization_type]
|
160 |
|
161 |
template = ChatPromptTemplate.from_messages([
|
162 |
("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user, it's sql query, the result of the query and the format you need to format it in."),
|
163 |
-
("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it
|
164 |
])
|
165 |
|
166 |
|
167 |
prompt = template.format_prompt(question=text_query, sql_query=sql_query,
|
168 |
results=sql_result, instructions=instruction)
|
169 |
-
|
170 |
formatted_data = run_llm(prompt)
|
171 |
print(f'Formatted Data {formatted_data}')
|
172 |
return json.loads(formatted_data.replace('.', '').strip())
|
@@ -205,6 +202,7 @@ def visualize_result(text_query, visualization_type, sql_query,
|
|
205 |
return fig
|
206 |
|
207 |
|
|
|
208 |
def main(table, text_query):
|
209 |
if table is None:
|
210 |
return ["", "", "", pd.DataFrame([{"error": "β Table is None."}])]
|
|
|
31 |
|
32 |
print('Loading Model...')
|
33 |
|
34 |
+
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
|
36 |
|
37 |
quantization_config = BitsAndBytesConfig(
|
38 |
+
load_in_4bit=True,
|
39 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
40 |
+
bnb_4bit_use_double_quant=True,
|
41 |
+
bnb_4bit_quant_type= "nf4")
|
42 |
|
43 |
+
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", quantization_config=quantization_config,
|
44 |
device_map="auto", torch_dtype=torch.bfloat16)
|
45 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
|
46 |
llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
49 |
print('Model Loaded...')
|
50 |
print(f'Model Device: {model.device}')
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def get_schemas():
|
53 |
schemas = conn.execute("""
|
54 |
SELECT DISTINCT schema_name
|
|
|
141 |
final_prompt = prompt.format_prompt(question=text_query,
|
142 |
sql_query=sql_query, results=sql_result)
|
143 |
response = run_llm(final_prompt)
|
144 |
+
response = response.replace('```', '')
|
145 |
lines = response.strip().split('\n')
|
146 |
+
print(lines)
|
147 |
visualization = lines[0].split(': ')[1]
|
148 |
reason = lines[1].split(': ')[1]
|
149 |
|
150 |
return visualization, reason
|
151 |
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
def format_data(text_query, sql_query, sql_result, visualization_type):
|
156 |
instruction = graph_instructions[visualization_type]
|
157 |
|
158 |
template = ChatPromptTemplate.from_messages([
|
159 |
("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user, it's sql query, the result of the query and the format you need to format it in."),
|
160 |
+
("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it. Do not give backticks."),
|
161 |
])
|
162 |
|
163 |
|
164 |
prompt = template.format_prompt(question=text_query, sql_query=sql_query,
|
165 |
results=sql_result, instructions=instruction)
|
166 |
+
print(prompt)
|
167 |
formatted_data = run_llm(prompt)
|
168 |
print(f'Formatted Data {formatted_data}')
|
169 |
return json.loads(formatted_data.replace('.', '').strip())
|
|
|
202 |
return fig
|
203 |
|
204 |
|
205 |
+
|
206 |
def main(table, text_query):
|
207 |
if table is None:
|
208 |
return ["", "", "", pd.DataFrame([{"error": "β Table is None."}])]
|
visualization_prompt.py
CHANGED
@@ -4,7 +4,8 @@ barGraphIntstruction = '''
|
|
4 |
labels: string[]
|
5 |
values: {\data: number[], label: string}[]
|
6 |
}
|
7 |
-
|
|
|
8 |
// Examples of usage:
|
9 |
Each label represents a column on the x axis.
|
10 |
Each array in values represents a different entity.
|
@@ -20,6 +21,8 @@ Here we are looking at the performance of american and european players for each
|
|
20 |
labels: ['series A', 'series B', 'series C'],
|
21 |
values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
|
22 |
}
|
|
|
|
|
23 |
'''
|
24 |
|
25 |
horizontalBarGraphIntstruction = '''
|
|
|
4 |
labels: string[]
|
5 |
values: {\data: number[], label: string}[]
|
6 |
}
|
7 |
+
|
8 |
+
The output must follow this format strictly, even if the input data differs from the examples below.
|
9 |
// Examples of usage:
|
10 |
Each label represents a column on the x axis.
|
11 |
Each array in values represents a different entity.
|
|
|
21 |
labels: ['series A', 'series B', 'series C'],
|
22 |
values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
|
23 |
}
|
24 |
+
|
25 |
+
The output format must be consistent with this structure, regardless of the specific input data.
|
26 |
'''
|
27 |
|
28 |
horizontalBarGraphIntstruction = '''
|