Hasan Iqbal
commited on
Replaces all gpt-3.5 with gpt-4o
Browse files- src/openfactcheck/solvers/factcheckgpt/factcheckgpt_cp.py +18 -10
- src/openfactcheck/solvers/factcheckgpt/factcheckgpt_rtv.py +27 -30
- src/openfactcheck/solvers/factcheckgpt/factcheckgpt_vfr.py +12 -24
- src/openfactcheck/solvers/factool/factool_utils/chat_api.py +65 -55
- src/openfactcheck/solvers/rarr/rarr_agreement_gate.py +8 -6
- src/openfactcheck/solvers/rarr/rarr_editor.py +16 -16
- src/openfactcheck/solvers/rarr/rarr_llm_retriever.py +4 -9
- src/openfactcheck/solvers/rarr/rarr_question_generator.py +4 -5
- src/openfactcheck/solvers/tutorial/utils/api.py +24 -16
- src/openfactcheck/solvers/webservice/factcheckgpt_cp.py +18 -10
- src/openfactcheck/solvers/webservice/factcheckgpt_rtv.py +27 -30
- src/openfactcheck/solvers/webservice/factcheckgpt_vfr.py +12 -24
- src/openfactcheck/solvers/webservice/factool_utils/chat_api.py +45 -40
- src/openfactcheck/solvers/webservice/ftool_cp.py +3 -2
- src/openfactcheck/solvers/webservice/ftool_rtv.py +4 -3
- src/openfactcheck/solvers/webservice/ftool_vfr.py +12 -7
- src/openfactcheck/solvers/webservice/rarr_rtv.py +5 -4
- src/openfactcheck/solvers/webservice/rarr_vfr.py +10 -9
- src/openfactcheck/state.py +57 -20
- src/openfactcheck/templates/solver_configs/webservice.yaml +3 -3
src/openfactcheck/solvers/factcheckgpt/factcheckgpt_cp.py
CHANGED
|
@@ -5,14 +5,20 @@ from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
|
| 5 |
|
| 6 |
from .factcheckgpt_utils.openai_api import gpt
|
| 7 |
from .factcheckgpt_utils.data_util import save_to_file
|
| 8 |
-
from .factcheckgpt_utils.prompt import
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
@Solver.register("factcheckgpt_claimprocessor", "response", "claims")
|
| 12 |
class FactCheckGPTClaimProcessor(StandardTaskSolver):
|
| 13 |
def __init__(self, args):
|
| 14 |
super().__init__(args)
|
| 15 |
-
self.model = self.global_config.get("factcheckgpt_model", "gpt-
|
| 16 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 17 |
self.mode = args.get("mode", "independent_sentences")
|
| 18 |
self.decompose_system_role = "You are good at decomposing and decontextualizing text."
|
|
@@ -22,19 +28,19 @@ class FactCheckGPTClaimProcessor(StandardTaskSolver):
|
|
| 22 |
self.prompt = {
|
| 23 |
"sentences": DOC_TO_SENTENCES_PROMPT,
|
| 24 |
"independent_sentences": DOC_TO_INDEPEDENT_SENTENCES_PROMPT,
|
| 25 |
-
"claims": SENTENCES_TO_CLAIMS_PROMPT
|
| 26 |
}.get(self.mode, DOC_TO_INDEPEDENT_SENTENCES_PROMPT)
|
| 27 |
nlp = spacy.load(self.spacy_model)
|
| 28 |
self.rule_based_tool = {
|
| 29 |
"nltk": lambda x: [x.strip() for x in nltk.sent_tokenize(x) if len(x.strip()) >= 3],
|
| 30 |
-
"spacy": lambda x: [x.text.strip() for x in nlp(x).sents if len(x.text.strip()) >= 3]
|
| 31 |
}.get(self.rule_based_method, "nltk")
|
| 32 |
|
| 33 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 34 |
# We have merged the text decomposer and worthiness filter here.
|
| 35 |
response = state.get(self.input_name)
|
| 36 |
claims = [response]
|
| 37 |
-
|
| 38 |
user_input = self.prompt.format(doc=response).strip()
|
| 39 |
r = gpt(user_input, model=self.model, system_role=self.decompose_system_role, num_retries=self.num_retries)
|
| 40 |
try:
|
|
@@ -45,13 +51,15 @@ class FactCheckGPTClaimProcessor(StandardTaskSolver):
|
|
| 45 |
|
| 46 |
if not isinstance(claims, list):
|
| 47 |
print(
|
| 48 |
-
f"{self.model} output {r}. It does not output a list of sentences correctly, return rule-based split results."
|
|
|
|
| 49 |
claims = self.rule_based_tool(response)
|
| 50 |
-
|
| 51 |
worthiness = [True] * len(claims)
|
| 52 |
user_input = CHECKWORTHY_PROMPT_BOOL.format(claims=claims)
|
| 53 |
-
response = gpt(
|
| 54 |
-
|
|
|
|
| 55 |
# TODO refine check worthiness prompt, value returned not reasonable.
|
| 56 |
try:
|
| 57 |
worthiness = eval(response)
|
|
|
|
| 5 |
|
| 6 |
from .factcheckgpt_utils.openai_api import gpt
|
| 7 |
from .factcheckgpt_utils.data_util import save_to_file
|
| 8 |
+
from .factcheckgpt_utils.prompt import (
|
| 9 |
+
DOC_TO_INDEPEDENT_SENTENCES_PROMPT,
|
| 10 |
+
SENTENCES_TO_CLAIMS_PROMPT,
|
| 11 |
+
DOC_TO_SENTENCES_PROMPT,
|
| 12 |
+
CHECKWORTHY_PROMPT_BOOL,
|
| 13 |
+
SPECIFY_CHECKWORTHY_CATEGORY_PROMPT,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
|
| 17 |
@Solver.register("factcheckgpt_claimprocessor", "response", "claims")
|
| 18 |
class FactCheckGPTClaimProcessor(StandardTaskSolver):
|
| 19 |
def __init__(self, args):
|
| 20 |
super().__init__(args)
|
| 21 |
+
self.model = self.global_config.get("factcheckgpt_model", "gpt-4o")
|
| 22 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 23 |
self.mode = args.get("mode", "independent_sentences")
|
| 24 |
self.decompose_system_role = "You are good at decomposing and decontextualizing text."
|
|
|
|
| 28 |
self.prompt = {
|
| 29 |
"sentences": DOC_TO_SENTENCES_PROMPT,
|
| 30 |
"independent_sentences": DOC_TO_INDEPEDENT_SENTENCES_PROMPT,
|
| 31 |
+
"claims": SENTENCES_TO_CLAIMS_PROMPT,
|
| 32 |
}.get(self.mode, DOC_TO_INDEPEDENT_SENTENCES_PROMPT)
|
| 33 |
nlp = spacy.load(self.spacy_model)
|
| 34 |
self.rule_based_tool = {
|
| 35 |
"nltk": lambda x: [x.strip() for x in nltk.sent_tokenize(x) if len(x.strip()) >= 3],
|
| 36 |
+
"spacy": lambda x: [x.text.strip() for x in nlp(x).sents if len(x.text.strip()) >= 3],
|
| 37 |
}.get(self.rule_based_method, "nltk")
|
| 38 |
|
| 39 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 40 |
# We have merged the text decomposer and worthiness filter here.
|
| 41 |
response = state.get(self.input_name)
|
| 42 |
claims = [response]
|
| 43 |
+
|
| 44 |
user_input = self.prompt.format(doc=response).strip()
|
| 45 |
r = gpt(user_input, model=self.model, system_role=self.decompose_system_role, num_retries=self.num_retries)
|
| 46 |
try:
|
|
|
|
| 51 |
|
| 52 |
if not isinstance(claims, list):
|
| 53 |
print(
|
| 54 |
+
f"{self.model} output {r}. It does not output a list of sentences correctly, return rule-based split results."
|
| 55 |
+
)
|
| 56 |
claims = self.rule_based_tool(response)
|
| 57 |
+
|
| 58 |
worthiness = [True] * len(claims)
|
| 59 |
user_input = CHECKWORTHY_PROMPT_BOOL.format(claims=claims)
|
| 60 |
+
response = gpt(
|
| 61 |
+
user_input, model=self.model, system_role=self.worthines_filter_system_role, num_retries=self.num_retries
|
| 62 |
+
)
|
| 63 |
# TODO refine check worthiness prompt, value returned not reasonable.
|
| 64 |
try:
|
| 65 |
worthiness = eval(response)
|
src/openfactcheck/solvers/factcheckgpt/factcheckgpt_rtv.py
CHANGED
|
@@ -17,16 +17,16 @@ from .factcheckgpt_utils.openai_api import gpt
|
|
| 17 |
from .factcheckgpt_utils.prompt import QGEN_PROMPT, QGEN_PROMPT_FMT
|
| 18 |
from .factcheckgpt_utils.data_util import save_txt, save_json
|
| 19 |
|
|
|
|
| 20 |
@Solver.register("factcheckgpt_retriever", "claims", "claims_with_evidences")
|
| 21 |
class FactCheckGPTRetriever(StandardTaskSolver):
|
| 22 |
def __init__(self, args):
|
| 23 |
super().__init__(args)
|
| 24 |
-
self.model = self.global_config.get("factcheckgpt_model", "gpt-
|
| 25 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 26 |
self.tokenizer = spacy.load("en_core_web_sm", disable=["ner", "tagger", "lemmatizer"])
|
| 27 |
self.question_duplicate_model = CrossEncoder(
|
| 28 |
-
|
| 29 |
-
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 30 |
)
|
| 31 |
self.passage_ranker = CrossEncoder(
|
| 32 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
@@ -44,7 +44,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 44 |
self.sentences_per_passage = args.get("sentences_per_passage", 5)
|
| 45 |
self.max_passages_per_question = args.get("max_passages_per_question", 5)
|
| 46 |
self.max_aggregated_evidences = args.get("max_aggregated_evidences", 5)
|
| 47 |
-
self.question_persist_path = args.get("question_persist_path",
|
| 48 |
self.snippets_persist_path = args.get("snippets_persist_path", "passage.json")
|
| 49 |
|
| 50 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
|
@@ -52,7 +52,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 52 |
claims_with_evidences = {}
|
| 53 |
for i, claim in enumerate(claims):
|
| 54 |
evidences = self.get_web_evidences_for_claim(claim)
|
| 55 |
-
claims_with_evidences[claim] = [(q, e[
|
| 56 |
state.set(self.output_name, claims_with_evidences)
|
| 57 |
return True, state
|
| 58 |
|
|
@@ -69,11 +69,9 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 69 |
snippets = {}
|
| 70 |
for question in questions:
|
| 71 |
retrieved_passages = self.get_relevant_snippets(question)
|
| 72 |
-
snippets[question] = sorted(
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
reverse=True
|
| 76 |
-
)[:self.max_passages_per_question]
|
| 77 |
save_json(snippets, self.snippets_persist_path)
|
| 78 |
return snippets
|
| 79 |
|
|
@@ -110,7 +108,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 110 |
model=self.model,
|
| 111 |
system_role=self.qgen_system_role,
|
| 112 |
num_retries=self.num_retries,
|
| 113 |
-
temperature=self.qgen_temp
|
| 114 |
)
|
| 115 |
try:
|
| 116 |
cur_round_questions = set(eval(response))
|
|
@@ -182,8 +180,8 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 182 |
return False
|
| 183 |
return True
|
| 184 |
|
| 185 |
-
def search_google(self, query: str, num_web_pages: int = 10, timeout: int = 6, save_url: str =
|
| 186 |
-
"""Searches the query using Google.
|
| 187 |
Args:
|
| 188 |
query: Search query.
|
| 189 |
num_web_pages: the number of web pages to request.
|
|
@@ -198,7 +196,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 198 |
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.14; rv:65.0) Gecko/20100101 Firefox/65.0"
|
| 199 |
# mobile user-agent
|
| 200 |
MOBILE_USER_AGENT = "Mozilla/5.0 (Linux; Android 7.0; SM-G930V Build/NRD90M) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.125 Mobile Safari/537.36"
|
| 201 |
-
headers = {
|
| 202 |
|
| 203 |
# set language
|
| 204 |
# set the Google interface language, use &hl=XX
|
|
@@ -222,18 +220,18 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 222 |
|
| 223 |
# save all url into a txt file
|
| 224 |
if not save_url == "":
|
| 225 |
-
with open(save_url,
|
| 226 |
for url in urls:
|
| 227 |
-
file.write(url +
|
| 228 |
return urls
|
| 229 |
|
| 230 |
def chunk_text(
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
) -> list[str]:
|
| 238 |
"""Chunks text into passages using a sliding window.
|
| 239 |
|
|
@@ -260,15 +258,16 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 260 |
]
|
| 261 |
for idx in range(0, len(sents), sliding_distance):
|
| 262 |
passages.append(
|
| 263 |
-
(" ".join(sents[idx: idx + sentences_per_passage]), idx, idx + sentences_per_passage - 1)
|
|
|
|
| 264 |
except UnicodeEncodeError as _: # Sometimes run into Unicode error when tokenizing.
|
| 265 |
print("Unicode error when using Spacy. Skipping text.")
|
| 266 |
|
| 267 |
return passages
|
| 268 |
|
| 269 |
def get_relevant_snippets(
|
| 270 |
-
|
| 271 |
-
|
| 272 |
):
|
| 273 |
search_results = self.search_google(query, timeout=self.search_timeout)
|
| 274 |
|
|
@@ -278,11 +277,9 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 278 |
scraped_results = [r for r in scraped_results if r[0] and ".pdf" not in r[1]]
|
| 279 |
# print("Num Bing Search Results: ", len(scraped_results))
|
| 280 |
retrieved_passages = list()
|
| 281 |
-
for webtext, url in scraped_results[:self.max_search_results_per_query]:
|
| 282 |
passages = self.chunk_text(
|
| 283 |
-
text=webtext,
|
| 284 |
-
tokenizer=self.tokenizer,
|
| 285 |
-
sentences_per_passage=self.sentences_per_passage
|
| 286 |
)
|
| 287 |
if not passages:
|
| 288 |
continue
|
|
@@ -304,7 +301,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 304 |
overlap = True
|
| 305 |
break
|
| 306 |
|
| 307 |
-
# Only consider top non-overlapping relevant passages to maximise for information
|
| 308 |
if not overlap:
|
| 309 |
relevant_items.append(deepcopy(passage_item))
|
| 310 |
retrieved_passages.append(
|
|
|
|
| 17 |
from .factcheckgpt_utils.prompt import QGEN_PROMPT, QGEN_PROMPT_FMT
|
| 18 |
from .factcheckgpt_utils.data_util import save_txt, save_json
|
| 19 |
|
| 20 |
+
|
| 21 |
@Solver.register("factcheckgpt_retriever", "claims", "claims_with_evidences")
|
| 22 |
class FactCheckGPTRetriever(StandardTaskSolver):
|
| 23 |
def __init__(self, args):
|
| 24 |
super().__init__(args)
|
| 25 |
+
self.model = self.global_config.get("factcheckgpt_model", "gpt-4o")
|
| 26 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 27 |
self.tokenizer = spacy.load("en_core_web_sm", disable=["ner", "tagger", "lemmatizer"])
|
| 28 |
self.question_duplicate_model = CrossEncoder(
|
| 29 |
+
"navteca/quora-roberta-base", device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 30 |
)
|
| 31 |
self.passage_ranker = CrossEncoder(
|
| 32 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
|
|
| 44 |
self.sentences_per_passage = args.get("sentences_per_passage", 5)
|
| 45 |
self.max_passages_per_question = args.get("max_passages_per_question", 5)
|
| 46 |
self.max_aggregated_evidences = args.get("max_aggregated_evidences", 5)
|
| 47 |
+
self.question_persist_path = args.get("question_persist_path", "questions.txt")
|
| 48 |
self.snippets_persist_path = args.get("snippets_persist_path", "passage.json")
|
| 49 |
|
| 50 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
|
|
|
| 52 |
claims_with_evidences = {}
|
| 53 |
for i, claim in enumerate(claims):
|
| 54 |
evidences = self.get_web_evidences_for_claim(claim)
|
| 55 |
+
claims_with_evidences[claim] = [(q, e["text"]) for q, e in evidences["aggregated"]]
|
| 56 |
state.set(self.output_name, claims_with_evidences)
|
| 57 |
return True, state
|
| 58 |
|
|
|
|
| 69 |
snippets = {}
|
| 70 |
for question in questions:
|
| 71 |
retrieved_passages = self.get_relevant_snippets(question)
|
| 72 |
+
snippets[question] = sorted(retrieved_passages, key=lambda x: x["retrieval_score"], reverse=True)[
|
| 73 |
+
: self.max_passages_per_question
|
| 74 |
+
]
|
|
|
|
|
|
|
| 75 |
save_json(snippets, self.snippets_persist_path)
|
| 76 |
return snippets
|
| 77 |
|
|
|
|
| 108 |
model=self.model,
|
| 109 |
system_role=self.qgen_system_role,
|
| 110 |
num_retries=self.num_retries,
|
| 111 |
+
temperature=self.qgen_temp,
|
| 112 |
)
|
| 113 |
try:
|
| 114 |
cur_round_questions = set(eval(response))
|
|
|
|
| 180 |
return False
|
| 181 |
return True
|
| 182 |
|
| 183 |
+
def search_google(self, query: str, num_web_pages: int = 10, timeout: int = 6, save_url: str = "") -> list[str]:
|
| 184 |
+
"""Searches the query using Google.
|
| 185 |
Args:
|
| 186 |
query: Search query.
|
| 187 |
num_web_pages: the number of web pages to request.
|
|
|
|
| 196 |
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.14; rv:65.0) Gecko/20100101 Firefox/65.0"
|
| 197 |
# mobile user-agent
|
| 198 |
MOBILE_USER_AGENT = "Mozilla/5.0 (Linux; Android 7.0; SM-G930V Build/NRD90M) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.125 Mobile Safari/537.36"
|
| 199 |
+
headers = {"User-Agent": USER_AGENT}
|
| 200 |
|
| 201 |
# set language
|
| 202 |
# set the Google interface language, use &hl=XX
|
|
|
|
| 220 |
|
| 221 |
# save all url into a txt file
|
| 222 |
if not save_url == "":
|
| 223 |
+
with open(save_url, "w") as file:
|
| 224 |
for url in urls:
|
| 225 |
+
file.write(url + "\n")
|
| 226 |
return urls
|
| 227 |
|
| 228 |
def chunk_text(
|
| 229 |
+
self,
|
| 230 |
+
text: str,
|
| 231 |
+
tokenizer,
|
| 232 |
+
sentences_per_passage: int = 5,
|
| 233 |
+
filter_sentence_len: int = 250,
|
| 234 |
+
sliding_distance: int = 2,
|
| 235 |
) -> list[str]:
|
| 236 |
"""Chunks text into passages using a sliding window.
|
| 237 |
|
|
|
|
| 258 |
]
|
| 259 |
for idx in range(0, len(sents), sliding_distance):
|
| 260 |
passages.append(
|
| 261 |
+
(" ".join(sents[idx : idx + sentences_per_passage]), idx, idx + sentences_per_passage - 1)
|
| 262 |
+
)
|
| 263 |
except UnicodeEncodeError as _: # Sometimes run into Unicode error when tokenizing.
|
| 264 |
print("Unicode error when using Spacy. Skipping text.")
|
| 265 |
|
| 266 |
return passages
|
| 267 |
|
| 268 |
def get_relevant_snippets(
|
| 269 |
+
self,
|
| 270 |
+
query,
|
| 271 |
):
|
| 272 |
search_results = self.search_google(query, timeout=self.search_timeout)
|
| 273 |
|
|
|
|
| 277 |
scraped_results = [r for r in scraped_results if r[0] and ".pdf" not in r[1]]
|
| 278 |
# print("Num Bing Search Results: ", len(scraped_results))
|
| 279 |
retrieved_passages = list()
|
| 280 |
+
for webtext, url in scraped_results[: self.max_search_results_per_query]:
|
| 281 |
passages = self.chunk_text(
|
| 282 |
+
text=webtext, tokenizer=self.tokenizer, sentences_per_passage=self.sentences_per_passage
|
|
|
|
|
|
|
| 283 |
)
|
| 284 |
if not passages:
|
| 285 |
continue
|
|
|
|
| 301 |
overlap = True
|
| 302 |
break
|
| 303 |
|
| 304 |
+
# Only consider top non-overlapping relevant passages to maximise for information
|
| 305 |
if not overlap:
|
| 306 |
relevant_items.append(deepcopy(passage_item))
|
| 307 |
retrieved_passages.append(
|
src/openfactcheck/solvers/factcheckgpt/factcheckgpt_vfr.py
CHANGED
|
@@ -9,24 +9,22 @@ from .factcheckgpt_utils.data_util import save_to_file
|
|
| 9 |
from .factcheckgpt_utils.prompt import IDENTIFY_STANCE_PROMPT, IDENTIFY_STANCE_PROMPT_FUNC
|
| 10 |
from .factcheckgpt_utils.nli import nli_infer
|
| 11 |
|
|
|
|
| 12 |
@Solver.register("factcheckgpt_verifier", "claims_with_evidences", "label")
|
| 13 |
class FactCheckGPTVerifier(StandardTaskSolver):
|
| 14 |
def __init__(self, args):
|
| 15 |
super().__init__(args)
|
| 16 |
-
self.stance_model = args.get("stance_model", "gpt-
|
| 17 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 18 |
# self.system_role = args.get("system_role", "You are a helpful factchecker assistant.")
|
| 19 |
self.system_role = "You are a helpful factchecker assistant."
|
| 20 |
self.verify_retries = args.get("verify_retries", 3)
|
| 21 |
-
self.stance_map = {
|
| 22 |
-
1: "support",
|
| 23 |
-
-1: "refute",
|
| 24 |
-
0: "irrelevant"
|
| 25 |
-
}
|
| 26 |
|
| 27 |
def verify_by_stance(
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
) -> Any:
|
| 31 |
labels = []
|
| 32 |
for evidence in evidences:
|
|
@@ -45,12 +43,7 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 45 |
|
| 46 |
def identify_stance_gpt(self, evidence, claim):
|
| 47 |
user_input = IDENTIFY_STANCE_PROMPT_FUNC.format(claim=claim, evidence=evidence)
|
| 48 |
-
r = gpt(
|
| 49 |
-
user_input,
|
| 50 |
-
model=self.stance_model,
|
| 51 |
-
system_role=self.system_role,
|
| 52 |
-
num_retries=self.num_retries
|
| 53 |
-
)
|
| 54 |
label = 0
|
| 55 |
try:
|
| 56 |
label = eval(r)
|
|
@@ -58,9 +51,9 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 58 |
print(f"An unexpected error occurred: {e}.")
|
| 59 |
return label
|
| 60 |
|
| 61 |
-
def stance(self, evidence, claim, model="gpt-
|
| 62 |
"""input: a claim and an evidence
|
| 63 |
-
|
| 64 |
label = 0
|
| 65 |
if self.stance_model == "nli":
|
| 66 |
label = nli_infer(premise=evidence, hypothesis=claim)
|
|
@@ -73,7 +66,7 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 73 |
def verify_claim(self, claim: str, evidences: list[str]) -> dict[str, Any]:
|
| 74 |
results = None
|
| 75 |
user_input = VERIFY_PROMPT.format(claim=claim, evidence=evidences)
|
| 76 |
-
r =
|
| 77 |
for _ in range(self.verify_retries):
|
| 78 |
r = gpt(
|
| 79 |
user_input,
|
|
@@ -97,12 +90,7 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 97 |
else:
|
| 98 |
print(f"Error output {r}. It does not output a dict, return factual label by stance aggregation.")
|
| 99 |
factual_label = self.verify_by_stance(claim, evidences)
|
| 100 |
-
results = {
|
| 101 |
-
"reasoning": "",
|
| 102 |
-
"error": "",
|
| 103 |
-
"correction": "",
|
| 104 |
-
"factuality": factual_label
|
| 105 |
-
}
|
| 106 |
return results
|
| 107 |
|
| 108 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
|
@@ -113,6 +101,6 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 113 |
result["claim"] = claim
|
| 114 |
result["evidences"] = evidences
|
| 115 |
results.append(result)
|
| 116 |
-
state.set(self.output_name, all([x[
|
| 117 |
state.set("detail", results)
|
| 118 |
return True, state
|
|
|
|
| 9 |
from .factcheckgpt_utils.prompt import IDENTIFY_STANCE_PROMPT, IDENTIFY_STANCE_PROMPT_FUNC
|
| 10 |
from .factcheckgpt_utils.nli import nli_infer
|
| 11 |
|
| 12 |
+
|
| 13 |
@Solver.register("factcheckgpt_verifier", "claims_with_evidences", "label")
|
| 14 |
class FactCheckGPTVerifier(StandardTaskSolver):
|
| 15 |
def __init__(self, args):
|
| 16 |
super().__init__(args)
|
| 17 |
+
self.stance_model = args.get("stance_model", "gpt-4o")
|
| 18 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 19 |
# self.system_role = args.get("system_role", "You are a helpful factchecker assistant.")
|
| 20 |
self.system_role = "You are a helpful factchecker assistant."
|
| 21 |
self.verify_retries = args.get("verify_retries", 3)
|
| 22 |
+
self.stance_map = {1: "support", -1: "refute", 0: "irrelevant"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def verify_by_stance(
|
| 25 |
+
self,
|
| 26 |
+
claim: str,
|
| 27 |
+
evidences: list[str],
|
| 28 |
) -> Any:
|
| 29 |
labels = []
|
| 30 |
for evidence in evidences:
|
|
|
|
| 43 |
|
| 44 |
def identify_stance_gpt(self, evidence, claim):
|
| 45 |
user_input = IDENTIFY_STANCE_PROMPT_FUNC.format(claim=claim, evidence=evidence)
|
| 46 |
+
r = gpt(user_input, model=self.stance_model, system_role=self.system_role, num_retries=self.num_retries)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
label = 0
|
| 48 |
try:
|
| 49 |
label = eval(r)
|
|
|
|
| 51 |
print(f"An unexpected error occurred: {e}.")
|
| 52 |
return label
|
| 53 |
|
| 54 |
+
def stance(self, evidence, claim, model="gpt-4o"):
|
| 55 |
"""input: a claim and an evidence
|
| 56 |
+
output: label in [support, refute, irrelevant]"""
|
| 57 |
label = 0
|
| 58 |
if self.stance_model == "nli":
|
| 59 |
label = nli_infer(premise=evidence, hypothesis=claim)
|
|
|
|
| 66 |
def verify_claim(self, claim: str, evidences: list[str]) -> dict[str, Any]:
|
| 67 |
results = None
|
| 68 |
user_input = VERIFY_PROMPT.format(claim=claim, evidence=evidences)
|
| 69 |
+
r = ""
|
| 70 |
for _ in range(self.verify_retries):
|
| 71 |
r = gpt(
|
| 72 |
user_input,
|
|
|
|
| 90 |
else:
|
| 91 |
print(f"Error output {r}. It does not output a dict, return factual label by stance aggregation.")
|
| 92 |
factual_label = self.verify_by_stance(claim, evidences)
|
| 93 |
+
results = {"reasoning": "", "error": "", "correction": "", "factuality": factual_label}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
return results
|
| 95 |
|
| 96 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
|
|
|
| 101 |
result["claim"] = claim
|
| 102 |
result["evidences"] = evidences
|
| 103 |
results.append(result)
|
| 104 |
+
state.set(self.output_name, all([x["factuality"] > 0 for x in results]))
|
| 105 |
state.set("detail", results)
|
| 106 |
return True, state
|
src/openfactcheck/solvers/factool/factool_utils/chat_api.py
CHANGED
|
@@ -15,56 +15,57 @@ import openai
|
|
| 15 |
from openai import OpenAI, AsyncOpenAI
|
| 16 |
import re
|
| 17 |
|
| 18 |
-
|
|
|
|
| 19 |
def __init__(
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
):
|
| 27 |
-
if
|
| 28 |
openai.api_base = "http://localhost:8000/v1"
|
| 29 |
else:
|
| 30 |
-
#openai.api_base = "https://api.openai.com/v1"
|
| 31 |
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
| 32 |
assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable."
|
| 33 |
-
assert openai.api_key !=
|
| 34 |
self.client = AsyncOpenAI()
|
| 35 |
|
| 36 |
self.config = {
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
}
|
| 43 |
|
| 44 |
def extract_list_from_string(self, input_string):
|
| 45 |
-
# pattern = r'\[.*\]'
|
| 46 |
# result = re.search(pattern, input_string)
|
| 47 |
# if result:
|
| 48 |
# return result.group()
|
| 49 |
# else:
|
| 50 |
# return None
|
| 51 |
-
start_index = input_string.find(
|
| 52 |
-
end_index = input_string.rfind(
|
| 53 |
|
| 54 |
if start_index != -1 and end_index != -1 and start_index < end_index:
|
| 55 |
-
return input_string[start_index:end_index + 1]
|
| 56 |
else:
|
| 57 |
return None
|
| 58 |
-
|
| 59 |
def extract_dict_from_string(self, input_string):
|
| 60 |
-
start_index = input_string.find(
|
| 61 |
-
end_index = input_string.rfind(
|
| 62 |
|
| 63 |
if start_index != -1 and end_index != -1 and start_index < end_index:
|
| 64 |
-
return input_string[start_index:end_index + 1]
|
| 65 |
else:
|
| 66 |
return None
|
| 67 |
-
|
| 68 |
def _boolean_fix(self, output):
|
| 69 |
return output.replace("true", "True").replace("false", "False")
|
| 70 |
|
|
@@ -75,7 +76,7 @@ class OpenAIChat():
|
|
| 75 |
return None
|
| 76 |
return output_eval
|
| 77 |
except:
|
| 78 |
-
|
| 79 |
if(expected_type == List):
|
| 80 |
valid_output = self.extract_list_from_string(output)
|
| 81 |
output_eval = ast.literal_eval(valid_output)
|
|
@@ -88,46 +89,47 @@ class OpenAIChat():
|
|
| 88 |
if not isinstance(output_eval, expected_type):
|
| 89 |
return None
|
| 90 |
return output_eval
|
| 91 |
-
|
| 92 |
return None
|
| 93 |
|
| 94 |
-
async def dispatch_openai_requests(
|
|
|
|
|
|
|
|
|
|
| 95 |
"""
|
| 96 |
Dispatches requests to OpenAI API asynchronously.
|
| 97 |
-
|
| 98 |
Args:
|
| 99 |
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
|
| 100 |
Returns:
|
| 101 |
List of responses from OpenAI API.
|
| 102 |
"""
|
|
|
|
| 103 |
async def _request_with_retry(messages, retry=3):
|
| 104 |
for attempt in range(retry):
|
| 105 |
try:
|
| 106 |
response = await self.client.chat.completions.create(
|
| 107 |
-
model=self.config[
|
| 108 |
messages=messages,
|
| 109 |
-
max_tokens=self.config[
|
| 110 |
-
temperature=self.config[
|
| 111 |
-
top_p=self.config[
|
| 112 |
)
|
| 113 |
return response
|
| 114 |
except openai.RateLimitError as e:
|
| 115 |
-
await asyncio.sleep((2
|
| 116 |
except (openai.Timeout, openai.APIError) as e:
|
| 117 |
-
await asyncio.sleep((2
|
| 118 |
except Exception as e:
|
| 119 |
# Log unexpected exception for further investigation
|
| 120 |
-
await asyncio.sleep((2
|
| 121 |
-
|
| 122 |
raise RuntimeError("All retries failed for OpenAI API request")
|
| 123 |
|
| 124 |
-
async_responses = [
|
| 125 |
-
_request_with_retry(messages)
|
| 126 |
-
for messages in messages_list
|
| 127 |
-
]
|
| 128 |
|
| 129 |
return await asyncio.gather(*async_responses, return_exceptions=True)
|
| 130 |
-
|
| 131 |
def run(self, messages_list, expected_type):
|
| 132 |
retry = 1
|
| 133 |
responses = [None for _ in range(len(messages_list))]
|
|
@@ -135,24 +137,32 @@ class OpenAIChat():
|
|
| 135 |
|
| 136 |
while retry > 0 and len(messages_list_cur_index) > 0:
|
| 137 |
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
|
| 138 |
-
|
| 139 |
-
predictions = asyncio.run(self.dispatch_openai_requests(
|
| 140 |
-
messages_list=messages_list_cur,
|
| 141 |
-
))
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
finised_index = []
|
| 145 |
for i, pred in enumerate(preds):
|
| 146 |
if pred is not None:
|
| 147 |
responses[messages_list_cur_index[i]] = pred
|
| 148 |
finised_index.append(messages_list_cur_index[i])
|
| 149 |
-
|
| 150 |
messages_list_cur_index = [i for i in messages_list_cur_index if i not in finised_index]
|
| 151 |
-
|
| 152 |
retry -= 1
|
| 153 |
-
|
| 154 |
return responses
|
| 155 |
|
|
|
|
| 156 |
# class OpenAIEmbed():
|
| 157 |
# def __init__():
|
| 158 |
# openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
|
@@ -190,9 +200,9 @@ class OpenAIChat():
|
|
| 190 |
# ))
|
| 191 |
|
| 192 |
# print(predictions)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
| 15 |
from openai import OpenAI, AsyncOpenAI
|
| 16 |
import re
|
| 17 |
|
| 18 |
+
|
| 19 |
+
class OpenAIChat:
|
| 20 |
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
model_name="gpt-4o",
|
| 23 |
+
max_tokens=2500,
|
| 24 |
+
temperature=0,
|
| 25 |
+
top_p=1,
|
| 26 |
+
request_timeout=120,
|
| 27 |
):
|
| 28 |
+
if "gpt" not in model_name:
|
| 29 |
openai.api_base = "http://localhost:8000/v1"
|
| 30 |
else:
|
| 31 |
+
# openai.api_base = "https://api.openai.com/v1"
|
| 32 |
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
| 33 |
assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable."
|
| 34 |
+
assert openai.api_key != "", "Please set the OPENAI_API_KEY environment variable."
|
| 35 |
self.client = AsyncOpenAI()
|
| 36 |
|
| 37 |
self.config = {
|
| 38 |
+
"model_name": model_name,
|
| 39 |
+
"max_tokens": max_tokens,
|
| 40 |
+
"temperature": temperature,
|
| 41 |
+
"top_p": top_p,
|
| 42 |
+
"request_timeout": request_timeout,
|
| 43 |
}
|
| 44 |
|
| 45 |
def extract_list_from_string(self, input_string):
|
| 46 |
+
# pattern = r'\[.*\]'
|
| 47 |
# result = re.search(pattern, input_string)
|
| 48 |
# if result:
|
| 49 |
# return result.group()
|
| 50 |
# else:
|
| 51 |
# return None
|
| 52 |
+
start_index = input_string.find("[")
|
| 53 |
+
end_index = input_string.rfind("]")
|
| 54 |
|
| 55 |
if start_index != -1 and end_index != -1 and start_index < end_index:
|
| 56 |
+
return input_string[start_index : end_index + 1]
|
| 57 |
else:
|
| 58 |
return None
|
| 59 |
+
|
| 60 |
def extract_dict_from_string(self, input_string):
|
| 61 |
+
start_index = input_string.find("{")
|
| 62 |
+
end_index = input_string.rfind("}")
|
| 63 |
|
| 64 |
if start_index != -1 and end_index != -1 and start_index < end_index:
|
| 65 |
+
return input_string[start_index : end_index + 1]
|
| 66 |
else:
|
| 67 |
return None
|
| 68 |
+
|
| 69 |
def _boolean_fix(self, output):
|
| 70 |
return output.replace("true", "True").replace("false", "False")
|
| 71 |
|
|
|
|
| 76 |
return None
|
| 77 |
return output_eval
|
| 78 |
except:
|
| 79 |
+
"""
|
| 80 |
if(expected_type == List):
|
| 81 |
valid_output = self.extract_list_from_string(output)
|
| 82 |
output_eval = ast.literal_eval(valid_output)
|
|
|
|
| 89 |
if not isinstance(output_eval, expected_type):
|
| 90 |
return None
|
| 91 |
return output_eval
|
| 92 |
+
"""
|
| 93 |
return None
|
| 94 |
|
| 95 |
+
async def dispatch_openai_requests(
|
| 96 |
+
self,
|
| 97 |
+
messages_list,
|
| 98 |
+
) -> list[str]:
|
| 99 |
"""
|
| 100 |
Dispatches requests to OpenAI API asynchronously.
|
| 101 |
+
|
| 102 |
Args:
|
| 103 |
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
|
| 104 |
Returns:
|
| 105 |
List of responses from OpenAI API.
|
| 106 |
"""
|
| 107 |
+
|
| 108 |
async def _request_with_retry(messages, retry=3):
|
| 109 |
for attempt in range(retry):
|
| 110 |
try:
|
| 111 |
response = await self.client.chat.completions.create(
|
| 112 |
+
model=self.config["model_name"],
|
| 113 |
messages=messages,
|
| 114 |
+
max_tokens=self.config["max_tokens"],
|
| 115 |
+
temperature=self.config["temperature"],
|
| 116 |
+
top_p=self.config["top_p"],
|
| 117 |
)
|
| 118 |
return response
|
| 119 |
except openai.RateLimitError as e:
|
| 120 |
+
await asyncio.sleep((2**attempt) * 0.5) # exponential backoff
|
| 121 |
except (openai.Timeout, openai.APIError) as e:
|
| 122 |
+
await asyncio.sleep((2**attempt) * 0.5) # exponential backoff
|
| 123 |
except Exception as e:
|
| 124 |
# Log unexpected exception for further investigation
|
| 125 |
+
await asyncio.sleep((2**attempt) * 0.5) # fallback in case of unknown errors
|
| 126 |
+
|
| 127 |
raise RuntimeError("All retries failed for OpenAI API request")
|
| 128 |
|
| 129 |
+
async_responses = [_request_with_retry(messages) for messages in messages_list]
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
return await asyncio.gather(*async_responses, return_exceptions=True)
|
| 132 |
+
|
| 133 |
def run(self, messages_list, expected_type):
|
| 134 |
retry = 1
|
| 135 |
responses = [None for _ in range(len(messages_list))]
|
|
|
|
| 137 |
|
| 138 |
while retry > 0 and len(messages_list_cur_index) > 0:
|
| 139 |
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
predictions = asyncio.run(
|
| 142 |
+
self.dispatch_openai_requests(
|
| 143 |
+
messages_list=messages_list_cur,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
preds = [
|
| 148 |
+
self._type_check(self._boolean_fix(prediction.choices[0].message.content), expected_type)
|
| 149 |
+
if prediction is not None
|
| 150 |
+
else None
|
| 151 |
+
for prediction in predictions
|
| 152 |
+
]
|
| 153 |
finised_index = []
|
| 154 |
for i, pred in enumerate(preds):
|
| 155 |
if pred is not None:
|
| 156 |
responses[messages_list_cur_index[i]] = pred
|
| 157 |
finised_index.append(messages_list_cur_index[i])
|
| 158 |
+
|
| 159 |
messages_list_cur_index = [i for i in messages_list_cur_index if i not in finised_index]
|
| 160 |
+
|
| 161 |
retry -= 1
|
| 162 |
+
|
| 163 |
return responses
|
| 164 |
|
| 165 |
+
|
| 166 |
# class OpenAIEmbed():
|
| 167 |
# def __init__():
|
| 168 |
# openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
|
|
|
| 200 |
# ))
|
| 201 |
|
| 202 |
# print(predictions)
|
| 203 |
+
# Usage
|
| 204 |
+
# embed = OpenAIEmbed()
|
| 205 |
+
# batch = ["string1", "string2", "string3", "string4", "string5", "string6", "string7", "string8", "string9", "string10"] # Your batch of strings
|
| 206 |
+
# embeddings = asyncio.run(embed.process_batch(batch, retry=3))
|
| 207 |
+
# for embedding in embeddings:
|
| 208 |
+
# print(embedding["data"][0]["embedding"])
|
src/openfactcheck/solvers/rarr/rarr_agreement_gate.py
CHANGED
|
@@ -3,32 +3,34 @@ from .prompts import rarr_prompts
|
|
| 3 |
|
| 4 |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
| 5 |
|
|
|
|
| 6 |
@Solver.register("rarr_agreement_gate", "claims_with_evidences", "claims_with_gates")
|
| 7 |
class RARRAgreementGate(StandardTaskSolver):
|
| 8 |
def __init__(self, args):
|
| 9 |
super().__init__(args)
|
| 10 |
self.max_evidences_per_question = args.get("max_evidences_per_question", 1)
|
| 11 |
-
self.model = self.global_config.get("model", "gpt-
|
| 12 |
|
| 13 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 14 |
claims = state.get(self.input_name)
|
| 15 |
|
| 16 |
for claim, contents in claims.items():
|
| 17 |
context = contents.get("context", None)
|
| 18 |
-
evidences = contents.get("evidences", [])[:self.max_evidences_per_question]
|
| 19 |
gates = []
|
| 20 |
for evidence in evidences:
|
| 21 |
gate = agreement_gate.run_agreement_gate(
|
| 22 |
claim=claim,
|
| 23 |
context=context,
|
| 24 |
-
query=evidence[
|
| 25 |
-
evidence=evidence[
|
| 26 |
model=self.model,
|
| 27 |
prompt=rarr_prompts.CONTEXTUAL_AGREEMENT_GATE_PROMPT
|
| 28 |
-
if context
|
|
|
|
| 29 |
)
|
| 30 |
gates.append(gate)
|
| 31 |
-
contents[
|
| 32 |
|
| 33 |
state.set(self.output_name, claims)
|
| 34 |
return True, state
|
|
|
|
| 3 |
|
| 4 |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
| 5 |
|
| 6 |
+
|
| 7 |
@Solver.register("rarr_agreement_gate", "claims_with_evidences", "claims_with_gates")
|
| 8 |
class RARRAgreementGate(StandardTaskSolver):
|
| 9 |
def __init__(self, args):
|
| 10 |
super().__init__(args)
|
| 11 |
self.max_evidences_per_question = args.get("max_evidences_per_question", 1)
|
| 12 |
+
self.model = self.global_config.get("model", "gpt-4o-instruct")
|
| 13 |
|
| 14 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 15 |
claims = state.get(self.input_name)
|
| 16 |
|
| 17 |
for claim, contents in claims.items():
|
| 18 |
context = contents.get("context", None)
|
| 19 |
+
evidences = contents.get("evidences", [])[: self.max_evidences_per_question]
|
| 20 |
gates = []
|
| 21 |
for evidence in evidences:
|
| 22 |
gate = agreement_gate.run_agreement_gate(
|
| 23 |
claim=claim,
|
| 24 |
context=context,
|
| 25 |
+
query=evidence["query"],
|
| 26 |
+
evidence=evidence["text"],
|
| 27 |
model=self.model,
|
| 28 |
prompt=rarr_prompts.CONTEXTUAL_AGREEMENT_GATE_PROMPT
|
| 29 |
+
if context
|
| 30 |
+
else rarr_prompts.AGREEMENT_GATE_PROMPT,
|
| 31 |
)
|
| 32 |
gates.append(gate)
|
| 33 |
+
contents["gates"] = gates
|
| 34 |
|
| 35 |
state.set(self.output_name, claims)
|
| 36 |
return True, state
|
src/openfactcheck/solvers/rarr/rarr_editor.py
CHANGED
|
@@ -5,12 +5,13 @@ from .prompts import rarr_prompts
|
|
| 5 |
|
| 6 |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
| 7 |
|
|
|
|
| 8 |
@Solver.register("rarr_editor", "claims_with_evidences", "revised_claims")
|
| 9 |
class RARREditor(StandardTaskSolver):
|
| 10 |
def __init__(self, args):
|
| 11 |
super().__init__(args)
|
| 12 |
-
self.model = self.global_config.get("model", "gpt-
|
| 13 |
-
# self.model = args.get("model", "gpt-
|
| 14 |
self.max_evidences_per_question = args.get("max_evidences_per_question", 1)
|
| 15 |
self.max_edit_ratio = args.get("max_edit_ratio", 100)
|
| 16 |
self.output_claim_only = args.get("output_claim_only", False)
|
|
@@ -20,7 +21,7 @@ class RARREditor(StandardTaskSolver):
|
|
| 20 |
final_result = {}
|
| 21 |
for claim, contents in claims.items():
|
| 22 |
context = contents.get("context", None)
|
| 23 |
-
evidences = contents.get("evidences", [])[:self.max_evidences_per_question]
|
| 24 |
agreement_gates = []
|
| 25 |
revision_steps = []
|
| 26 |
claim_for_iterative_revision = claim
|
|
@@ -28,32 +29,31 @@ class RARREditor(StandardTaskSolver):
|
|
| 28 |
gate = agreement_gate.run_agreement_gate(
|
| 29 |
claim=claim_for_iterative_revision,
|
| 30 |
context=context,
|
| 31 |
-
query=evidence[
|
| 32 |
-
evidence=evidence[
|
| 33 |
model=self.model,
|
| 34 |
prompt=rarr_prompts.CONTEXTUAL_AGREEMENT_GATE_PROMPT
|
| 35 |
-
if context
|
|
|
|
| 36 |
)
|
| 37 |
agreement_gates.append(gate)
|
| 38 |
|
| 39 |
-
if gate[
|
| 40 |
edited_claim = editor.run_rarr_editor(
|
| 41 |
claim=claim_for_iterative_revision,
|
| 42 |
context=context,
|
| 43 |
-
query=evidence[
|
| 44 |
-
evidence=evidence[
|
| 45 |
model=self.model,
|
| 46 |
-
prompt=rarr_prompts.CONTEXTUAL_EDITOR_PROMPT
|
| 47 |
-
|
| 48 |
-
else rarr_prompts.EDITOR_PROMPT,
|
| 49 |
-
)['text']
|
| 50 |
if Levenshtein.distance(claim, edited_claim) / len(claim) <= self.max_edit_ratio:
|
| 51 |
claim_for_iterative_revision = edited_claim
|
| 52 |
revision_steps.append({"text": claim_for_iterative_revision})
|
| 53 |
result = {
|
| 54 |
"context": context,
|
| 55 |
"text": claim,
|
| 56 |
-
"questions": contents[
|
| 57 |
"evidences_for_questions": evidences,
|
| 58 |
"revisions": [
|
| 59 |
{
|
|
@@ -66,7 +66,7 @@ class RARREditor(StandardTaskSolver):
|
|
| 66 |
],
|
| 67 |
}
|
| 68 |
selected_evidences = evidence_selection.select_evidences(result)
|
| 69 |
-
result[
|
| 70 |
-
final_result[claim] = result[
|
| 71 |
state.set(self.output_name, final_result)
|
| 72 |
return True, state
|
|
|
|
| 5 |
|
| 6 |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
| 7 |
|
| 8 |
+
|
| 9 |
@Solver.register("rarr_editor", "claims_with_evidences", "revised_claims")
|
| 10 |
class RARREditor(StandardTaskSolver):
|
| 11 |
def __init__(self, args):
|
| 12 |
super().__init__(args)
|
| 13 |
+
self.model = self.global_config.get("model", "gpt-4o-instruct")
|
| 14 |
+
# self.model = args.get("model", "gpt-4o-instruct")
|
| 15 |
self.max_evidences_per_question = args.get("max_evidences_per_question", 1)
|
| 16 |
self.max_edit_ratio = args.get("max_edit_ratio", 100)
|
| 17 |
self.output_claim_only = args.get("output_claim_only", False)
|
|
|
|
| 21 |
final_result = {}
|
| 22 |
for claim, contents in claims.items():
|
| 23 |
context = contents.get("context", None)
|
| 24 |
+
evidences = contents.get("evidences", [])[: self.max_evidences_per_question]
|
| 25 |
agreement_gates = []
|
| 26 |
revision_steps = []
|
| 27 |
claim_for_iterative_revision = claim
|
|
|
|
| 29 |
gate = agreement_gate.run_agreement_gate(
|
| 30 |
claim=claim_for_iterative_revision,
|
| 31 |
context=context,
|
| 32 |
+
query=evidence["query"],
|
| 33 |
+
evidence=evidence["text"],
|
| 34 |
model=self.model,
|
| 35 |
prompt=rarr_prompts.CONTEXTUAL_AGREEMENT_GATE_PROMPT
|
| 36 |
+
if context
|
| 37 |
+
else rarr_prompts.AGREEMENT_GATE_PROMPT,
|
| 38 |
)
|
| 39 |
agreement_gates.append(gate)
|
| 40 |
|
| 41 |
+
if gate["is_open"]:
|
| 42 |
edited_claim = editor.run_rarr_editor(
|
| 43 |
claim=claim_for_iterative_revision,
|
| 44 |
context=context,
|
| 45 |
+
query=evidence["query"],
|
| 46 |
+
evidence=evidence["text"],
|
| 47 |
model=self.model,
|
| 48 |
+
prompt=rarr_prompts.CONTEXTUAL_EDITOR_PROMPT if context else rarr_prompts.EDITOR_PROMPT,
|
| 49 |
+
)["text"]
|
|
|
|
|
|
|
| 50 |
if Levenshtein.distance(claim, edited_claim) / len(claim) <= self.max_edit_ratio:
|
| 51 |
claim_for_iterative_revision = edited_claim
|
| 52 |
revision_steps.append({"text": claim_for_iterative_revision})
|
| 53 |
result = {
|
| 54 |
"context": context,
|
| 55 |
"text": claim,
|
| 56 |
+
"questions": contents["questions"],
|
| 57 |
"evidences_for_questions": evidences,
|
| 58 |
"revisions": [
|
| 59 |
{
|
|
|
|
| 66 |
],
|
| 67 |
}
|
| 68 |
selected_evidences = evidence_selection.select_evidences(result)
|
| 69 |
+
result["selected_evidences"] = selected_evidences
|
| 70 |
+
final_result[claim] = result["revisions"][0]["revised_text"] if self.output_claim_only else result
|
| 71 |
state.set(self.output_name, final_result)
|
| 72 |
return True, state
|
src/openfactcheck/solvers/rarr/rarr_llm_retriever.py
CHANGED
|
@@ -3,11 +3,12 @@ from .prompts.hallucination_prompts import EVIDENCE_HALLUCINATION
|
|
| 3 |
|
| 4 |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
| 5 |
|
|
|
|
| 6 |
@Solver.register("llm_retriever", "claims_with_questions", "claims_with_evidences")
|
| 7 |
class RARRLLMRetriever(StandardTaskSolver):
|
| 8 |
def __init__(self, args):
|
| 9 |
super().__init__(args)
|
| 10 |
-
self.model = self.global_config.get("model", "gpt-
|
| 11 |
|
| 12 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 13 |
claims = state.get(self.input_name)
|
|
@@ -16,14 +17,8 @@ class RARRLLMRetriever(StandardTaskSolver):
|
|
| 16 |
questions = contents.get("questions", [])
|
| 17 |
evidences = []
|
| 18 |
for question in questions:
|
| 19 |
-
evidences.append(
|
| 20 |
-
|
| 21 |
-
question,
|
| 22 |
-
model=self.model,
|
| 23 |
-
prompt=EVIDENCE_HALLUCINATION
|
| 24 |
-
)
|
| 25 |
-
)
|
| 26 |
-
claims[claim]['evidences'] = evidences
|
| 27 |
|
| 28 |
state.set(self.output_name, claims)
|
| 29 |
return True, state
|
|
|
|
| 3 |
|
| 4 |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
| 5 |
|
| 6 |
+
|
| 7 |
@Solver.register("llm_retriever", "claims_with_questions", "claims_with_evidences")
|
| 8 |
class RARRLLMRetriever(StandardTaskSolver):
|
| 9 |
def __init__(self, args):
|
| 10 |
super().__init__(args)
|
| 11 |
+
self.model = self.global_config.get("model", "gpt-4o-instruct")
|
| 12 |
|
| 13 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 14 |
claims = state.get(self.input_name)
|
|
|
|
| 17 |
questions = contents.get("questions", [])
|
| 18 |
evidences = []
|
| 19 |
for question in questions:
|
| 20 |
+
evidences.append(run_evidence_hallucination(question, model=self.model, prompt=EVIDENCE_HALLUCINATION))
|
| 21 |
+
claims[claim]["evidences"] = evidences
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
state.set(self.output_name, claims)
|
| 24 |
return True, state
|
src/openfactcheck/solvers/rarr/rarr_question_generator.py
CHANGED
|
@@ -3,11 +3,12 @@ from .prompts import rarr_prompts
|
|
| 3 |
|
| 4 |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
| 5 |
|
|
|
|
| 6 |
@Solver.register("rarr_question_generator", "claims_with_context", "claims_with_questions")
|
| 7 |
class RARRQuestionGenerator(StandardTaskSolver):
|
| 8 |
def __init__(self, args):
|
| 9 |
super().__init__(args)
|
| 10 |
-
self.model = self.global_config.get("model", "gpt-
|
| 11 |
self.temperature_qgen = args.get("temperature_qgen", 0.7)
|
| 12 |
self.num_rounds_qgen = args.get("num_rounds_qgen", 3)
|
| 13 |
|
|
@@ -18,13 +19,11 @@ class RARRQuestionGenerator(StandardTaskSolver):
|
|
| 18 |
claims = {c: dict() for c in claims}
|
| 19 |
for claim, contents in claims.items():
|
| 20 |
context = contents.get("context", None)
|
| 21 |
-
claims[claim][
|
| 22 |
claim=claim,
|
| 23 |
context=context,
|
| 24 |
model=self.model,
|
| 25 |
-
prompt=rarr_prompts.CONTEXTUAL_QGEN_PROMPT
|
| 26 |
-
if context
|
| 27 |
-
else rarr_prompts.QGEN_PROMPT,
|
| 28 |
temperature=self.temperature_qgen,
|
| 29 |
num_rounds=self.num_rounds_qgen,
|
| 30 |
)
|
|
|
|
| 3 |
|
| 4 |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
|
| 5 |
|
| 6 |
+
|
| 7 |
@Solver.register("rarr_question_generator", "claims_with_context", "claims_with_questions")
|
| 8 |
class RARRQuestionGenerator(StandardTaskSolver):
|
| 9 |
def __init__(self, args):
|
| 10 |
super().__init__(args)
|
| 11 |
+
self.model = self.global_config.get("model", "gpt-4o-instruct")
|
| 12 |
self.temperature_qgen = args.get("temperature_qgen", 0.7)
|
| 13 |
self.num_rounds_qgen = args.get("num_rounds_qgen", 3)
|
| 14 |
|
|
|
|
| 19 |
claims = {c: dict() for c in claims}
|
| 20 |
for claim, contents in claims.items():
|
| 21 |
context = contents.get("context", None)
|
| 22 |
+
claims[claim]["questions"] = run_rarr_question_generation(
|
| 23 |
claim=claim,
|
| 24 |
context=context,
|
| 25 |
model=self.model,
|
| 26 |
+
prompt=rarr_prompts.CONTEXTUAL_QGEN_PROMPT if context else rarr_prompts.QGEN_PROMPT,
|
|
|
|
|
|
|
| 27 |
temperature=self.temperature_qgen,
|
| 28 |
num_rounds=self.num_rounds_qgen,
|
| 29 |
)
|
src/openfactcheck/solvers/tutorial/utils/api.py
CHANGED
|
@@ -9,32 +9,36 @@ from typing import Any, Dict, List, Tuple
|
|
| 9 |
# OpenAI ChatGPT and davicci-text
|
| 10 |
# ----------------------------------------------------------
|
| 11 |
client = None
|
|
|
|
|
|
|
| 12 |
def init_client():
|
| 13 |
global client
|
| 14 |
if client is None:
|
| 15 |
-
if openai.api_key is None and
|
| 16 |
print("openai_key not presented, delay to initialize.")
|
| 17 |
return
|
| 18 |
client = OpenAI()
|
| 19 |
|
|
|
|
| 20 |
def chatgpt(user_input):
|
| 21 |
response = client.chat.completions.create(
|
| 22 |
-
model="gpt-
|
| 23 |
messages=[
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
]
|
| 27 |
)
|
| 28 |
|
| 29 |
-
result =
|
| 30 |
for choice in response.choices:
|
| 31 |
result += choice.message.content
|
| 32 |
|
| 33 |
return result
|
| 34 |
|
|
|
|
| 35 |
def davinci(prompt):
|
| 36 |
# Set up the model and prompt
|
| 37 |
-
model_engine = "gpt-
|
| 38 |
|
| 39 |
# Generate a response
|
| 40 |
completion = client.completions.create(
|
|
@@ -49,11 +53,13 @@ def davinci(prompt):
|
|
| 49 |
response = completion.choices[0].text
|
| 50 |
return response
|
| 51 |
|
|
|
|
| 52 |
# ----------------------------------------------------------
|
| 53 |
# Bing Search
|
| 54 |
# ----------------------------------------------------------
|
| 55 |
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search/"
|
| 56 |
-
SUBSCRIPTION_KEY = ""
|
|
|
|
| 57 |
|
| 58 |
def search_bing(query: str, timeout: float = 3) -> List[str]:
|
| 59 |
"""Searches the query using Bing.
|
|
@@ -63,7 +69,7 @@ def search_bing(query: str, timeout: float = 3) -> List[str]:
|
|
| 63 |
Returns:
|
| 64 |
search_results: A list of the top URLs relevant to the query.
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
headers = {"Ocp-Apim-Subscription-Key": SUBSCRIPTION_KEY}
|
| 68 |
params = {"q": query, "textDecorations": True, "textFormat": "HTML"}
|
| 69 |
response = requests.get(BING_SEARCH_URL, headers=headers, params=params, timeout=timeout)
|
|
@@ -73,7 +79,8 @@ def search_bing(query: str, timeout: float = 3) -> List[str]:
|
|
| 73 |
search_results = [r["url"] for r in response["webPages"]["value"]]
|
| 74 |
return search_results
|
| 75 |
|
| 76 |
-
|
|
|
|
| 77 |
# search_results = search_bing("What are the different awards that Preslav Nakov has received")
|
| 78 |
# print(search_results)
|
| 79 |
|
|
@@ -81,7 +88,7 @@ def search_bing(query: str, timeout: float = 3) -> List[str]:
|
|
| 81 |
# ----------------------------------------------------------
|
| 82 |
# Google Search
|
| 83 |
# ----------------------------------------------------------
|
| 84 |
-
def search_google(query: str, num_web_pages: int = 10, save_url: str =
|
| 85 |
"""Searches the query using Google.
|
| 86 |
Args:
|
| 87 |
query: Search query.
|
|
@@ -97,13 +104,13 @@ def search_google(query: str, num_web_pages: int = 10, save_url: str = '') -> Li
|
|
| 97 |
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.14; rv:65.0) Gecko/20100101 Firefox/65.0"
|
| 98 |
# mobile user-agent
|
| 99 |
MOBILE_USER_AGENT = "Mozilla/5.0 (Linux; Android 7.0; SM-G930V Build/NRD90M) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.125 Mobile Safari/537.36"
|
| 100 |
-
headers = {
|
| 101 |
-
|
| 102 |
# set language
|
| 103 |
# set the Google interface language, use &hl=XX
|
| 104 |
# set the preferred language of the search results, use &lr=lang_XX
|
| 105 |
# set language as en, otherwise it will return many translation web pages to Arabic that can't be opened correctly.
|
| 106 |
-
lang = "en"
|
| 107 |
|
| 108 |
# scrape google results
|
| 109 |
urls = []
|
|
@@ -121,11 +128,12 @@ def search_google(query: str, num_web_pages: int = 10, save_url: str = '') -> Li
|
|
| 121 |
|
| 122 |
# save all url into a txt file
|
| 123 |
if not save_url == "":
|
| 124 |
-
with open(save_url,
|
| 125 |
for url in urls:
|
| 126 |
-
file.write(url +
|
| 127 |
return urls
|
| 128 |
|
|
|
|
| 129 |
# Test google search
|
| 130 |
# query = "Google Company Introduction"
|
| 131 |
# urls = search_google(query)
|
|
|
|
| 9 |
# OpenAI ChatGPT and davicci-text
|
| 10 |
# ----------------------------------------------------------
|
| 11 |
client = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
def init_client():
|
| 15 |
global client
|
| 16 |
if client is None:
|
| 17 |
+
if openai.api_key is None and "OPENAI_API_KEY" not in os.environ:
|
| 18 |
print("openai_key not presented, delay to initialize.")
|
| 19 |
return
|
| 20 |
client = OpenAI()
|
| 21 |
|
| 22 |
+
|
| 23 |
def chatgpt(user_input):
|
| 24 |
response = client.chat.completions.create(
|
| 25 |
+
model="gpt-4o",
|
| 26 |
messages=[
|
| 27 |
+
{"role": "system", "content": "You are a NLP expert that is good at fact checking"},
|
| 28 |
+
{"role": "user", "content": user_input},
|
| 29 |
+
],
|
| 30 |
)
|
| 31 |
|
| 32 |
+
result = ""
|
| 33 |
for choice in response.choices:
|
| 34 |
result += choice.message.content
|
| 35 |
|
| 36 |
return result
|
| 37 |
|
| 38 |
+
|
| 39 |
def davinci(prompt):
|
| 40 |
# Set up the model and prompt
|
| 41 |
+
model_engine = "gpt-4o-instruct"
|
| 42 |
|
| 43 |
# Generate a response
|
| 44 |
completion = client.completions.create(
|
|
|
|
| 53 |
response = completion.choices[0].text
|
| 54 |
return response
|
| 55 |
|
| 56 |
+
|
| 57 |
# ----------------------------------------------------------
|
| 58 |
# Bing Search
|
| 59 |
# ----------------------------------------------------------
|
| 60 |
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search/"
|
| 61 |
+
SUBSCRIPTION_KEY = "" # fill your bing api key
|
| 62 |
+
|
| 63 |
|
| 64 |
def search_bing(query: str, timeout: float = 3) -> List[str]:
|
| 65 |
"""Searches the query using Bing.
|
|
|
|
| 69 |
Returns:
|
| 70 |
search_results: A list of the top URLs relevant to the query.
|
| 71 |
"""
|
| 72 |
+
|
| 73 |
headers = {"Ocp-Apim-Subscription-Key": SUBSCRIPTION_KEY}
|
| 74 |
params = {"q": query, "textDecorations": True, "textFormat": "HTML"}
|
| 75 |
response = requests.get(BING_SEARCH_URL, headers=headers, params=params, timeout=timeout)
|
|
|
|
| 79 |
search_results = [r["url"] for r in response["webPages"]["value"]]
|
| 80 |
return search_results
|
| 81 |
|
| 82 |
+
|
| 83 |
+
# Test Bing search
|
| 84 |
# search_results = search_bing("What are the different awards that Preslav Nakov has received")
|
| 85 |
# print(search_results)
|
| 86 |
|
|
|
|
| 88 |
# ----------------------------------------------------------
|
| 89 |
# Google Search
|
| 90 |
# ----------------------------------------------------------
|
| 91 |
+
def search_google(query: str, num_web_pages: int = 10, save_url: str = "") -> List[str]:
|
| 92 |
"""Searches the query using Google.
|
| 93 |
Args:
|
| 94 |
query: Search query.
|
|
|
|
| 104 |
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.14; rv:65.0) Gecko/20100101 Firefox/65.0"
|
| 105 |
# mobile user-agent
|
| 106 |
MOBILE_USER_AGENT = "Mozilla/5.0 (Linux; Android 7.0; SM-G930V Build/NRD90M) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.125 Mobile Safari/537.36"
|
| 107 |
+
headers = {"User-Agent": USER_AGENT}
|
| 108 |
+
|
| 109 |
# set language
|
| 110 |
# set the Google interface language, use &hl=XX
|
| 111 |
# set the preferred language of the search results, use &lr=lang_XX
|
| 112 |
# set language as en, otherwise it will return many translation web pages to Arabic that can't be opened correctly.
|
| 113 |
+
lang = "en"
|
| 114 |
|
| 115 |
# scrape google results
|
| 116 |
urls = []
|
|
|
|
| 128 |
|
| 129 |
# save all url into a txt file
|
| 130 |
if not save_url == "":
|
| 131 |
+
with open(save_url, "w") as file:
|
| 132 |
for url in urls:
|
| 133 |
+
file.write(url + "\n")
|
| 134 |
return urls
|
| 135 |
|
| 136 |
+
|
| 137 |
# Test google search
|
| 138 |
# query = "Google Company Introduction"
|
| 139 |
# urls = search_google(query)
|
src/openfactcheck/solvers/webservice/factcheckgpt_cp.py
CHANGED
|
@@ -6,14 +6,20 @@ from openfactcheck.solver import StandardTaskSolver, Solver
|
|
| 6 |
|
| 7 |
from .factcheckgpt_utils.openai_api import gpt
|
| 8 |
from .factcheckgpt_utils.data_util import save_to_file
|
| 9 |
-
from .factcheckgpt_utils.prompt import
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
@Solver.register("factcheckgpt_claimprocessor", "response", "claims")
|
| 13 |
class FactCheckGPTClaimProcessor(StandardTaskSolver):
|
| 14 |
def __init__(self, args):
|
| 15 |
super().__init__(args)
|
| 16 |
-
self.model = self.global_config.get("factcheckgpt_model", "gpt-
|
| 17 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 18 |
self.mode = args.get("mode", "independent_sentences")
|
| 19 |
self.decompose_system_role = "You are good at decomposing and decontextualizing text."
|
|
@@ -23,19 +29,19 @@ class FactCheckGPTClaimProcessor(StandardTaskSolver):
|
|
| 23 |
self.prompt = {
|
| 24 |
"sentences": DOC_TO_SENTENCES_PROMPT,
|
| 25 |
"independent_sentences": DOC_TO_INDEPEDENT_SENTENCES_PROMPT,
|
| 26 |
-
"claims": SENTENCES_TO_CLAIMS_PROMPT
|
| 27 |
}.get(self.mode, DOC_TO_INDEPEDENT_SENTENCES_PROMPT)
|
| 28 |
nlp = spacy.load(self.spacy_model)
|
| 29 |
self.rule_based_tool = {
|
| 30 |
"nltk": lambda x: [x.strip() for x in nltk.sent_tokenize(x) if len(x.strip()) >= 3],
|
| 31 |
-
"spacy": lambda x: [x.text.strip() for x in nlp(x).sents if len(x.text.strip()) >= 3]
|
| 32 |
}.get(self.rule_based_method, "nltk")
|
| 33 |
|
| 34 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 35 |
# We have merged the text decomposer and worthiness filter here.
|
| 36 |
response = state.get(self.input_name)
|
| 37 |
claims = [response]
|
| 38 |
-
|
| 39 |
user_input = self.prompt.format(doc=response).strip()
|
| 40 |
r = gpt(user_input, model=self.model, system_role=self.decompose_system_role, num_retries=self.num_retries)
|
| 41 |
try:
|
|
@@ -46,13 +52,15 @@ class FactCheckGPTClaimProcessor(StandardTaskSolver):
|
|
| 46 |
|
| 47 |
if not isinstance(claims, list):
|
| 48 |
print(
|
| 49 |
-
f"{self.model} output {r}. It does not output a list of sentences correctly, return rule-based split results."
|
|
|
|
| 50 |
claims = self.rule_based_tool(response)
|
| 51 |
-
|
| 52 |
worthiness = [True] * len(claims)
|
| 53 |
user_input = CHECKWORTHY_PROMPT_BOOL.format(claims=claims)
|
| 54 |
-
response = gpt(
|
| 55 |
-
|
|
|
|
| 56 |
# TODO refine check worthiness prompt, value returned not reasonable.
|
| 57 |
try:
|
| 58 |
worthiness = eval(response)
|
|
|
|
| 6 |
|
| 7 |
from .factcheckgpt_utils.openai_api import gpt
|
| 8 |
from .factcheckgpt_utils.data_util import save_to_file
|
| 9 |
+
from .factcheckgpt_utils.prompt import (
|
| 10 |
+
DOC_TO_INDEPEDENT_SENTENCES_PROMPT,
|
| 11 |
+
SENTENCES_TO_CLAIMS_PROMPT,
|
| 12 |
+
DOC_TO_SENTENCES_PROMPT,
|
| 13 |
+
CHECKWORTHY_PROMPT_BOOL,
|
| 14 |
+
SPECIFY_CHECKWORTHY_CATEGORY_PROMPT,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
|
| 18 |
@Solver.register("factcheckgpt_claimprocessor", "response", "claims")
|
| 19 |
class FactCheckGPTClaimProcessor(StandardTaskSolver):
|
| 20 |
def __init__(self, args):
|
| 21 |
super().__init__(args)
|
| 22 |
+
self.model = self.global_config.get("factcheckgpt_model", "gpt-4o")
|
| 23 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 24 |
self.mode = args.get("mode", "independent_sentences")
|
| 25 |
self.decompose_system_role = "You are good at decomposing and decontextualizing text."
|
|
|
|
| 29 |
self.prompt = {
|
| 30 |
"sentences": DOC_TO_SENTENCES_PROMPT,
|
| 31 |
"independent_sentences": DOC_TO_INDEPEDENT_SENTENCES_PROMPT,
|
| 32 |
+
"claims": SENTENCES_TO_CLAIMS_PROMPT,
|
| 33 |
}.get(self.mode, DOC_TO_INDEPEDENT_SENTENCES_PROMPT)
|
| 34 |
nlp = spacy.load(self.spacy_model)
|
| 35 |
self.rule_based_tool = {
|
| 36 |
"nltk": lambda x: [x.strip() for x in nltk.sent_tokenize(x) if len(x.strip()) >= 3],
|
| 37 |
+
"spacy": lambda x: [x.text.strip() for x in nlp(x).sents if len(x.text.strip()) >= 3],
|
| 38 |
}.get(self.rule_based_method, "nltk")
|
| 39 |
|
| 40 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 41 |
# We have merged the text decomposer and worthiness filter here.
|
| 42 |
response = state.get(self.input_name)
|
| 43 |
claims = [response]
|
| 44 |
+
|
| 45 |
user_input = self.prompt.format(doc=response).strip()
|
| 46 |
r = gpt(user_input, model=self.model, system_role=self.decompose_system_role, num_retries=self.num_retries)
|
| 47 |
try:
|
|
|
|
| 52 |
|
| 53 |
if not isinstance(claims, list):
|
| 54 |
print(
|
| 55 |
+
f"{self.model} output {r}. It does not output a list of sentences correctly, return rule-based split results."
|
| 56 |
+
)
|
| 57 |
claims = self.rule_based_tool(response)
|
| 58 |
+
|
| 59 |
worthiness = [True] * len(claims)
|
| 60 |
user_input = CHECKWORTHY_PROMPT_BOOL.format(claims=claims)
|
| 61 |
+
response = gpt(
|
| 62 |
+
user_input, model=self.model, system_role=self.worthines_filter_system_role, num_retries=self.num_retries
|
| 63 |
+
)
|
| 64 |
# TODO refine check worthiness prompt, value returned not reasonable.
|
| 65 |
try:
|
| 66 |
worthiness = eval(response)
|
src/openfactcheck/solvers/webservice/factcheckgpt_rtv.py
CHANGED
|
@@ -18,16 +18,16 @@ from .factcheckgpt_utils.openai_api import gpt
|
|
| 18 |
from .factcheckgpt_utils.prompt import QGEN_PROMPT, QGEN_PROMPT_FMT
|
| 19 |
from .factcheckgpt_utils.data_util import save_txt, save_json
|
| 20 |
|
|
|
|
| 21 |
@Solver.register("factcheckgpt_retriever", "claims", "claims_with_evidences")
|
| 22 |
class FactCheckGPTRetriever(StandardTaskSolver):
|
| 23 |
def __init__(self, args):
|
| 24 |
super().__init__(args)
|
| 25 |
-
self.model = self.global_config.get("factcheckgpt_model", "gpt-
|
| 26 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 27 |
self.tokenizer = spacy.load("en_core_web_sm", disable=["ner", "tagger", "lemmatizer"])
|
| 28 |
self.question_duplicate_model = CrossEncoder(
|
| 29 |
-
|
| 30 |
-
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 31 |
)
|
| 32 |
self.passage_ranker = CrossEncoder(
|
| 33 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
@@ -45,7 +45,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 45 |
self.sentences_per_passage = args.get("sentences_per_passage", 5)
|
| 46 |
self.max_passages_per_question = args.get("max_passages_per_question", 5)
|
| 47 |
self.max_aggregated_evidences = args.get("max_aggregated_evidences", 5)
|
| 48 |
-
self.question_persist_path = args.get("question_persist_path",
|
| 49 |
self.snippets_persist_path = args.get("snippets_persist_path", "passage.json")
|
| 50 |
|
| 51 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
|
@@ -53,7 +53,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 53 |
claims_with_evidences = {}
|
| 54 |
for i, claim in enumerate(claims):
|
| 55 |
evidences = self.get_web_evidences_for_claim(claim)
|
| 56 |
-
claims_with_evidences[claim] = [(q, e[
|
| 57 |
state.set(self.output_name, claims_with_evidences)
|
| 58 |
return True, state
|
| 59 |
|
|
@@ -70,11 +70,9 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 70 |
snippets = {}
|
| 71 |
for question in questions:
|
| 72 |
retrieved_passages = self.get_relevant_snippets(question)
|
| 73 |
-
snippets[question] = sorted(
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
reverse=True
|
| 77 |
-
)[:self.max_passages_per_question]
|
| 78 |
save_json(snippets, self.snippets_persist_path)
|
| 79 |
return snippets
|
| 80 |
|
|
@@ -111,7 +109,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 111 |
model=self.model,
|
| 112 |
system_role=self.qgen_system_role,
|
| 113 |
num_retries=self.num_retries,
|
| 114 |
-
temperature=self.qgen_temp
|
| 115 |
)
|
| 116 |
try:
|
| 117 |
cur_round_questions = set(eval(response))
|
|
@@ -183,8 +181,8 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 183 |
return False
|
| 184 |
return True
|
| 185 |
|
| 186 |
-
def search_google(self, query: str, num_web_pages: int = 10, timeout: int = 6, save_url: str =
|
| 187 |
-
"""Searches the query using Google.
|
| 188 |
Args:
|
| 189 |
query: Search query.
|
| 190 |
num_web_pages: the number of web pages to request.
|
|
@@ -199,7 +197,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 199 |
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.14; rv:65.0) Gecko/20100101 Firefox/65.0"
|
| 200 |
# mobile user-agent
|
| 201 |
MOBILE_USER_AGENT = "Mozilla/5.0 (Linux; Android 7.0; SM-G930V Build/NRD90M) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.125 Mobile Safari/537.36"
|
| 202 |
-
headers = {
|
| 203 |
|
| 204 |
# set language
|
| 205 |
# set the Google interface language, use &hl=XX
|
|
@@ -223,18 +221,18 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 223 |
|
| 224 |
# save all url into a txt file
|
| 225 |
if not save_url == "":
|
| 226 |
-
with open(save_url,
|
| 227 |
for url in urls:
|
| 228 |
-
file.write(url +
|
| 229 |
return urls
|
| 230 |
|
| 231 |
def chunk_text(
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
) -> list[str]:
|
| 239 |
"""Chunks text into passages using a sliding window.
|
| 240 |
|
|
@@ -261,15 +259,16 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 261 |
]
|
| 262 |
for idx in range(0, len(sents), sliding_distance):
|
| 263 |
passages.append(
|
| 264 |
-
(" ".join(sents[idx: idx + sentences_per_passage]), idx, idx + sentences_per_passage - 1)
|
|
|
|
| 265 |
except UnicodeEncodeError as _: # Sometimes run into Unicode error when tokenizing.
|
| 266 |
print("Unicode error when using Spacy. Skipping text.")
|
| 267 |
|
| 268 |
return passages
|
| 269 |
|
| 270 |
def get_relevant_snippets(
|
| 271 |
-
|
| 272 |
-
|
| 273 |
):
|
| 274 |
search_results = self.search_google(query, timeout=self.search_timeout)
|
| 275 |
|
|
@@ -279,11 +278,9 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 279 |
scraped_results = [r for r in scraped_results if r[0] and ".pdf" not in r[1]]
|
| 280 |
# print("Num Bing Search Results: ", len(scraped_results))
|
| 281 |
retrieved_passages = list()
|
| 282 |
-
for webtext, url in scraped_results[:self.max_search_results_per_query]:
|
| 283 |
passages = self.chunk_text(
|
| 284 |
-
text=webtext,
|
| 285 |
-
tokenizer=self.tokenizer,
|
| 286 |
-
sentences_per_passage=self.sentences_per_passage
|
| 287 |
)
|
| 288 |
if not passages:
|
| 289 |
continue
|
|
@@ -305,7 +302,7 @@ class FactCheckGPTRetriever(StandardTaskSolver):
|
|
| 305 |
overlap = True
|
| 306 |
break
|
| 307 |
|
| 308 |
-
# Only consider top non-overlapping relevant passages to maximise for information
|
| 309 |
if not overlap:
|
| 310 |
relevant_items.append(deepcopy(passage_item))
|
| 311 |
retrieved_passages.append(
|
|
|
|
| 18 |
from .factcheckgpt_utils.prompt import QGEN_PROMPT, QGEN_PROMPT_FMT
|
| 19 |
from .factcheckgpt_utils.data_util import save_txt, save_json
|
| 20 |
|
| 21 |
+
|
| 22 |
@Solver.register("factcheckgpt_retriever", "claims", "claims_with_evidences")
|
| 23 |
class FactCheckGPTRetriever(StandardTaskSolver):
|
| 24 |
def __init__(self, args):
|
| 25 |
super().__init__(args)
|
| 26 |
+
self.model = self.global_config.get("factcheckgpt_model", "gpt-4o")
|
| 27 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 28 |
self.tokenizer = spacy.load("en_core_web_sm", disable=["ner", "tagger", "lemmatizer"])
|
| 29 |
self.question_duplicate_model = CrossEncoder(
|
| 30 |
+
"navteca/quora-roberta-base", device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 31 |
)
|
| 32 |
self.passage_ranker = CrossEncoder(
|
| 33 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
|
|
| 45 |
self.sentences_per_passage = args.get("sentences_per_passage", 5)
|
| 46 |
self.max_passages_per_question = args.get("max_passages_per_question", 5)
|
| 47 |
self.max_aggregated_evidences = args.get("max_aggregated_evidences", 5)
|
| 48 |
+
self.question_persist_path = args.get("question_persist_path", "questions.txt")
|
| 49 |
self.snippets_persist_path = args.get("snippets_persist_path", "passage.json")
|
| 50 |
|
| 51 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
|
|
|
| 53 |
claims_with_evidences = {}
|
| 54 |
for i, claim in enumerate(claims):
|
| 55 |
evidences = self.get_web_evidences_for_claim(claim)
|
| 56 |
+
claims_with_evidences[claim] = [(q, e["text"]) for q, e in evidences["aggregated"]]
|
| 57 |
state.set(self.output_name, claims_with_evidences)
|
| 58 |
return True, state
|
| 59 |
|
|
|
|
| 70 |
snippets = {}
|
| 71 |
for question in questions:
|
| 72 |
retrieved_passages = self.get_relevant_snippets(question)
|
| 73 |
+
snippets[question] = sorted(retrieved_passages, key=lambda x: x["retrieval_score"], reverse=True)[
|
| 74 |
+
: self.max_passages_per_question
|
| 75 |
+
]
|
|
|
|
|
|
|
| 76 |
save_json(snippets, self.snippets_persist_path)
|
| 77 |
return snippets
|
| 78 |
|
|
|
|
| 109 |
model=self.model,
|
| 110 |
system_role=self.qgen_system_role,
|
| 111 |
num_retries=self.num_retries,
|
| 112 |
+
temperature=self.qgen_temp,
|
| 113 |
)
|
| 114 |
try:
|
| 115 |
cur_round_questions = set(eval(response))
|
|
|
|
| 181 |
return False
|
| 182 |
return True
|
| 183 |
|
| 184 |
+
def search_google(self, query: str, num_web_pages: int = 10, timeout: int = 6, save_url: str = "") -> list[str]:
|
| 185 |
+
"""Searches the query using Google.
|
| 186 |
Args:
|
| 187 |
query: Search query.
|
| 188 |
num_web_pages: the number of web pages to request.
|
|
|
|
| 197 |
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.14; rv:65.0) Gecko/20100101 Firefox/65.0"
|
| 198 |
# mobile user-agent
|
| 199 |
MOBILE_USER_AGENT = "Mozilla/5.0 (Linux; Android 7.0; SM-G930V Build/NRD90M) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.125 Mobile Safari/537.36"
|
| 200 |
+
headers = {"User-Agent": USER_AGENT}
|
| 201 |
|
| 202 |
# set language
|
| 203 |
# set the Google interface language, use &hl=XX
|
|
|
|
| 221 |
|
| 222 |
# save all url into a txt file
|
| 223 |
if not save_url == "":
|
| 224 |
+
with open(save_url, "w") as file:
|
| 225 |
for url in urls:
|
| 226 |
+
file.write(url + "\n")
|
| 227 |
return urls
|
| 228 |
|
| 229 |
def chunk_text(
|
| 230 |
+
self,
|
| 231 |
+
text: str,
|
| 232 |
+
tokenizer,
|
| 233 |
+
sentences_per_passage: int = 5,
|
| 234 |
+
filter_sentence_len: int = 250,
|
| 235 |
+
sliding_distance: int = 2,
|
| 236 |
) -> list[str]:
|
| 237 |
"""Chunks text into passages using a sliding window.
|
| 238 |
|
|
|
|
| 259 |
]
|
| 260 |
for idx in range(0, len(sents), sliding_distance):
|
| 261 |
passages.append(
|
| 262 |
+
(" ".join(sents[idx : idx + sentences_per_passage]), idx, idx + sentences_per_passage - 1)
|
| 263 |
+
)
|
| 264 |
except UnicodeEncodeError as _: # Sometimes run into Unicode error when tokenizing.
|
| 265 |
print("Unicode error when using Spacy. Skipping text.")
|
| 266 |
|
| 267 |
return passages
|
| 268 |
|
| 269 |
def get_relevant_snippets(
|
| 270 |
+
self,
|
| 271 |
+
query,
|
| 272 |
):
|
| 273 |
search_results = self.search_google(query, timeout=self.search_timeout)
|
| 274 |
|
|
|
|
| 278 |
scraped_results = [r for r in scraped_results if r[0] and ".pdf" not in r[1]]
|
| 279 |
# print("Num Bing Search Results: ", len(scraped_results))
|
| 280 |
retrieved_passages = list()
|
| 281 |
+
for webtext, url in scraped_results[: self.max_search_results_per_query]:
|
| 282 |
passages = self.chunk_text(
|
| 283 |
+
text=webtext, tokenizer=self.tokenizer, sentences_per_passage=self.sentences_per_passage
|
|
|
|
|
|
|
| 284 |
)
|
| 285 |
if not passages:
|
| 286 |
continue
|
|
|
|
| 302 |
overlap = True
|
| 303 |
break
|
| 304 |
|
| 305 |
+
# Only consider top non-overlapping relevant passages to maximise for information
|
| 306 |
if not overlap:
|
| 307 |
relevant_items.append(deepcopy(passage_item))
|
| 308 |
retrieved_passages.append(
|
src/openfactcheck/solvers/webservice/factcheckgpt_vfr.py
CHANGED
|
@@ -10,24 +10,22 @@ from .factcheckgpt_utils.data_util import save_to_file
|
|
| 10 |
from .factcheckgpt_utils.prompt import IDENTIFY_STANCE_PROMPT, IDENTIFY_STANCE_PROMPT_FUNC
|
| 11 |
from .factcheckgpt_utils.nli import nli_infer
|
| 12 |
|
|
|
|
| 13 |
@Solver.register("factcheckgpt_verifier", "claims_with_evidences", "label")
|
| 14 |
class FactCheckGPTVerifier(StandardTaskSolver):
|
| 15 |
def __init__(self, args):
|
| 16 |
super().__init__(args)
|
| 17 |
-
self.stance_model = args.get("stance_model", "gpt-
|
| 18 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 19 |
# self.system_role = args.get("system_role", "You are a helpful factchecker assistant.")
|
| 20 |
self.system_role = "You are a helpful factchecker assistant."
|
| 21 |
self.verify_retries = args.get("verify_retries", 3)
|
| 22 |
-
self.stance_map = {
|
| 23 |
-
1: "support",
|
| 24 |
-
-1: "refute",
|
| 25 |
-
0: "irrelevant"
|
| 26 |
-
}
|
| 27 |
|
| 28 |
def verify_by_stance(
|
| 29 |
-
|
| 30 |
-
|
|
|
|
| 31 |
) -> Any:
|
| 32 |
labels = []
|
| 33 |
for evidence in evidences:
|
|
@@ -46,12 +44,7 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 46 |
|
| 47 |
def identify_stance_gpt(self, evidence, claim):
|
| 48 |
user_input = IDENTIFY_STANCE_PROMPT_FUNC.format(claim=claim, evidence=evidence)
|
| 49 |
-
r = gpt(
|
| 50 |
-
user_input,
|
| 51 |
-
model=self.stance_model,
|
| 52 |
-
system_role=self.system_role,
|
| 53 |
-
num_retries=self.num_retries
|
| 54 |
-
)
|
| 55 |
label = 0
|
| 56 |
try:
|
| 57 |
label = eval(r)
|
|
@@ -59,9 +52,9 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 59 |
print(f"An unexpected error occurred: {e}.")
|
| 60 |
return label
|
| 61 |
|
| 62 |
-
def stance(self, evidence, claim, model="gpt-
|
| 63 |
"""input: a claim and an evidence
|
| 64 |
-
|
| 65 |
label = 0
|
| 66 |
if self.stance_model == "nli":
|
| 67 |
label = nli_infer(premise=evidence, hypothesis=claim)
|
|
@@ -74,7 +67,7 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 74 |
def verify_claim(self, claim: str, evidences: list[str]) -> dict[str, Any]:
|
| 75 |
results = None
|
| 76 |
user_input = VERIFY_PROMPT.format(claim=claim, evidence=evidences)
|
| 77 |
-
r =
|
| 78 |
for _ in range(self.verify_retries):
|
| 79 |
r = gpt(
|
| 80 |
user_input,
|
|
@@ -98,12 +91,7 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 98 |
else:
|
| 99 |
print(f"Error output {r}. It does not output a dict, return factual label by stance aggregation.")
|
| 100 |
factual_label = self.verify_by_stance(claim, evidences)
|
| 101 |
-
results = {
|
| 102 |
-
"reasoning": "",
|
| 103 |
-
"error": "",
|
| 104 |
-
"correction": "",
|
| 105 |
-
"factuality": factual_label
|
| 106 |
-
}
|
| 107 |
return results
|
| 108 |
|
| 109 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
|
@@ -114,6 +102,6 @@ class FactCheckGPTVerifier(StandardTaskSolver):
|
|
| 114 |
result["claim"] = claim
|
| 115 |
result["evidences"] = evidences
|
| 116 |
results.append(result)
|
| 117 |
-
state.set(self.output_name, all([x[
|
| 118 |
state.set("detail", results)
|
| 119 |
return True, state
|
|
|
|
| 10 |
from .factcheckgpt_utils.prompt import IDENTIFY_STANCE_PROMPT, IDENTIFY_STANCE_PROMPT_FUNC
|
| 11 |
from .factcheckgpt_utils.nli import nli_infer
|
| 12 |
|
| 13 |
+
|
| 14 |
@Solver.register("factcheckgpt_verifier", "claims_with_evidences", "label")
|
| 15 |
class FactCheckGPTVerifier(StandardTaskSolver):
|
| 16 |
def __init__(self, args):
|
| 17 |
super().__init__(args)
|
| 18 |
+
self.stance_model = args.get("stance_model", "gpt-4o")
|
| 19 |
self.num_retries = self.global_config.get("num_retries", 3)
|
| 20 |
# self.system_role = args.get("system_role", "You are a helpful factchecker assistant.")
|
| 21 |
self.system_role = "You are a helpful factchecker assistant."
|
| 22 |
self.verify_retries = args.get("verify_retries", 3)
|
| 23 |
+
self.stance_map = {1: "support", -1: "refute", 0: "irrelevant"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def verify_by_stance(
|
| 26 |
+
self,
|
| 27 |
+
claim: str,
|
| 28 |
+
evidences: list[str],
|
| 29 |
) -> Any:
|
| 30 |
labels = []
|
| 31 |
for evidence in evidences:
|
|
|
|
| 44 |
|
| 45 |
def identify_stance_gpt(self, evidence, claim):
|
| 46 |
user_input = IDENTIFY_STANCE_PROMPT_FUNC.format(claim=claim, evidence=evidence)
|
| 47 |
+
r = gpt(user_input, model=self.stance_model, system_role=self.system_role, num_retries=self.num_retries)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
label = 0
|
| 49 |
try:
|
| 50 |
label = eval(r)
|
|
|
|
| 52 |
print(f"An unexpected error occurred: {e}.")
|
| 53 |
return label
|
| 54 |
|
| 55 |
+
def stance(self, evidence, claim, model="gpt-4o"):
|
| 56 |
"""input: a claim and an evidence
|
| 57 |
+
output: label in [support, refute, irrelevant]"""
|
| 58 |
label = 0
|
| 59 |
if self.stance_model == "nli":
|
| 60 |
label = nli_infer(premise=evidence, hypothesis=claim)
|
|
|
|
| 67 |
def verify_claim(self, claim: str, evidences: list[str]) -> dict[str, Any]:
|
| 68 |
results = None
|
| 69 |
user_input = VERIFY_PROMPT.format(claim=claim, evidence=evidences)
|
| 70 |
+
r = ""
|
| 71 |
for _ in range(self.verify_retries):
|
| 72 |
r = gpt(
|
| 73 |
user_input,
|
|
|
|
| 91 |
else:
|
| 92 |
print(f"Error output {r}. It does not output a dict, return factual label by stance aggregation.")
|
| 93 |
factual_label = self.verify_by_stance(claim, evidences)
|
| 94 |
+
results = {"reasoning": "", "error": "", "correction": "", "factuality": factual_label}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
return results
|
| 96 |
|
| 97 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
|
|
|
| 102 |
result["claim"] = claim
|
| 103 |
result["evidences"] = evidences
|
| 104 |
results.append(result)
|
| 105 |
+
state.set(self.output_name, all([x["factuality"] > 0 for x in results]))
|
| 106 |
state.set("detail", results)
|
| 107 |
return True, state
|
src/openfactcheck/solvers/webservice/factool_utils/chat_api.py
CHANGED
|
@@ -21,53 +21,54 @@ import re
|
|
| 21 |
# env
|
| 22 |
# openai.api_key = factool_env_config.openai_api_key
|
| 23 |
|
| 24 |
-
|
|
|
|
| 25 |
def __init__(
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
):
|
| 33 |
-
if
|
| 34 |
openai.api_base = "http://localhost:8000/v1"
|
| 35 |
else:
|
| 36 |
# openai.api_base = "https://api.openai.com/v1"
|
| 37 |
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
| 38 |
assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable."
|
| 39 |
-
assert openai.api_key !=
|
| 40 |
self.client = AsyncOpenAI()
|
| 41 |
|
| 42 |
self.config = {
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
}
|
| 49 |
|
| 50 |
def extract_list_from_string(self, input_string):
|
| 51 |
-
# pattern = r'\[.*\]'
|
| 52 |
# result = re.search(pattern, input_string)
|
| 53 |
# if result:
|
| 54 |
# return result.group()
|
| 55 |
# else:
|
| 56 |
# return None
|
| 57 |
-
start_index = input_string.find(
|
| 58 |
-
end_index = input_string.rfind(
|
| 59 |
|
| 60 |
if start_index != -1 and end_index != -1 and start_index < end_index:
|
| 61 |
-
return input_string[start_index:end_index + 1]
|
| 62 |
else:
|
| 63 |
return None
|
| 64 |
|
| 65 |
def extract_dict_from_string(self, input_string):
|
| 66 |
-
start_index = input_string.find(
|
| 67 |
-
end_index = input_string.rfind(
|
| 68 |
|
| 69 |
if start_index != -1 and end_index != -1 and start_index < end_index:
|
| 70 |
-
return input_string[start_index:end_index + 1]
|
| 71 |
else:
|
| 72 |
return None
|
| 73 |
|
|
@@ -81,7 +82,7 @@ class OpenAIChat():
|
|
| 81 |
return None
|
| 82 |
return output_eval
|
| 83 |
except:
|
| 84 |
-
|
| 85 |
if(expected_type == List):
|
| 86 |
valid_output = self.extract_list_from_string(output)
|
| 87 |
output_eval = ast.literal_eval(valid_output)
|
|
@@ -94,15 +95,15 @@ class OpenAIChat():
|
|
| 94 |
if not isinstance(output_eval, expected_type):
|
| 95 |
return None
|
| 96 |
return output_eval
|
| 97 |
-
|
| 98 |
return None
|
| 99 |
|
| 100 |
async def dispatch_openai_requests(
|
| 101 |
-
|
| 102 |
-
|
| 103 |
) -> list[str]:
|
| 104 |
"""Dispatches requests to OpenAI API asynchronously.
|
| 105 |
-
|
| 106 |
Args:
|
| 107 |
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
|
| 108 |
Returns:
|
|
@@ -113,11 +114,11 @@ class OpenAIChat():
|
|
| 113 |
for _ in range(retry):
|
| 114 |
try:
|
| 115 |
response = await self.client.chat.completions.create(
|
| 116 |
-
model=self.config[
|
| 117 |
messages=messages,
|
| 118 |
-
max_tokens=self.config[
|
| 119 |
-
temperature=self.config[
|
| 120 |
-
top_p=self.config[
|
| 121 |
)
|
| 122 |
return response
|
| 123 |
except openai.RateLimitError:
|
|
@@ -146,10 +147,7 @@ class OpenAIChat():
|
|
| 146 |
|
| 147 |
return None
|
| 148 |
|
| 149 |
-
async_responses = [
|
| 150 |
-
_request_with_retry(messages)
|
| 151 |
-
for messages in messages_list
|
| 152 |
-
]
|
| 153 |
|
| 154 |
return await asyncio.gather(*async_responses, return_exceptions=True)
|
| 155 |
|
|
@@ -161,12 +159,18 @@ class OpenAIChat():
|
|
| 161 |
while retry > 0 and len(messages_list_cur_index) > 0:
|
| 162 |
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
|
| 163 |
|
| 164 |
-
predictions = asyncio.run(
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
finised_index = []
|
| 171 |
for i, pred in enumerate(preds):
|
| 172 |
if pred is not None:
|
|
@@ -179,6 +183,7 @@ class OpenAIChat():
|
|
| 179 |
|
| 180 |
return responses
|
| 181 |
|
|
|
|
| 182 |
# class OpenAIEmbed():
|
| 183 |
# def __init__():
|
| 184 |
# openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
|
|
|
| 21 |
# env
|
| 22 |
# openai.api_key = factool_env_config.openai_api_key
|
| 23 |
|
| 24 |
+
|
| 25 |
+
class OpenAIChat:
|
| 26 |
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
model_name="gpt-4o",
|
| 29 |
+
max_tokens=2500,
|
| 30 |
+
temperature=0,
|
| 31 |
+
top_p=1,
|
| 32 |
+
request_timeout=120,
|
| 33 |
):
|
| 34 |
+
if "gpt" not in model_name:
|
| 35 |
openai.api_base = "http://localhost:8000/v1"
|
| 36 |
else:
|
| 37 |
# openai.api_base = "https://api.openai.com/v1"
|
| 38 |
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
| 39 |
assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable."
|
| 40 |
+
assert openai.api_key != "", "Please set the OPENAI_API_KEY environment variable."
|
| 41 |
self.client = AsyncOpenAI()
|
| 42 |
|
| 43 |
self.config = {
|
| 44 |
+
"model_name": model_name,
|
| 45 |
+
"max_tokens": max_tokens,
|
| 46 |
+
"temperature": temperature,
|
| 47 |
+
"top_p": top_p,
|
| 48 |
+
"request_timeout": request_timeout,
|
| 49 |
}
|
| 50 |
|
| 51 |
def extract_list_from_string(self, input_string):
|
| 52 |
+
# pattern = r'\[.*\]'
|
| 53 |
# result = re.search(pattern, input_string)
|
| 54 |
# if result:
|
| 55 |
# return result.group()
|
| 56 |
# else:
|
| 57 |
# return None
|
| 58 |
+
start_index = input_string.find("[")
|
| 59 |
+
end_index = input_string.rfind("]")
|
| 60 |
|
| 61 |
if start_index != -1 and end_index != -1 and start_index < end_index:
|
| 62 |
+
return input_string[start_index : end_index + 1]
|
| 63 |
else:
|
| 64 |
return None
|
| 65 |
|
| 66 |
def extract_dict_from_string(self, input_string):
|
| 67 |
+
start_index = input_string.find("{")
|
| 68 |
+
end_index = input_string.rfind("}")
|
| 69 |
|
| 70 |
if start_index != -1 and end_index != -1 and start_index < end_index:
|
| 71 |
+
return input_string[start_index : end_index + 1]
|
| 72 |
else:
|
| 73 |
return None
|
| 74 |
|
|
|
|
| 82 |
return None
|
| 83 |
return output_eval
|
| 84 |
except:
|
| 85 |
+
"""
|
| 86 |
if(expected_type == List):
|
| 87 |
valid_output = self.extract_list_from_string(output)
|
| 88 |
output_eval = ast.literal_eval(valid_output)
|
|
|
|
| 95 |
if not isinstance(output_eval, expected_type):
|
| 96 |
return None
|
| 97 |
return output_eval
|
| 98 |
+
"""
|
| 99 |
return None
|
| 100 |
|
| 101 |
async def dispatch_openai_requests(
|
| 102 |
+
self,
|
| 103 |
+
messages_list,
|
| 104 |
) -> list[str]:
|
| 105 |
"""Dispatches requests to OpenAI API asynchronously.
|
| 106 |
+
|
| 107 |
Args:
|
| 108 |
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
|
| 109 |
Returns:
|
|
|
|
| 114 |
for _ in range(retry):
|
| 115 |
try:
|
| 116 |
response = await self.client.chat.completions.create(
|
| 117 |
+
model=self.config["model_name"],
|
| 118 |
messages=messages,
|
| 119 |
+
max_tokens=self.config["max_tokens"],
|
| 120 |
+
temperature=self.config["temperature"],
|
| 121 |
+
top_p=self.config["top_p"],
|
| 122 |
)
|
| 123 |
return response
|
| 124 |
except openai.RateLimitError:
|
|
|
|
| 147 |
|
| 148 |
return None
|
| 149 |
|
| 150 |
+
async_responses = [_request_with_retry(messages) for messages in messages_list]
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
return await asyncio.gather(*async_responses, return_exceptions=True)
|
| 153 |
|
|
|
|
| 159 |
while retry > 0 and len(messages_list_cur_index) > 0:
|
| 160 |
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
|
| 161 |
|
| 162 |
+
predictions = asyncio.run(
|
| 163 |
+
self.dispatch_openai_requests(
|
| 164 |
+
messages_list=messages_list_cur,
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
preds = [
|
| 169 |
+
self._type_check(self._boolean_fix(prediction.choices[0].message.content), expected_type)
|
| 170 |
+
if prediction is not None
|
| 171 |
+
else None
|
| 172 |
+
for prediction in predictions
|
| 173 |
+
]
|
| 174 |
finised_index = []
|
| 175 |
for i, pred in enumerate(preds):
|
| 176 |
if pred is not None:
|
|
|
|
| 183 |
|
| 184 |
return responses
|
| 185 |
|
| 186 |
+
|
| 187 |
# class OpenAIEmbed():
|
| 188 |
# def __init__():
|
| 189 |
# openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
src/openfactcheck/solvers/webservice/ftool_cp.py
CHANGED
|
@@ -4,11 +4,12 @@ from openfactcheck.solver import StandardTaskSolver, Solver
|
|
| 4 |
from .factool_utils.chat_api import OpenAIChat
|
| 5 |
from .factool_utils.prompt import CLAIM_EXTRACTION_PROMPT
|
| 6 |
|
|
|
|
| 7 |
@Solver.register("factool_claimprocessor", "response", "claims")
|
| 8 |
class FactoolClaimProcessor(StandardTaskSolver):
|
| 9 |
def __init__(self, args):
|
| 10 |
super().__init__(args)
|
| 11 |
-
self.gpt_model = self.global_config.get("factool_gpt_model", "gpt-
|
| 12 |
self.gpt = OpenAIChat(self.gpt_model)
|
| 13 |
self.claim_prompt = CLAIM_EXTRACTION_PROMPT
|
| 14 |
|
|
@@ -16,7 +17,7 @@ class FactoolClaimProcessor(StandardTaskSolver):
|
|
| 16 |
response = state.get(self.input_name)
|
| 17 |
|
| 18 |
claims = self._claim_extraction(responses=[response])[0]
|
| 19 |
-
|
| 20 |
extracted_claims = [claim["claim"] for claim in claims]
|
| 21 |
|
| 22 |
state.set(self.output_name, extracted_claims)
|
|
|
|
| 4 |
from .factool_utils.chat_api import OpenAIChat
|
| 5 |
from .factool_utils.prompt import CLAIM_EXTRACTION_PROMPT
|
| 6 |
|
| 7 |
+
|
| 8 |
@Solver.register("factool_claimprocessor", "response", "claims")
|
| 9 |
class FactoolClaimProcessor(StandardTaskSolver):
|
| 10 |
def __init__(self, args):
|
| 11 |
super().__init__(args)
|
| 12 |
+
self.gpt_model = self.global_config.get("factool_gpt_model", "gpt-4o")
|
| 13 |
self.gpt = OpenAIChat(self.gpt_model)
|
| 14 |
self.claim_prompt = CLAIM_EXTRACTION_PROMPT
|
| 15 |
|
|
|
|
| 17 |
response = state.get(self.input_name)
|
| 18 |
|
| 19 |
claims = self._claim_extraction(responses=[response])[0]
|
| 20 |
+
|
| 21 |
extracted_claims = [claim["claim"] for claim in claims]
|
| 22 |
|
| 23 |
state.set(self.output_name, extracted_claims)
|
src/openfactcheck/solvers/webservice/ftool_rtv.py
CHANGED
|
@@ -5,11 +5,12 @@ from .factool_utils.chat_api import OpenAIChat
|
|
| 5 |
from .factool_utils.search_api import GoogleSerperAPIWrapper
|
| 6 |
from .factool_utils.prompt import QUERY_GENERATION_PROMPT
|
| 7 |
|
|
|
|
| 8 |
@Solver.register("factool_retriever", "claims", "claims_with_evidences")
|
| 9 |
class FactoolRetriever(StandardTaskSolver):
|
| 10 |
def __init__(self, args):
|
| 11 |
super().__init__(args)
|
| 12 |
-
self.gpt_model = self.global_config.get("factool_gpt_model", "gpt-
|
| 13 |
self.snippet_cnt = args.get("snippet_cnt", 10)
|
| 14 |
self.gpt = OpenAIChat(self.gpt_model)
|
| 15 |
self.query_prompt = QUERY_GENERATION_PROMPT
|
|
@@ -22,8 +23,8 @@ class FactoolRetriever(StandardTaskSolver):
|
|
| 22 |
evidences = self.search_engine.run(queries)
|
| 23 |
results = {}
|
| 24 |
for query, claim, evidence in zip(queries, claims, evidences):
|
| 25 |
-
merged_query =
|
| 26 |
-
results[claim] = [(merged_query, x[
|
| 27 |
state.set(self.output_name, results)
|
| 28 |
return True, state
|
| 29 |
|
|
|
|
| 5 |
from .factool_utils.search_api import GoogleSerperAPIWrapper
|
| 6 |
from .factool_utils.prompt import QUERY_GENERATION_PROMPT
|
| 7 |
|
| 8 |
+
|
| 9 |
@Solver.register("factool_retriever", "claims", "claims_with_evidences")
|
| 10 |
class FactoolRetriever(StandardTaskSolver):
|
| 11 |
def __init__(self, args):
|
| 12 |
super().__init__(args)
|
| 13 |
+
self.gpt_model = self.global_config.get("factool_gpt_model", "gpt-4o")
|
| 14 |
self.snippet_cnt = args.get("snippet_cnt", 10)
|
| 15 |
self.gpt = OpenAIChat(self.gpt_model)
|
| 16 |
self.query_prompt = QUERY_GENERATION_PROMPT
|
|
|
|
| 23 |
evidences = self.search_engine.run(queries)
|
| 24 |
results = {}
|
| 25 |
for query, claim, evidence in zip(queries, claims, evidences):
|
| 26 |
+
merged_query = " ".join(query) if query and len(query) > 1 else str(query) if query else ""
|
| 27 |
+
results[claim] = [(merged_query, x["content"]) for x in evidence]
|
| 28 |
state.set(self.output_name, results)
|
| 29 |
return True, state
|
| 30 |
|
src/openfactcheck/solvers/webservice/ftool_vfr.py
CHANGED
|
@@ -4,11 +4,12 @@ from openfactcheck.solver import StandardTaskSolver, Solver
|
|
| 4 |
from .factool_utils.chat_api import OpenAIChat
|
| 5 |
from .factool_utils.prompt import VERIFICATION_PROMPT
|
| 6 |
|
|
|
|
| 7 |
@Solver.register("factool_verifier", "claims_with_evidences", "label")
|
| 8 |
class FactoolVerifier(StandardTaskSolver):
|
| 9 |
def __init__(self, args):
|
| 10 |
super().__init__(args)
|
| 11 |
-
self.gpt_model = self.global_config.get("factool_gpt_model", "gpt-
|
| 12 |
self.gpt = OpenAIChat(self.gpt_model)
|
| 13 |
self.verification_prompt = VERIFICATION_PROMPT
|
| 14 |
|
|
@@ -16,19 +17,23 @@ class FactoolVerifier(StandardTaskSolver):
|
|
| 16 |
claims_with_evidences = state.get(self.input_name)
|
| 17 |
results = self._verification(claims_with_evidences)
|
| 18 |
for i, k in enumerate(list(claims_with_evidences.keys())):
|
| 19 |
-
results[i][
|
| 20 |
-
results[i][
|
| 21 |
state.set("detail", results)
|
| 22 |
-
label = all(v[
|
| 23 |
state.set(self.output_name, label)
|
| 24 |
return True, state
|
| 25 |
|
| 26 |
def _verification(self, claims_with_evidences):
|
| 27 |
messages_list = [
|
| 28 |
[
|
| 29 |
-
{"role": "system", "content": self.verification_prompt[
|
| 30 |
-
{
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
]
|
| 33 |
for claim, evidence in claims_with_evidences.items()
|
| 34 |
]
|
|
|
|
| 4 |
from .factool_utils.chat_api import OpenAIChat
|
| 5 |
from .factool_utils.prompt import VERIFICATION_PROMPT
|
| 6 |
|
| 7 |
+
|
| 8 |
@Solver.register("factool_verifier", "claims_with_evidences", "label")
|
| 9 |
class FactoolVerifier(StandardTaskSolver):
|
| 10 |
def __init__(self, args):
|
| 11 |
super().__init__(args)
|
| 12 |
+
self.gpt_model = self.global_config.get("factool_gpt_model", "gpt-4o")
|
| 13 |
self.gpt = OpenAIChat(self.gpt_model)
|
| 14 |
self.verification_prompt = VERIFICATION_PROMPT
|
| 15 |
|
|
|
|
| 17 |
claims_with_evidences = state.get(self.input_name)
|
| 18 |
results = self._verification(claims_with_evidences)
|
| 19 |
for i, k in enumerate(list(claims_with_evidences.keys())):
|
| 20 |
+
results[i]["claim"] = k
|
| 21 |
+
results[i]["evidences"] = claims_with_evidences[k]
|
| 22 |
state.set("detail", results)
|
| 23 |
+
label = all(v["factuality"] for v in results)
|
| 24 |
state.set(self.output_name, label)
|
| 25 |
return True, state
|
| 26 |
|
| 27 |
def _verification(self, claims_with_evidences):
|
| 28 |
messages_list = [
|
| 29 |
[
|
| 30 |
+
{"role": "system", "content": self.verification_prompt["system"]},
|
| 31 |
+
{
|
| 32 |
+
"role": "user",
|
| 33 |
+
"content": self.verification_prompt["user"].format(
|
| 34 |
+
claim=claim, evidence=str([e[1] for e in evidence])
|
| 35 |
+
),
|
| 36 |
+
},
|
| 37 |
]
|
| 38 |
for claim, evidence in claims_with_evidences.items()
|
| 39 |
]
|
src/openfactcheck/solvers/webservice/rarr_rtv.py
CHANGED
|
@@ -5,11 +5,12 @@ from .rarr_utils.question_generation import run_rarr_question_generation
|
|
| 5 |
from .rarr_utils.functional_prompt import QGEN_PROMPT
|
| 6 |
from .rarr_utils import search
|
| 7 |
|
|
|
|
| 8 |
@Solver.register("rarr_retriever", "claims", "claims_with_evidences")
|
| 9 |
class RARRRetriever(StandardTaskSolver):
|
| 10 |
def __init__(self, args):
|
| 11 |
super().__init__(args)
|
| 12 |
-
self.model = self.global_config.get("rarr_model", "gpt-
|
| 13 |
self.temperature_qgen = args.get("temperature_qgen", 0.7)
|
| 14 |
self.num_rounds_qgen = args.get("num_rounds_qgen", 3)
|
| 15 |
self.max_search_results_per_query = args.get("max_search_results_per_query", 5)
|
|
@@ -19,7 +20,7 @@ class RARRRetriever(StandardTaskSolver):
|
|
| 19 |
|
| 20 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 21 |
claims = state.get(self.input_name)
|
| 22 |
-
|
| 23 |
results = dict()
|
| 24 |
for claim in claims:
|
| 25 |
questions = run_rarr_question_generation(
|
|
@@ -39,8 +40,8 @@ class RARRRetriever(StandardTaskSolver):
|
|
| 39 |
sliding_distance=self.sliding_distance,
|
| 40 |
max_passages_per_search_result_to_return=self.max_passages_per_search_result,
|
| 41 |
)
|
| 42 |
-
evidences.extend([(question, x[
|
| 43 |
-
|
| 44 |
results[claim] = evidences
|
| 45 |
|
| 46 |
state.set(self.output_name, results)
|
|
|
|
| 5 |
from .rarr_utils.functional_prompt import QGEN_PROMPT
|
| 6 |
from .rarr_utils import search
|
| 7 |
|
| 8 |
+
|
| 9 |
@Solver.register("rarr_retriever", "claims", "claims_with_evidences")
|
| 10 |
class RARRRetriever(StandardTaskSolver):
|
| 11 |
def __init__(self, args):
|
| 12 |
super().__init__(args)
|
| 13 |
+
self.model = self.global_config.get("rarr_model", "gpt-4o-instruct")
|
| 14 |
self.temperature_qgen = args.get("temperature_qgen", 0.7)
|
| 15 |
self.num_rounds_qgen = args.get("num_rounds_qgen", 3)
|
| 16 |
self.max_search_results_per_query = args.get("max_search_results_per_query", 5)
|
|
|
|
| 20 |
|
| 21 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 22 |
claims = state.get(self.input_name)
|
| 23 |
+
|
| 24 |
results = dict()
|
| 25 |
for claim in claims:
|
| 26 |
questions = run_rarr_question_generation(
|
|
|
|
| 40 |
sliding_distance=self.sliding_distance,
|
| 41 |
max_passages_per_search_result_to_return=self.max_passages_per_search_result,
|
| 42 |
)
|
| 43 |
+
evidences.extend([(question, x["text"]) for x in q_evidences])
|
| 44 |
+
|
| 45 |
results[claim] = evidences
|
| 46 |
|
| 47 |
state.set(self.output_name, results)
|
src/openfactcheck/solvers/webservice/rarr_vfr.py
CHANGED
|
@@ -4,19 +4,20 @@ from openfactcheck.solver import StandardTaskSolver, Solver
|
|
| 4 |
from .rarr_utils.agreement_gate import run_agreement_gate
|
| 5 |
from .rarr_utils.functional_prompt import AGREEMENT_GATE_PROMPT
|
| 6 |
|
|
|
|
| 7 |
@Solver.register("rarr_verifier", "claims_with_evidences", "label")
|
| 8 |
class RARRAgreementGate(StandardTaskSolver):
|
| 9 |
def __init__(self, args):
|
| 10 |
super().__init__(args)
|
| 11 |
self.max_evidences_per_question = args.get("max_evidences_per_question", 1)
|
| 12 |
-
self.model = self.global_config.get("rarr_model", "gpt-
|
| 13 |
|
| 14 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 15 |
claims_with_evidences = state.get(self.input_name)
|
| 16 |
results = []
|
| 17 |
for claim, evidences in claims_with_evidences.items():
|
| 18 |
result = {}
|
| 19 |
-
evidences = evidences[:self.max_evidences_per_question]
|
| 20 |
labels = []
|
| 21 |
for query, evidence in evidences:
|
| 22 |
gate = run_agreement_gate(
|
|
@@ -25,14 +26,14 @@ class RARRAgreementGate(StandardTaskSolver):
|
|
| 25 |
query=query,
|
| 26 |
evidence=evidence,
|
| 27 |
model=self.model,
|
| 28 |
-
prompt=AGREEMENT_GATE_PROMPT
|
| 29 |
)
|
| 30 |
-
labels.append(gate[
|
| 31 |
-
result[
|
| 32 |
-
result[
|
| 33 |
-
result[
|
| 34 |
-
result[
|
| 35 |
results.append(result)
|
| 36 |
-
state.set(self.output_name, all([x[
|
| 37 |
state.set("detail", results)
|
| 38 |
return True, state
|
|
|
|
| 4 |
from .rarr_utils.agreement_gate import run_agreement_gate
|
| 5 |
from .rarr_utils.functional_prompt import AGREEMENT_GATE_PROMPT
|
| 6 |
|
| 7 |
+
|
| 8 |
@Solver.register("rarr_verifier", "claims_with_evidences", "label")
|
| 9 |
class RARRAgreementGate(StandardTaskSolver):
|
| 10 |
def __init__(self, args):
|
| 11 |
super().__init__(args)
|
| 12 |
self.max_evidences_per_question = args.get("max_evidences_per_question", 1)
|
| 13 |
+
self.model = self.global_config.get("rarr_model", "gpt-4o-instruct")
|
| 14 |
|
| 15 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
| 16 |
claims_with_evidences = state.get(self.input_name)
|
| 17 |
results = []
|
| 18 |
for claim, evidences in claims_with_evidences.items():
|
| 19 |
result = {}
|
| 20 |
+
evidences = evidences[: self.max_evidences_per_question]
|
| 21 |
labels = []
|
| 22 |
for query, evidence in evidences:
|
| 23 |
gate = run_agreement_gate(
|
|
|
|
| 26 |
query=query,
|
| 27 |
evidence=evidence,
|
| 28 |
model=self.model,
|
| 29 |
+
prompt=AGREEMENT_GATE_PROMPT,
|
| 30 |
)
|
| 31 |
+
labels.append(gate["is_open"])
|
| 32 |
+
result["claim"] = claim
|
| 33 |
+
result["evidences"] = evidences
|
| 34 |
+
result["labels"] = labels
|
| 35 |
+
result["factuality"] = all(labels)
|
| 36 |
results.append(result)
|
| 37 |
+
state.set(self.output_name, all([x["factuality"] for x in results]))
|
| 38 |
state.set("detail", results)
|
| 39 |
return True, state
|
src/openfactcheck/state.py
CHANGED
|
@@ -1,52 +1,89 @@
|
|
|
|
|
|
|
|
| 1 |
from openfactcheck.utils.logging import get_logger
|
| 2 |
|
| 3 |
# Get the logger
|
| 4 |
logger = get_logger(__name__)
|
| 5 |
|
|
|
|
| 6 |
class FactCheckerState:
|
| 7 |
"""
|
| 8 |
-
A class to manage the state of a fact
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
Parameters
|
| 13 |
-
----------
|
| 14 |
-
question : str
|
| 15 |
-
The question to be fact-checked.
|
| 16 |
-
response : str
|
| 17 |
-
The response to the question.
|
| 18 |
"""
|
| 19 |
-
|
|
|
|
| 20 |
"""
|
| 21 |
Initialize the FactCheckerState object.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"""
|
| 23 |
-
self.question: str = question
|
| 24 |
-
self.response: str = response
|
| 25 |
|
| 26 |
-
def set(self, name, value):
|
| 27 |
"""
|
| 28 |
Set an attribute of the state object.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""
|
| 30 |
if hasattr(self, name):
|
| 31 |
-
|
| 32 |
setattr(self, name, value)
|
| 33 |
|
| 34 |
-
def get(self, name):
|
| 35 |
"""
|
| 36 |
Get an attribute of the state object.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
"""
|
| 38 |
if not hasattr(self, name):
|
| 39 |
-
raise ValueError(f"
|
| 40 |
-
return getattr(self, name
|
| 41 |
|
| 42 |
-
def __str__(self):
|
| 43 |
"""
|
| 44 |
Return a string representation of the state object.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
"""
|
| 46 |
return str(self.__dict__)
|
| 47 |
|
| 48 |
-
def to_dict(self):
|
| 49 |
"""
|
| 50 |
Return a dictionary representation of the state object.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
"""
|
| 52 |
return self.__dict__
|
|
|
|
| 1 |
+
from typing import Any, Optional
|
| 2 |
+
|
| 3 |
from openfactcheck.utils.logging import get_logger
|
| 4 |
|
| 5 |
# Get the logger
|
| 6 |
logger = get_logger(__name__)
|
| 7 |
|
| 8 |
+
|
| 9 |
class FactCheckerState:
|
| 10 |
"""
|
| 11 |
+
A class to manage the state of a fact-checking system.
|
| 12 |
+
|
| 13 |
+
It holds a question and its corresponding response, and provides methods
|
| 14 |
+
to set and get these attributes dynamically.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, question: Optional[str] = None, response: Optional[str] = None) -> None:
|
| 18 |
"""
|
| 19 |
Initialize the FactCheckerState object.
|
| 20 |
+
|
| 21 |
+
Parameters
|
| 22 |
+
----------
|
| 23 |
+
question : Optional[str]
|
| 24 |
+
The question to be fact-checked.
|
| 25 |
+
response : Optional[str]
|
| 26 |
+
The response to the question.
|
| 27 |
"""
|
| 28 |
+
self.question: Optional[str] = question
|
| 29 |
+
self.response: Optional[str] = response
|
| 30 |
|
| 31 |
+
def set(self, name: str, value: Any) -> None:
|
| 32 |
"""
|
| 33 |
Set an attribute of the state object.
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
name : str
|
| 38 |
+
The name of the attribute to set.
|
| 39 |
+
value : Any
|
| 40 |
+
The value to set for the attribute.
|
| 41 |
"""
|
| 42 |
if hasattr(self, name):
|
| 43 |
+
logger.warning(f"Modifying existing attribute '{name}'")
|
| 44 |
setattr(self, name, value)
|
| 45 |
|
| 46 |
+
def get(self, name: str) -> Any:
|
| 47 |
"""
|
| 48 |
Get an attribute of the state object.
|
| 49 |
+
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
name : str
|
| 53 |
+
The name of the attribute to retrieve.
|
| 54 |
+
|
| 55 |
+
Returns
|
| 56 |
+
-------
|
| 57 |
+
Any
|
| 58 |
+
The value of the requested attribute.
|
| 59 |
+
|
| 60 |
+
Raises
|
| 61 |
+
------
|
| 62 |
+
ValueError
|
| 63 |
+
If the attribute does not exist.
|
| 64 |
"""
|
| 65 |
if not hasattr(self, name):
|
| 66 |
+
raise ValueError(f"Attribute '{name}' does not exist")
|
| 67 |
+
return getattr(self, name)
|
| 68 |
|
| 69 |
+
def __str__(self) -> str:
|
| 70 |
"""
|
| 71 |
Return a string representation of the state object.
|
| 72 |
+
|
| 73 |
+
Returns
|
| 74 |
+
-------
|
| 75 |
+
str
|
| 76 |
+
A string representation of the object's dictionary.
|
| 77 |
"""
|
| 78 |
return str(self.__dict__)
|
| 79 |
|
| 80 |
+
def to_dict(self) -> dict[str, Any]:
|
| 81 |
"""
|
| 82 |
Return a dictionary representation of the state object.
|
| 83 |
+
|
| 84 |
+
Returns
|
| 85 |
+
-------
|
| 86 |
+
Dict[str, Any]
|
| 87 |
+
A dictionary containing the object's attributes.
|
| 88 |
"""
|
| 89 |
return self.__dict__
|
src/openfactcheck/templates/solver_configs/webservice.yaml
CHANGED
|
@@ -9,7 +9,7 @@ factool_retriever:
|
|
| 9 |
factool_verifier:
|
| 10 |
input_name: claims_with_evidences
|
| 11 |
output_name: label
|
| 12 |
-
factcheckgpt_model: gpt-
|
| 13 |
factcheckgpt_claimprocessor:
|
| 14 |
input_name: response
|
| 15 |
output_name: claims
|
|
@@ -31,9 +31,9 @@ factcheckgpt_retriever:
|
|
| 31 |
factcheckgpt_verifier:
|
| 32 |
input_name: claims_with_evidences
|
| 33 |
output_name: label
|
| 34 |
-
stance_model: gpt-
|
| 35 |
verify_retries: 3
|
| 36 |
-
rarr_model: gpt-
|
| 37 |
rarr_retriever:
|
| 38 |
input_name: claims
|
| 39 |
output_name: claims_with_evidences
|
|
|
|
| 9 |
factool_verifier:
|
| 10 |
input_name: claims_with_evidences
|
| 11 |
output_name: label
|
| 12 |
+
factcheckgpt_model: gpt-4o
|
| 13 |
factcheckgpt_claimprocessor:
|
| 14 |
input_name: response
|
| 15 |
output_name: claims
|
|
|
|
| 31 |
factcheckgpt_verifier:
|
| 32 |
input_name: claims_with_evidences
|
| 33 |
output_name: label
|
| 34 |
+
stance_model: gpt-4o
|
| 35 |
verify_retries: 3
|
| 36 |
+
rarr_model: gpt-4o-instruct
|
| 37 |
rarr_retriever:
|
| 38 |
input_name: claims
|
| 39 |
output_name: claims_with_evidences
|