sivakum4 commited on
Commit
7017d8a
·
1 Parent(s): 2e0cda6

Feat: HF Inference API

Browse files
Files changed (1) hide show
  1. buffalo_rag/model/rag.py +53 -67
buffalo_rag/model/rag.py CHANGED
@@ -2,49 +2,30 @@ import os
2
  import json
3
  from typing import List, Dict, Any, Optional, Tuple
4
 
5
- import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
7
 
8
  from buffalo_rag.vector_store.db import VectorStore
9
 
10
  class BuffaloRAG:
11
- def __init__(self,
12
- model_name: str = "Qwen/Qwen1.5-1.8B-Chat",
13
- vector_store: Optional[VectorStore] = None):
 
 
 
14
  self.vector_store = vector_store or VectorStore()
15
-
16
- try:
17
- # Load model and tokenizer
18
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- self.model = AutoModelForCausalLM.from_pretrained(
20
- model_name,
21
- torch_dtype=torch.float16,
22
- device_map="auto",
23
- trust_remote_code=True,
24
- low_cpu_mem_usage=True
25
- )
26
-
27
- # More conservative generation parameters for stability
28
- self.pipe = pipeline(
29
- "text-generation",
30
- model=self.model,
31
- tokenizer=self.tokenizer,
32
- max_new_tokens=256, # Shorter outputs for stability
33
- do_sample=False, # Use greedy decoding instead of sampling
34
- pad_token_id=self.tokenizer.eos_token_id
35
- )
36
- except Exception as e:
37
- print(f"Error loading main model: {str(e)}")
38
- print("Falling back to smaller model...")
39
- # Fallback to a smaller, more stable model
40
- self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
41
- self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
42
- self.pipe = pipeline(
43
- "text-generation",
44
- model=self.model,
45
- tokenizer=self.tokenizer,
46
- max_new_tokens=256
47
  )
 
48
 
49
  def retrieve(self,
50
  query: str,
@@ -54,46 +35,51 @@ class BuffaloRAG:
54
  return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories)
55
 
56
  def format_context(self, results: List[Dict[str, Any]]) -> str:
57
- """Format retrieved results into context."""
58
- context = ""
59
-
60
- for i, result in enumerate(results):
61
- chunk = result['chunk']
62
- context += f"Source {i+1}: {chunk['title']}\n"
63
- context += f"URL: {chunk['url']}\n"
64
- context += f"Content: {chunk['content'][:500]}...\n\n"
65
-
66
- return context
67
 
68
  def generate_response(self, query: str, context: str) -> str:
69
  """Generate response using the language model with error handling."""
70
  prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner.
71
 
72
- For your responses:
73
- 1. Address the student respectfully and empathetically
74
- 2. Provide clear, accurate information with specific details and steps when applicable
75
- 3. Organize your answer with appropriate headings, bullet points, or numbered lists when helpful
76
- 4. If the student's question is unclear or lacks essential details, ask 1-2 specific clarifying questions to better understand their situation
77
- 5. Include relevant deadlines, contacts, or resources when appropriate
78
- 6. Conclude with a brief encouraging statement
79
- 7. Only answer related to international students at UB, if it's not related to international students at UB, just say "I'm sorry, I don't have information about that."
80
- 8. Do not entertain any questions that are not related to students at UB.
81
 
82
- Question: {query}
83
 
84
- Relevant Information:
85
- {context}
86
 
87
- Answer:"""
88
 
89
  try:
90
- # Generate response
91
- response = self.pipe(prompt)[0]['generated_text']
92
-
93
- # Extract only the generated part (after the prompt)
94
- generated = response[len(prompt):].strip()
95
-
96
- return generated
 
 
 
 
 
97
  except Exception as e:
98
  print(f"Error during generation: {str(e)}")
99
  # Fallback response
 
2
  import json
3
  from typing import List, Dict, Any, Optional, Tuple
4
 
5
+ # from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+
7
+ from huggingface_hub import InferenceClient
8
 
9
  from buffalo_rag.vector_store.db import VectorStore
10
 
11
  class BuffaloRAG:
12
+ def __init__(
13
+ self,
14
+ model_name: str = "meta-llama/Llama-2-7b-chat-hf",
15
+ vector_store: Optional[VectorStore] = None
16
+ ):
17
+ # 1. Vector store
18
  self.vector_store = vector_store or VectorStore()
19
+
20
+ # 2. Hugging Face Inference client
21
+ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
22
+ if not hf_token:
23
+ raise ValueError("Please set HUGGINGFACEHUB_API_TOKEN in your environment.")
24
+ self.client = InferenceClient(
25
+ provider="cerebras",
26
+ api_key=hf_token,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
+
29
 
30
  def retrieve(self,
31
  query: str,
 
35
  return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories)
36
 
37
  def format_context(self, results: List[Dict[str, Any]]) -> str:
38
+ """Concatenate retrieved passages into context."""
39
+ ctx = []
40
+ for i, r in enumerate(results, start=1):
41
+ c = r["chunk"]
42
+ ctx.append(
43
+ f"Source {i}: {c['title']}\n"
44
+ f"URL: {c['url']}\n"
45
+ f"Content: {c['content'][:500]}...\n"
46
+ )
47
+ return "\n".join(ctx)
48
 
49
  def generate_response(self, query: str, context: str) -> str:
50
  """Generate response using the language model with error handling."""
51
  prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner.
52
 
53
+ For your responses:
54
+ 1. Address the student respectfully and empathetically
55
+ 2. Provide clear, accurate information with specific details and steps when applicable
56
+ 3. Organize your answer with appropriate headings, bullet points, or numbered lists when helpful
57
+ 4. If the student's question is unclear or lacks essential details, ask 1-2 specific clarifying questions to better understand their situation
58
+ 5. Include relevant deadlines, contacts, or resources when appropriate
59
+ 6. Conclude with a brief encouraging statement
60
+ 7. Only answer related to international students at UB, if it's not related to international students at UB, just say "I'm sorry, I don't have information about that."
61
+ 8. Do not entertain any questions that are not related to students at UB.
62
 
63
+ Question: {query}
64
 
65
+ Relevant Information:
66
+ {context}
67
 
68
+ Answer:"""
69
 
70
  try:
71
+ completion = self.client.chat.completions.create(
72
+ model="meta-llama/Llama-3.3-70B-Instruct",
73
+ messages=[
74
+ {
75
+ "role": "user",
76
+ "content": prompt
77
+ }
78
+ ],
79
+ max_tokens=512,
80
+ )
81
+
82
+ return completion.choices[0].message.content
83
  except Exception as e:
84
  print(f"Error during generation: {str(e)}")
85
  # Fallback response