HusnaManakkot commited on
Commit
aff9b4b
Β·
verified Β·
1 Parent(s): 3673a2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -5
app.py CHANGED
@@ -10,13 +10,9 @@ model = AutoModelForSeq2SeqLM.from_pretrained("hrshtsharma2012/NL2SQL-Picard-fin
10
  spider_dataset = load_dataset("spider", split='train[:5]')
11
 
12
  def generate_sql(query):
13
- # Add a prefix to prompt the model to generate SQL
14
- prefixed_query = f"T1.SQL: '{query}'"
15
- inputs = tokenizer(prefixed_query, return_tensors="pt", padding=True)
16
  outputs = model.generate(**inputs, max_length=512)
17
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
- # Remove the prefix from the generated SQL query
19
- sql_query = sql_query.replace("T1.SQL: '", "").rstrip("'")
20
  return sql_query
21
 
22
  # Use examples from the Spider dataset
 
10
  spider_dataset = load_dataset("spider", split='train[:5]')
11
 
12
  def generate_sql(query):
13
+ inputs = tokenizer(query, return_tensors="pt", padding=True)
 
 
14
  outputs = model.generate(**inputs, max_length=512)
15
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
16
  return sql_query
17
 
18
  # Use examples from the Spider dataset