File size: 4,448 Bytes
48bb68b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
## Database retrieval based on keywords
## need to ] add [email protected]
"""
## Calculates distances and assigns tentative classification
"""
function distances_and_classification(narrative_matrix, target_matrix)
distances = pairwise(CosineDist(), target_matrix, narrative_matrix, dims=2)
# get the index of the column with the smallest distance
return distances[argmin(distances, dims=2)][:, 1], argmin(distances, dims=2)[:, 1]
end
"""
## Assignments of closest claim and counterclaim to the test data
"""
function assignments!(narrative_matrix, target_matrix, narrative_embeddings, target_embeddings; kwargs...)
claim_counter_claim = get(kwargs, :claim_counter_claim, "claim")
dists, narrative_assignment = distances_and_classification(narrative_matrix, target_matrix)
target_embeddings[:, "$(claim_counter_claim)Dist"] = dists
target_embeddings[:, "Closest$(claim_counter_claim)"] = [narrative_embeddings[x[2], claim_counter_claim] for x in narrative_assignment[:, 1]]
return nothing
end
"""
## Get distances and assign the closest claim to the test data
include("src/Narrative.jl")
include("src/NarrativeClassification.jl")
climate_narrative = create_example_narrative();
generate_claim_embeddings_from_narrative!(climate_narrative)
candidate_data = candidate_embeddings(climate_narrative)
get_distances!(climate_narrative, candidate_data)
"""
function get_distances!(narrative::Narrative, target_embeddings::DataFrame)
## Matrix of embeddings
narrative_embeddings = narrative_to_dataframe(narrative)
narrative_matrix = hcat([claim.claimembedding for claim in narrative.claims]...)
counternarrative_matrix = hcat([claim.counterclaimembedding for claim in narrative.claims]...)
target_matrix = hcat(target_embeddings[:, "Embeddings"]...)
# Create a search function
# Assign the closest claim to the test data
assignments!(narrative_matrix, target_matrix, narrative_embeddings, target_embeddings, claim_counter_claim="claim")
# Assign the closest counterclaim to the test data
assignments!(counternarrative_matrix, target_matrix, narrative_embeddings, target_embeddings, claim_counter_claim="counterclaim")
return nothing
end
function apply_gate_logic!(target_embeddings; kwargs...)
threshold = get(kwargs, :threshold, 0.2)
# Find those closer to claim than counter claim
closer_to_claim = findall(target_embeddings[:, "claimDist"] .< target_embeddings[:, "counterclaimDist"])
# Meets the threshold
meets_threshold = findall(target_embeddings[:, "claimDist"] .< threshold)
# Meets the threshold and is closer to claim than counter claim
target_embeddings[:, "OCLabel"] .= 0
target_embeddings[intersect(meets_threshold, closer_to_claim), "OCLabel"] .= 1
return nothing
end
"""
## Deploy the narrative model
- Input: narrative, threshold
include("src/Narrative.jl")
include("src/NarrativeClassification.jl")
include("src/ExampleNarrative.jl")
climate_narrative = create_example_narrative();
generate_claim_embeddings_from_narrative!(climate_narrative)
candidate_data = candidate_embeddings_from_narrative(climate_narrative)
get_distances!(climate_narrative, candidate_data)
apply_gate_logic!(candidate_data; threshold=0.2)
return_top_labels(candidate_data)
"""
function return_top_labels(target_embeddings; kwargs...)
top_labels = get(kwargs, :top_labels, 10)
# Filter to "OCLabel" == 1
out = target_embeddings[findall(target_embeddings[:, "OCLabel"] .== 1), :]
# sort by claimDist
sort!(out, :claimDist)
return out[1:min(top_labels, nrow(out)), :]
end
function return_positive_candidates(target_embeddings)
return target_embeddings[findall(target_embeddings[:, "OCLabel"] .== 1), :]
end
"""
## Deploy the narrative model
- Input: narrative, threshold
include("src/Narrative.jl")
include("src/NarrativeClassification.jl")
include("src/ExampleNarrative.jl")
climate_narrative = create_example_narrative();
deploy_narrative_model!(climate_narrative; threshold=0.2)
"""
function deploy_narrative_model!(narrative::Narrative; kwargs...)
threshold = get(kwargs, :threshold, 0.2)
db = get(kwargs, :db, "data/random_300k.csv")
generate_claim_embeddings_from_narrative!(narrative)
candidate_data = candidate_embeddings_from_narrative(narrative; db=db)
get_distances!(narrative, candidate_data)
apply_gate_logic!(candidate_data, threshold=threshold)
return candidate_data
end |