stefanjwojcik's picture
Upload 5 files
734777a verified
## Write an SVM classifier to take embeddings, and use them to predict the label
# Get the data
import OstreaCultura as OC
dat = CSV.read("data/Climate Misinformation Library with counterclaims.csv", DataFrame)
## Stack claims and counter claims into a single column, label as 1 for claim, 0 for counter claim
dat = stack(select(dat, r"laims"), [:Claims, :Counterclaims], variable_name=:Type, value_name=:text)
dropmissing!(dat)
dat.label = ifelse.(dat.Type .== "Claims", 1, 0)
## Embeddings of text
model = "multilingual-e5-large"
embeds = OC.multi_embeddings(dat)
#Features
features = convert(Array, embeds.Embeddings)
y = convert(Array, dat.label)
# Generate resnet svm (resvm)
@sk_import calibration: CalibratedClassifierCV
import ScikitLearn: CrossValidation
@sk_import svm: LinearSVC
import ScikitLearn: CrossValidation
using ScikitLearn.CrossValidation: cross_val_score
resvm = LinearSVC(C=.5, loss="squared_hinge", penalty="l2", multi_class="ovr", random_state = 35552, max_iter=2000)
cv = ScikitLearn.CrossValidation.KFold(189, n_folds=5, random_state = 134, shuffle=true)
out = cross_val_score(resvm, features, y, cv = cv)
## get precision and recall
using ScikitLearn: metrics
y_pred = ScikitLearn.CrossValidation.cross_val_predict(resvm, features, y, cv=cv)
## roll your own precision
pre = sum((y .== 1) .& (y_pred .== 1)) / sum(y_pred .== 1)
## roll your own recall
rec = sum((y .== 1) .& (y_pred .== 1)) / sum(y .== 1)
## pull out the support vector embeddings
fit!(resvm, features, y)
#
sv = resvm.coef_
histogram(sv[1, :])
calsvm = CalibratedClassifierCV(resvm)
calsvm.fit(features, y)
prob_preds = calsvm.predict_proba(features)
## Get indices of top highest probabilities in the first column
top_k = 5
top_indices = sortperm(prob_preds[:, 1], rev=true)[1:top_k]
prob_preds[top_indices, :]
dat[top_indices, :]
## Get the indices of the top k highest probabilities in second column
top_indices = sortperm(prob_preds[:, 2], rev=true)[1:top_k]
prob_preds[top_indices, :]
dat[top_indices, :]