Spaces:
Sleeping
Sleeping
## Write an SVM classifier to take embeddings, and use them to predict the label | |
# Get the data | |
import OstreaCultura as OC | |
dat = CSV.read("data/Climate Misinformation Library with counterclaims.csv", DataFrame) | |
## Stack claims and counter claims into a single column, label as 1 for claim, 0 for counter claim | |
dat = stack(select(dat, r"laims"), [:Claims, :Counterclaims], variable_name=:Type, value_name=:text) | |
dropmissing!(dat) | |
dat.label = ifelse.(dat.Type .== "Claims", 1, 0) | |
## Embeddings of text | |
model = "multilingual-e5-large" | |
embeds = OC.multi_embeddings(dat) | |
#Features | |
features = convert(Array, embeds.Embeddings) | |
y = convert(Array, dat.label) | |
# Generate resnet svm (resvm) | |
@sk_import calibration: CalibratedClassifierCV | |
import ScikitLearn: CrossValidation | |
@sk_import svm: LinearSVC | |
import ScikitLearn: CrossValidation | |
using ScikitLearn.CrossValidation: cross_val_score | |
resvm = LinearSVC(C=.5, loss="squared_hinge", penalty="l2", multi_class="ovr", random_state = 35552, max_iter=2000) | |
cv = ScikitLearn.CrossValidation.KFold(189, n_folds=5, random_state = 134, shuffle=true) | |
out = cross_val_score(resvm, features, y, cv = cv) | |
## get precision and recall | |
using ScikitLearn: metrics | |
y_pred = ScikitLearn.CrossValidation.cross_val_predict(resvm, features, y, cv=cv) | |
## roll your own precision | |
pre = sum((y .== 1) .& (y_pred .== 1)) / sum(y_pred .== 1) | |
## roll your own recall | |
rec = sum((y .== 1) .& (y_pred .== 1)) / sum(y .== 1) | |
## pull out the support vector embeddings | |
fit!(resvm, features, y) | |
# | |
sv = resvm.coef_ | |
histogram(sv[1, :]) | |
calsvm = CalibratedClassifierCV(resvm) | |
calsvm.fit(features, y) | |
prob_preds = calsvm.predict_proba(features) | |
## Get indices of top highest probabilities in the first column | |
top_k = 5 | |
top_indices = sortperm(prob_preds[:, 1], rev=true)[1:top_k] | |
prob_preds[top_indices, :] | |
dat[top_indices, :] | |
## Get the indices of the top k highest probabilities in second column | |
top_indices = sortperm(prob_preds[:, 2], rev=true)[1:top_k] | |
prob_preds[top_indices, :] | |
dat[top_indices, :] |