File size: 4,448 Bytes
48bb68b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
## Database retrieval based on keywords 
## need to ] add [email protected]


"""
## Calculates distances and assigns tentative classification
"""
function distances_and_classification(narrative_matrix, target_matrix)
    distances = pairwise(CosineDist(), target_matrix, narrative_matrix, dims=2)
    # get the index of the column with the smallest distance
    return distances[argmin(distances, dims=2)][:, 1], argmin(distances, dims=2)[:, 1]
end

"""
## Assignments of closest claim and counterclaim to the test data
"""
function assignments!(narrative_matrix, target_matrix, narrative_embeddings, target_embeddings; kwargs...)
    claim_counter_claim = get(kwargs, :claim_counter_claim, "claim")
    dists, narrative_assignment = distances_and_classification(narrative_matrix, target_matrix)
    target_embeddings[:, "$(claim_counter_claim)Dist"] = dists
    target_embeddings[:, "Closest$(claim_counter_claim)"] = [narrative_embeddings[x[2], claim_counter_claim] for x in narrative_assignment[:, 1]]
    return nothing 
end

"""
## Get distances and assign the closest claim to the test data

include("src/Narrative.jl")
include("src/NarrativeClassification.jl")
climate_narrative = create_example_narrative();
generate_claim_embeddings_from_narrative!(climate_narrative)
candidate_data = candidate_embeddings(climate_narrative)
get_distances!(climate_narrative, candidate_data)
"""
function get_distances!(narrative::Narrative, target_embeddings::DataFrame)
    ## Matrix of embeddings
    narrative_embeddings = narrative_to_dataframe(narrative)
    narrative_matrix = hcat([claim.claimembedding for claim in narrative.claims]...)
    counternarrative_matrix = hcat([claim.counterclaimembedding for claim in narrative.claims]...)
    target_matrix = hcat(target_embeddings[:, "Embeddings"]...)
    # Create a search function
    # Assign the closest claim to the test data
    assignments!(narrative_matrix, target_matrix, narrative_embeddings, target_embeddings, claim_counter_claim="claim")
    # Assign the closest counterclaim to the test data
    assignments!(counternarrative_matrix, target_matrix, narrative_embeddings, target_embeddings, claim_counter_claim="counterclaim")
    return nothing
end

function apply_gate_logic!(target_embeddings; kwargs...)
    threshold = get(kwargs, :threshold, 0.2)
    # Find those closer to claim than counter claim 
    closer_to_claim = findall(target_embeddings[:, "claimDist"] .< target_embeddings[:, "counterclaimDist"])
    # Meets the threshold
    meets_threshold = findall(target_embeddings[:, "claimDist"] .< threshold)
    # Meets the threshold and is closer to claim than counter claim
    target_embeddings[:, "OCLabel"] .= 0
    target_embeddings[intersect(meets_threshold, closer_to_claim), "OCLabel"] .= 1
    return nothing
end

"""
## Deploy the narrative model
- Input: narrative, threshold

include("src/Narrative.jl")
include("src/NarrativeClassification.jl")
include("src/ExampleNarrative.jl")
climate_narrative = create_example_narrative();
generate_claim_embeddings_from_narrative!(climate_narrative)
candidate_data = candidate_embeddings_from_narrative(climate_narrative)
get_distances!(climate_narrative, candidate_data)
apply_gate_logic!(candidate_data; threshold=0.2)
return_top_labels(candidate_data)

"""
function return_top_labels(target_embeddings; kwargs...)
    top_labels = get(kwargs, :top_labels, 10)
    # Filter to "OCLabel" == 1
    out = target_embeddings[findall(target_embeddings[:, "OCLabel"] .== 1), :]
    # sort by claimDist
    sort!(out, :claimDist)
    return out[1:min(top_labels, nrow(out)), :]
end

function return_positive_candidates(target_embeddings)
    return target_embeddings[findall(target_embeddings[:, "OCLabel"] .== 1), :]
end

"""
## Deploy the narrative model
- Input: narrative, threshold

include("src/Narrative.jl")
include("src/NarrativeClassification.jl")
include("src/ExampleNarrative.jl")
climate_narrative = create_example_narrative();
deploy_narrative_model!(climate_narrative; threshold=0.2)
"""
function deploy_narrative_model!(narrative::Narrative; kwargs...)
    threshold = get(kwargs, :threshold, 0.2)
    db = get(kwargs, :db, "data/random_300k.csv")
    generate_claim_embeddings_from_narrative!(narrative)
    candidate_data = candidate_embeddings_from_narrative(narrative; db=db)
    get_distances!(narrative, candidate_data)
    apply_gate_logic!(candidate_data, threshold=threshold)
    return candidate_data
end