|
|
|
using CSV, JLD2, DataFrames, OpenAI, StatsBase, Distances, TidierPlots |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function table(df::DataFrame, cols::Vector{Symbol}) |
|
combine(groupby(df, cols), nrow) |
|
end |
|
|
|
|
|
""" |
|
## Embeddings to recover narratives |
|
narrative_embeddings = create_narrative_embeddings() |
|
""" |
|
function create_narrative_embeddings(regenerate=false) |
|
if !regenerate && isfile("data/narrative_embeddings.jld2") |
|
return load_object("data/narrative_embeddings.jld2") |
|
end |
|
@info "Regenerating narrative embeddings..." |
|
narratives = CSV.read("data/Modified Misinformation Library.csv", DataFrame) |
|
|
|
n_embeddings = create_embeddings(ENV["OPENAI_API_KEY"], narratives[!, "Misinformation Narrative"]) |
|
|
|
narratives[!, "Embeddings"] = [x["embedding"] for x in n_embeddings.response["data"]] |
|
|
|
save_object("data/narrative_embeddings.jld2", narratives) |
|
return narratives |
|
end |
|
|
|
""" |
|
# This is the testing data |
|
target_embeddings = create_test_embeddings() |
|
""" |
|
function create_test_embeddings(regenerate=false) |
|
if !regenerate && isfile("data/test_embeddings.jld2") |
|
return load_object("data/test_embeddings.jld2") |
|
end |
|
@info "Regenerating test embeddings..." |
|
df_test = CSV.read("data/Indicator_Test.csv", DataFrame) |
|
|
|
n_embeddings = create_embeddings(ENV["OPENAI_API_KEY"], df_test[!, "text"]) |
|
|
|
df_test[!, "Embeddings"] = [x["embedding"] for x in n_embeddings.response["data"]] |
|
|
|
save_object("data/test_embeddings.jld2", df_test) |
|
return df_test |
|
end |
|
|
|
""" |
|
### The embeddings for each example are along the rows, so they can be compared column-wise (fast) |
|
narrative_embeddings = create_narrative_embeddings() |
|
target_embeddings = create_test_embeddings() |
|
one_shot_classification!(narrative_embeddings, target_embeddings) |
|
## Show the results - text, closest narrative |
|
target_embeddings[:, ["text", "Closest Narrative", "label"]] |> first(5) |
|
""" |
|
function one_shot_classification!(narrative_embeddings, target_embeddings) |
|
|
|
narrative_matrix = hcat(narrative_embeddings[:, "Embeddings"]...) |
|
target_matrix = hcat(target_embeddings[:, "Embeddings"]...) |
|
|
|
function search(narrative_matrix, target_matrix) |
|
distances = pairwise(CosineDist(), target_matrix, narrative_matrix, dims=2) |
|
|
|
narrative_index = argmin(distances, dims=2) |
|
return narrative_index |
|
end |
|
|
|
narrative_assignment = search(narrative_matrix, target_matrix) |
|
target_embeddings[:, "Closest Narrative"] = [narrative_embeddings[x[2], "Misinformation Narrative"] for x in narrative_assignment[:, 1]] |
|
return target_embeddings |
|
end |
|
|
|
function get_distances!(narrative_embeddings, target_embeddings) |
|
|
|
narrative_matrix = hcat(narrative_embeddings[:, "Embeddings"]...) |
|
target_matrix = hcat(target_embeddings[:, "Embeddings"]...) |
|
|
|
function embedding_distances(narrative_matrix, target_matrix) |
|
distances = pairwise(CosineDist(), target_matrix, narrative_matrix, dims=2) |
|
|
|
return distances[argmin(distances, dims=2)][:, 1] |
|
end |
|
|
|
target_embeddings[:, "Dist"] = embedding_distances(narrative_matrix, target_matrix) |
|
return target_embeddings |
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
narrative_embeddings = create_narrative_embeddings() |
|
target_embeddings = create_test_embeddings() |
|
one_shot_classification!(narrative_embeddings, target_embeddings) |
|
get_distances!(narrative_embeddings, target_embeddings) |
|
|
|
|
|
using TidierPlots |
|
|
|
|
|
ggplot(target_embeddings, @aes(x = label, y = Dist)) + |
|
geom_violin() + labs(x="Misinfo Label", y="Distance") |
|
|
|
|
|
|
|
|
|
|
|
|
|
target_embeddings[!, "MisinfoPred"] = target_embeddings[!, "Dist"] .< 0.2 |
|
|
|
|
|
using MLJ |
|
|
|
y_true = target_embeddings[!, "label"] |
|
y_pred = target_embeddings[!, "MisinfoPred"] |
|
confusion_matrix(y_pred, y_true) |
|
accuracy(y_true, y_pred) |
|
true_positive_rate(y_true, y_pred) |
|
false_positive_rate(y_true, y_pred) |
|
|
|
|
|
target_embeddings |> |
|
(data -> filter(:label => x -> x .== 1.0, data)) |> |
|
(data -> sort(data, :Dist)) |> |
|
(data -> first(data, 10)) |> |
|
(data -> select(data, ["text", "Closest Narrative", "Dist"])) |