DragonProgrammer commited on
Commit
dd964ab
·
verified ·
1 Parent(s): 7dc08c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -62,32 +62,41 @@ def safe_calculator_func(expression: str) -> str:
62
  print(f"Error during calculation for '{expression}': {e}")
63
  return f"Error calculating '{expression}': Invalid expression or calculation error ({e})."
64
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  class LangChainAgentWrapper:
67
  def __init__(self):
68
  print("Initializing LangChainAgentWrapper...")
69
 
 
70
  model_id = "google/flan-t5-base"
71
 
72
  try:
 
73
  print(f"Loading model pipeline for: {model_id}")
74
 
75
- # --- MODIFICATION: Use the custom pipeline class ---
76
- # Load the tokenizer first
77
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
78
- # Load the model
79
  model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_id)
80
 
81
- # Create an instance of our custom pipeline
82
  llm_pipeline = FlanT5Pipeline(
83
  task="text2text-generation",
84
  model=model,
85
  tokenizer=tokenizer,
86
  device_map="auto",
87
- max_new_tokens=512 # Add max_new_tokens to control output length
88
  )
89
- # --- END MODIFICATION ---
90
-
91
  print("Model pipeline loaded successfully.")
92
 
93
  # Wrap the pipeline in a LangChain LLM object
 
62
  print(f"Error during calculation for '{expression}': {e}")
63
  return f"Error calculating '{expression}': Invalid expression or calculation error ({e})."
64
 
65
+ # --- Custom Pipeline to Fix LangChain Integration ---
66
+ class FlanT5Pipeline(transformers.Pipeline):
67
+ def _call(self, *args, **kwargs):
68
+ # The HuggingFacePipeline class in LangChain might not pass the input
69
+ # with the 'inputs' keyword. This custom _call method ensures that
70
+ # whatever is passed as the first argument is correctly forwarded.
71
+ if args and len(args) > 0:
72
+ return super()._call(args[0], **kwargs)
73
+ else:
74
+ # Fallback in case no positional arguments are provided
75
+ return super()._call(kwargs)
76
 
77
  class LangChainAgentWrapper:
78
  def __init__(self):
79
  print("Initializing LangChainAgentWrapper...")
80
 
81
+ # Switched to a smaller, CPU-friendly instruction-tuned model
82
  model_id = "google/flan-t5-base"
83
 
84
  try:
85
+ hf_auth_token = os.getenv("HF_TOKEN")
86
  print(f"Loading model pipeline for: {model_id}")
87
 
88
+ # We load the model and tokenizer objects first
 
89
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
 
90
  model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_id)
91
 
92
+ # Now we use our custom FlanT5Pipeline class
93
  llm_pipeline = FlanT5Pipeline(
94
  task="text2text-generation",
95
  model=model,
96
  tokenizer=tokenizer,
97
  device_map="auto",
98
+ max_new_tokens=512
99
  )
 
 
100
  print("Model pipeline loaded successfully.")
101
 
102
  # Wrap the pipeline in a LangChain LLM object