{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text Similarity Prediction and Analysis\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Objective\n",
"\n",
"The aim of this project is to create a system that can analyze the similarity between records by using text analysis techniques. The system will employ natural language processing methods and similarity metrics to assess the similarity of textual content present in different documents. This analysis will enable applications such as document retrieval, clustering, and recommendation systems to provide more accurate and relevant results based on the similarity of document contents. The goal is to improve information management and information retrieval workflows by providing a robust and efficient method for measuring document similarity using text analysis.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Problem Statement\n",
"\n",
"Effectively organizing, retrieving, and using large volumes of textual documents is a vital challenge in many domains, including digital libraries, knowledge management systems, and content recommendation platforms. With the exponential growth of digital information, it's becoming increasingly difficult to manually categorize, cluster, and identify related records. Without efficient methods to measure the similarity between records based on their textual content, organizations struggle to manage their document repositories effectively, hampering productivity and decision-making processes.\n",
"\n",
"The traditional keyword-based search and retrieval methods often fall short of capturing the true semantic similarities between documents, leading to incomplete or irrelevant results. This creates a pressing need for advanced text analysis techniques that accurately assess the degree of similarity between documents, considering their contextual meaning, themes, and conceptual overlaps.\n",
"\n",
"By developing a robust text similarity analysis system, organizations can unlock multiple benefits, including improved information retrieval, enhanced content clustering and categorization, and more effective recommendation systems. Accurate similarity analysis allows users to identify related documents quickly, facilitating knowledge sharing, collaboration, and decision-making processes. Moreover, such a system can streamline document management workflows, reducing redundancy and enabling more efficient storage and organization of textual information.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Justification\n",
"\n",
"The exponential growth of digital information has resulted in an overwhelming amount of textual data in the form of documents, reports, articles, and other written materials. As this flood of information continues to expand, organizations across various sectors are struggling to manage, organize, and extract value from their document repositories effectively.\n",
"\n",
"Traditional methods of document management, such as manual categorization and keyword-based search, are increasingly inadequate in handling the scale and complexity of modern document collections. These approaches often fail to capture the true semantic similarities between documents, leading to incomplete or irrelevant search results, inefficient clustering, and missed opportunities for knowledge discovery.\n",
"\n",
"Developing a robust system for analyzing document similarity that leverages advanced text analysis techniques and natural language processing methods is essential for several reasons:\n",
"\n",
"1. Improved information retrieval: By accurately measuring the similarity between documents based on their textual content, users can quickly identify and retrieve related materials, enhancing research, decision-making, and knowledge-sharing processes.\n",
"\n",
"2. Efficient document clustering and categorization: Similarity analysis enables automated document clustering and categorization, reducing the need for manual effort and ensuring that related documents are organized together for easier access and navigation.\n",
"\n",
"3. Enhanced recommendation systems: By understanding the semantic relationships between documents, recommendation systems can provide more relevant and personalized suggestions, improving user experience and facilitating content discovery.\n",
"\n",
"4. Reduction of redundancy and duplication: Identifying highly similar or duplicate documents can help organizations streamline their document repositories, reducing storage requirements and improving overall efficiency.\n",
"\n",
"5. Knowledge extraction and insight generation: Analyzing similarities between documents can reveal patterns, trends, and connections that may not be immediately apparent, enabling organizations to uncover valuable insights and make data-driven decisions.\n",
"\n",
"Moreover, as the volume of digital information continues to grow, the importance of effective document similarity analysis will only increase. Failing to address this challenge can lead to inefficient information management, missed opportunities, and a competitive disadvantage for organizations that rely heavily on textual data.\n",
"\n",
"By investing in the development of a robust document similarity analysis system, organizations can future-proof their document management processes, gain a deeper understanding of their information assets, and unlock new opportunities for innovation and growth.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Research Data\n",
"\n",
"The [STS (Semantic Textual Similarity) Benchmark dataset](https://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark) is a popular resource for evaluating the performance of systems designed to measure the semantic similarity between pairs of sentences. It is widely used in the natural language processing community for tasks such as text understanding, paraphrase detection, and sentence similarity analysis.\n",
"\n",
"The STS Benchmark dataset consists of a collection of sentence pairs, each accompanied by a human-annotated similarity score ranging from 0 (no semantic similarity) to 5 (semantic equivalence). These sentence pairs are drawn from various sources, including news articles, image captions, and online forums, covering a diverse range of topics and domains.\n",
"\n",
"The STS Benchmark dataset has been widely adopted due to its diversity, the availability of human-annotated similarity scores, and its usefulness in evaluating the performance of various semantic similarity models and algorithms. It provides a standardized and well-curated resource for researchers and developers working on natural language processing tasks involving semantic similarity analysis.\n",
"\n",
"## Description of Data\n",
"\n",
"The benchmark comprises 8628 sentence pairs. This is the breakdown according to genres and train-dev-test splits:\n",
"\n",
"| | train | dev | test | total |\n",
"| ------- | ----- | ---- | ---- | ----- |\n",
"| news | 3299 | 500 | 500 | 4299 |\n",
"| caption | 2000 | 625 | 625 | 3250 |\n",
"| forum | 450 | 375 | 254 | 1079 |\n",
"| total | 5749 | 1500 | 1379 | 8628 |\n",
"\n",
"Breakdown according to the original names and task years of the datasets:\n",
"\n",
"| genre | file | years | train | dev | test |\n",
"| -------- | -------------- | ------- | ----- | --- | ---- |\n",
"| news | MSRpar | 2012 | 1000 | 250 | 250 |\n",
"| news | headlines | 2013-16 | 1999 | 250 | 250 |\n",
"| news | deft-news | 2014 | 300 | 0 | 0 |\n",
"| captions | MSRvid | 2012 | 1000 | 250 | 250 |\n",
"| captions | images | 2014-15 | 1000 | 250 | 250 |\n",
"| captions | track5.en-en | 2017 | 0 | 125 | 125 |\n",
"| forum | deft-forum | 2014 | 450 | 0 | 0 |\n",
"| forum | answers-forums | 2015 | 0 | 375 | 0 |\n",
"| forum | answer-answer | 2016 | 0 | 0 | 254 |\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Research Questions\n",
"\n",
"1. What are the most effective natural language processing techniques and algorithms for measuring semantic similarity between documents?\n",
"\n",
"2. What are the computational and scalability challenges associated with performing similarity analysis on text collections, and how can these challenges be addressed?\n",
"\n",
"3. How can user interactions and feedback be effectively incorporated into the similarity analysis system to improve its accuracy and adaptability over time?\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Assumptions:\n",
"\n",
"1. There exists a sufficient quantity and variety textual datasets for training and evaluating the performance of the document similarity analysis system.\n",
"2. The availability of computational resources, including processing power and memory, is adequate to support the implementation and deployment of the document similarity analysis system at scale.\n",
"\n",
"3. The natural language processing techniques and similarity metrics selected for the system are capable of effectively capturing semantic relationships and nuances within textual documents across different languages and domains.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Scope:\n",
"\n",
"To develop a minimum viable prototype of a document similarity analysis system using natural language processing techniques. The prototype will be designed to accept text input in the form of documents or text passages and analyze their semantic similarity.\n",
"\n",
"The core functionality will include:\n",
"\n",
"1. Text extraction and preprocessing.\n",
"2. Embedding documents into vector representations using pre-trained language models\n",
"3. Calculating pairwise similarity scores between document embeddings using cosine similarity or other distance metrics\n",
"4. Returning a ranked list of similar documents given an input document\n",
"\n",
"It will be developed as a simple web application using Gradio and deployed on Hugging Face Spaces for easy access and testing.\n",
"\n",
"The initial scope is limited to handling text input in English. Advanced features like multilingual support, domain adaptation, scalability optimizations, and user feedback incorporation are out of scope for this prototype. The primary goal is to demonstrate the core document similarity analysis capabilities using readily available NLP tools and models.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hypothesis\n",
"\n",
"By leveraging NLP models and semantic text similarity techniques, it is hypothesized that the developed prototype system will be able to accurately measure and rank the similarities between documents based on their contextual content. Specifically, the prototype will demonstrate an improvement in identifying semantically related documents compared to traditional keyword-based approaches. This will be achieved by projecting documents into high-dimensional vector representations that capture their underlying meanings and concepts, allowing for a more robust similarity comparison.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Code\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instalations\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %pip install datasets transformers\n",
"# %pip install transformers\n",
"# %pip install accelerate -U\n",
"# %pip install streamlit\n",
"# %pip install textdistance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports\n"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from matplotlib import pyplot as plt\n",
"from sentence_transformers import CrossEncoder, SentenceTransformer, losses, models\n",
"from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
"from sentence_transformers.readers import InputExample\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"from sklearn.preprocessing import StandardScaler\n",
"from torch.utils.data import DataLoader\n",
"import math\n",
"import pandas as pd\n",
"import textdistance\n",
"import numpy as np\n",
"import joblib\n",
"from samples import get_samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load English train Dataset:\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset json (/Users/charleskabue/.cache/huggingface/datasets/mteb___json/mteb--stsbenchmark-sts-998a21523b45a16a/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ee8e8730f3834a00ab4599949457d0d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['split', 'genre', 'dataset', 'year', 'sid', 'score', 'sentence1', 'sentence2'],\n",
" num_rows: 5749\n",
" })\n",
" validation: Dataset({\n",
" features: ['split', 'genre', 'dataset', 'year', 'sid', 'score', 'sentence1', 'sentence2'],\n",
" num_rows: 1500\n",
" })\n",
" test: Dataset({\n",
" features: ['split', 'genre', 'dataset', 'year', 'sid', 'score', 'sentence1', 'sentence2'],\n",
" num_rows: 1379\n",
" })\n",
"})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The dataset is loaded from https://huggingface.co/datasets/mteb/stsbenchmark-sts\n",
"dataset = load_dataset(\"mteb/stsbenchmark-sts\")\n",
"dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare for Training\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5749, 1379, 1500)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_dataset = dataset['train']\n",
"test_dataset = dataset['test']\n",
"validation_dataset = dataset['validation']\n",
"\n",
"# Format the dataset for training\n",
"def format_dataset(row):\n",
" # Normalize score to range 0 ... 1\n",
" score = float(row[\"score\"]) / 5.0\n",
" sentence1, sentence2 = row[\"sentence1\"], row[\"sentence2\"]\n",
" return InputExample(texts=[sentence1, sentence2], label=score)\n",
"\n",
"\n",
"formated_train_dataset = [format_dataset(i) for i in train_dataset]\n",
"formated_test_dataset = [format_dataset(i) for i in test_dataset]\n",
"formated_validation_dataset = [format_dataset(i) for i in validation_dataset]\n",
"len(train_dataset), len(formated_test_dataset), len(formated_validation_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Read the dataset\n",
"train_batch_size = 32\n",
"num_epochs = 16\n",
"model_save_path = 'trained_model_stsbenchmark_bert-base-uncased'\n",
"# Load pre-trained BERT Transformer model from Huggingface\n",
"word_embedding_model = models.Transformer('bert-base-uncased')\n",
"# Apply mean pooling to get one fixed sized sentence vector\n",
"pooling_model = models.Pooling(\n",
" word_embedding_dimension=word_embedding_model.get_word_embedding_dimension(),\n",
" pooling_mode_mean_tokens=True,\n",
" pooling_mode_cls_token=False,\n",
" pooling_mode_max_tokens=False,\n",
")\n",
"model = SentenceTransformer(modules=[word_embedding_model, pooling_model])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warmup-steps: 288\n"
]
}
],
"source": [
"train_dataloader = DataLoader(formated_train_dataset, shuffle=True, batch_size=train_batch_size)\n",
"train_loss = losses.CosineSimilarityLoss(model=model)\n",
"evaluator = EmbeddingSimilarityEvaluator.from_input_examples(formated_test_dataset, name=\"sts-dev\")\n",
"# Configure the training. We skip evaluation in this example\n",
"# 10% of train data for warm-up\n",
"warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)\n",
"print(f\"Warmup-steps: {warmup_steps}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b759320f32764618981bda9c0775e1e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch: 0%| | 0/16 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a7f0f2ded3c244698dfd263162dcac22",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7ce9e6da265949f4a05f479eeea16f03",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2ba7a13ba89d4e5c96514afadaf1b6e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6cb1c2a645774fe098b4c12534f55523",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "790f29c5386e4bca8a36aa9de6005226",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "db02431ce43b4ccdb5c98907a123cd0c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ee949f1b64f94cb59dc764a4e65ead73",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30167834e1454cfd8087b4c632a91ba5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "da8c1ccd110241e09a95e2ec66aabcf3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7e32421ebeb0451f9fdbb413b74c93b1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "edea341ecc39475582926e292f84ceae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "476ab0a0e9d04969a1764cf892111065",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b10f5ebdade247e19b4ef64e5cc9642a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb352a1b390b49b4a4edf5ab4901b62f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ae51d6d8c76b4e818d73116d1c0629e3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cca3ce6202a7492690fa41a1971f2b09",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/180 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.fit(\n",
" train_objectives=[(train_dataloader, train_loss)],\n",
" evaluator=evaluator,\n",
" epochs=num_epochs,\n",
" evaluation_steps=10_000,\n",
" warmup_steps=warmup_steps,\n",
" output_path=model_save_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate the model"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy on STS test dataset: 0.8759463607276905\n"
]
}
],
"source": [
"# Load the stored model and evaluate its performance on STS benchmark dataset\n",
"if model is None:\n",
" model = SentenceTransformer(model_save_path)\n",
"test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(formated_validation_dataset, name=\"sts-test\")\n",
"accuracy = test_evaluator(model, output_path=model_save_path)\n",
"print(f\"Accuracy on STS test dataset: {accuracy}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### View the model performance"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" epoch \n",
" steps \n",
" cosine_pearson \n",
" cosine_spearman \n",
" euclidean_pearson \n",
" euclidean_spearman \n",
" manhattan_pearson \n",
" manhattan_spearman \n",
" dot_pearson \n",
" dot_spearman \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0 \n",
" -1 \n",
" 0.831459 \n",
" 0.822286 \n",
" 0.799320 \n",
" 0.795001 \n",
" 0.798473 \n",
" 0.794257 \n",
" 0.698020 \n",
" 0.691569 \n",
" \n",
" \n",
" 1 \n",
" 1 \n",
" -1 \n",
" 0.840706 \n",
" 0.834337 \n",
" 0.815210 \n",
" 0.812220 \n",
" 0.814498 \n",
" 0.811689 \n",
" 0.732746 \n",
" 0.730361 \n",
" \n",
" \n",
" 2 \n",
" 2 \n",
" -1 \n",
" 0.843926 \n",
" 0.835904 \n",
" 0.826207 \n",
" 0.822151 \n",
" 0.825924 \n",
" 0.821681 \n",
" 0.748794 \n",
" 0.737895 \n",
" \n",
" \n",
" 3 \n",
" 3 \n",
" -1 \n",
" 0.846428 \n",
" 0.839487 \n",
" 0.828132 \n",
" 0.825238 \n",
" 0.827827 \n",
" 0.824739 \n",
" 0.755207 \n",
" 0.744858 \n",
" \n",
" \n",
" 4 \n",
" 0 \n",
" -1 \n",
" 0.792202 \n",
" 0.771397 \n",
" 0.760367 \n",
" 0.743911 \n",
" 0.760203 \n",
" 0.744113 \n",
" 0.681647 \n",
" 0.667173 \n",
" \n",
" \n",
" 5 \n",
" 1 \n",
" -1 \n",
" 0.835359 \n",
" 0.827130 \n",
" 0.807031 \n",
" 0.803834 \n",
" 0.806447 \n",
" 0.803400 \n",
" 0.724916 \n",
" 0.715752 \n",
" \n",
" \n",
" 6 \n",
" 2 \n",
" -1 \n",
" 0.838358 \n",
" 0.830313 \n",
" 0.819444 \n",
" 0.815761 \n",
" 0.819400 \n",
" 0.815993 \n",
" 0.742720 \n",
" 0.731048 \n",
" \n",
" \n",
" 7 \n",
" 3 \n",
" -1 \n",
" 0.842378 \n",
" 0.837650 \n",
" 0.822014 \n",
" 0.820538 \n",
" 0.821950 \n",
" 0.820751 \n",
" 0.746923 \n",
" 0.737437 \n",
" \n",
" \n",
" 8 \n",
" 4 \n",
" -1 \n",
" 0.843533 \n",
" 0.837812 \n",
" 0.830953 \n",
" 0.827699 \n",
" 0.831241 \n",
" 0.828058 \n",
" 0.760023 \n",
" 0.746661 \n",
" \n",
" \n",
" 9 \n",
" 5 \n",
" -1 \n",
" 0.842615 \n",
" 0.838848 \n",
" 0.827527 \n",
" 0.827053 \n",
" 0.827468 \n",
" 0.827264 \n",
" 0.758369 \n",
" 0.746365 \n",
" \n",
" \n",
" 10 \n",
" 6 \n",
" -1 \n",
" 0.843484 \n",
" 0.838789 \n",
" 0.827800 \n",
" 0.826239 \n",
" 0.827930 \n",
" 0.826622 \n",
" 0.756345 \n",
" 0.745615 \n",
" \n",
" \n",
" 11 \n",
" 7 \n",
" -1 \n",
" 0.844762 \n",
" 0.840291 \n",
" 0.828616 \n",
" 0.827343 \n",
" 0.828418 \n",
" 0.827361 \n",
" 0.763227 \n",
" 0.751573 \n",
" \n",
" \n",
" 12 \n",
" 8 \n",
" -1 \n",
" 0.843006 \n",
" 0.839770 \n",
" 0.827860 \n",
" 0.826893 \n",
" 0.827871 \n",
" 0.827051 \n",
" 0.765544 \n",
" 0.753510 \n",
" \n",
" \n",
" 13 \n",
" 9 \n",
" -1 \n",
" 0.846304 \n",
" 0.842857 \n",
" 0.828914 \n",
" 0.828065 \n",
" 0.828630 \n",
" 0.828003 \n",
" 0.765577 \n",
" 0.754536 \n",
" \n",
" \n",
" 14 \n",
" 10 \n",
" -1 \n",
" 0.845491 \n",
" 0.841096 \n",
" 0.831268 \n",
" 0.830173 \n",
" 0.831042 \n",
" 0.829934 \n",
" 0.766516 \n",
" 0.754324 \n",
" \n",
" \n",
" 15 \n",
" 11 \n",
" -1 \n",
" 0.845038 \n",
" 0.841520 \n",
" 0.830022 \n",
" 0.829324 \n",
" 0.829868 \n",
" 0.829125 \n",
" 0.764627 \n",
" 0.752546 \n",
" \n",
" \n",
" 16 \n",
" 12 \n",
" -1 \n",
" 0.845263 \n",
" 0.842207 \n",
" 0.831057 \n",
" 0.830302 \n",
" 0.830940 \n",
" 0.830102 \n",
" 0.766376 \n",
" 0.754835 \n",
" \n",
" \n",
" 17 \n",
" 13 \n",
" -1 \n",
" 0.844913 \n",
" 0.841911 \n",
" 0.830002 \n",
" 0.828979 \n",
" 0.829962 \n",
" 0.829043 \n",
" 0.766809 \n",
" 0.755277 \n",
" \n",
" \n",
" 18 \n",
" 14 \n",
" -1 \n",
" 0.844950 \n",
" 0.842389 \n",
" 0.830668 \n",
" 0.829802 \n",
" 0.830586 \n",
" 0.829736 \n",
" 0.766818 \n",
" 0.755222 \n",
" \n",
" \n",
" 19 \n",
" 15 \n",
" -1 \n",
" 0.845169 \n",
" 0.842546 \n",
" 0.830544 \n",
" 0.829747 \n",
" 0.830460 \n",
" 0.829629 \n",
" 0.767047 \n",
" 0.755781 \n",
" \n",
" \n",
" 20 \n",
" 0 \n",
" -1 \n",
" 0.789996 \n",
" 0.771507 \n",
" 0.751155 \n",
" 0.735840 \n",
" 0.750595 \n",
" 0.735447 \n",
" 0.656377 \n",
" 0.647560 \n",
" \n",
" \n",
" 21 \n",
" 1 \n",
" -1 \n",
" 0.834723 \n",
" 0.825629 \n",
" 0.805897 \n",
" 0.801195 \n",
" 0.805206 \n",
" 0.800686 \n",
" 0.714148 \n",
" 0.704298 \n",
" \n",
" \n",
" 22 \n",
" 2 \n",
" -1 \n",
" 0.843609 \n",
" 0.836768 \n",
" 0.824553 \n",
" 0.821647 \n",
" 0.824005 \n",
" 0.821232 \n",
" 0.750288 \n",
" 0.740117 \n",
" \n",
" \n",
" 23 \n",
" 3 \n",
" -1 \n",
" 0.844290 \n",
" 0.837315 \n",
" 0.823369 \n",
" 0.820839 \n",
" 0.822907 \n",
" 0.820771 \n",
" 0.752619 \n",
" 0.743209 \n",
" \n",
" \n",
" 24 \n",
" 4 \n",
" -1 \n",
" 0.848353 \n",
" 0.843894 \n",
" 0.829264 \n",
" 0.828715 \n",
" 0.829097 \n",
" 0.828821 \n",
" 0.739662 \n",
" 0.729897 \n",
" \n",
" \n",
" 25 \n",
" 5 \n",
" -1 \n",
" 0.846934 \n",
" 0.841312 \n",
" 0.828725 \n",
" 0.827166 \n",
" 0.828384 \n",
" 0.827049 \n",
" 0.762402 \n",
" 0.751935 \n",
" \n",
" \n",
" 26 \n",
" 6 \n",
" -1 \n",
" 0.845585 \n",
" 0.841405 \n",
" 0.829897 \n",
" 0.828822 \n",
" 0.829770 \n",
" 0.828740 \n",
" 0.764267 \n",
" 0.753591 \n",
" \n",
" \n",
" 27 \n",
" 7 \n",
" -1 \n",
" 0.845388 \n",
" 0.840861 \n",
" 0.831278 \n",
" 0.829617 \n",
" 0.831128 \n",
" 0.829610 \n",
" 0.767783 \n",
" 0.756129 \n",
" \n",
" \n",
" 28 \n",
" 8 \n",
" -1 \n",
" 0.846980 \n",
" 0.843605 \n",
" 0.831718 \n",
" 0.830741 \n",
" 0.831503 \n",
" 0.830459 \n",
" 0.768833 \n",
" 0.757720 \n",
" \n",
" \n",
" 29 \n",
" 9 \n",
" -1 \n",
" 0.846720 \n",
" 0.843634 \n",
" 0.829811 \n",
" 0.829512 \n",
" 0.829674 \n",
" 0.829268 \n",
" 0.765590 \n",
" 0.754997 \n",
" \n",
" \n",
" 30 \n",
" 10 \n",
" -1 \n",
" 0.847116 \n",
" 0.843249 \n",
" 0.832140 \n",
" 0.830753 \n",
" 0.831911 \n",
" 0.830722 \n",
" 0.767915 \n",
" 0.754851 \n",
" \n",
" \n",
" 31 \n",
" 11 \n",
" -1 \n",
" 0.847349 \n",
" 0.844004 \n",
" 0.832437 \n",
" 0.831659 \n",
" 0.832199 \n",
" 0.831515 \n",
" 0.769308 \n",
" 0.758294 \n",
" \n",
" \n",
" 32 \n",
" 12 \n",
" -1 \n",
" 0.847548 \n",
" 0.844449 \n",
" 0.832724 \n",
" 0.832165 \n",
" 0.832418 \n",
" 0.831973 \n",
" 0.768521 \n",
" 0.757657 \n",
" \n",
" \n",
" 33 \n",
" 13 \n",
" -1 \n",
" 0.847360 \n",
" 0.843996 \n",
" 0.832769 \n",
" 0.831874 \n",
" 0.832490 \n",
" 0.831656 \n",
" 0.768880 \n",
" 0.757353 \n",
" \n",
" \n",
" 34 \n",
" 14 \n",
" -1 \n",
" 0.848117 \n",
" 0.844973 \n",
" 0.832809 \n",
" 0.832048 \n",
" 0.832587 \n",
" 0.831915 \n",
" 0.769845 \n",
" 0.758487 \n",
" \n",
" \n",
" 35 \n",
" 15 \n",
" -1 \n",
" 0.848033 \n",
" 0.844908 \n",
" 0.832758 \n",
" 0.831945 \n",
" 0.832527 \n",
" 0.831946 \n",
" 0.769407 \n",
" 0.758096 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" epoch steps cosine_pearson cosine_spearman euclidean_pearson \\\n",
"0 0 -1 0.831459 0.822286 0.799320 \n",
"1 1 -1 0.840706 0.834337 0.815210 \n",
"2 2 -1 0.843926 0.835904 0.826207 \n",
"3 3 -1 0.846428 0.839487 0.828132 \n",
"4 0 -1 0.792202 0.771397 0.760367 \n",
"5 1 -1 0.835359 0.827130 0.807031 \n",
"6 2 -1 0.838358 0.830313 0.819444 \n",
"7 3 -1 0.842378 0.837650 0.822014 \n",
"8 4 -1 0.843533 0.837812 0.830953 \n",
"9 5 -1 0.842615 0.838848 0.827527 \n",
"10 6 -1 0.843484 0.838789 0.827800 \n",
"11 7 -1 0.844762 0.840291 0.828616 \n",
"12 8 -1 0.843006 0.839770 0.827860 \n",
"13 9 -1 0.846304 0.842857 0.828914 \n",
"14 10 -1 0.845491 0.841096 0.831268 \n",
"15 11 -1 0.845038 0.841520 0.830022 \n",
"16 12 -1 0.845263 0.842207 0.831057 \n",
"17 13 -1 0.844913 0.841911 0.830002 \n",
"18 14 -1 0.844950 0.842389 0.830668 \n",
"19 15 -1 0.845169 0.842546 0.830544 \n",
"20 0 -1 0.789996 0.771507 0.751155 \n",
"21 1 -1 0.834723 0.825629 0.805897 \n",
"22 2 -1 0.843609 0.836768 0.824553 \n",
"23 3 -1 0.844290 0.837315 0.823369 \n",
"24 4 -1 0.848353 0.843894 0.829264 \n",
"25 5 -1 0.846934 0.841312 0.828725 \n",
"26 6 -1 0.845585 0.841405 0.829897 \n",
"27 7 -1 0.845388 0.840861 0.831278 \n",
"28 8 -1 0.846980 0.843605 0.831718 \n",
"29 9 -1 0.846720 0.843634 0.829811 \n",
"30 10 -1 0.847116 0.843249 0.832140 \n",
"31 11 -1 0.847349 0.844004 0.832437 \n",
"32 12 -1 0.847548 0.844449 0.832724 \n",
"33 13 -1 0.847360 0.843996 0.832769 \n",
"34 14 -1 0.848117 0.844973 0.832809 \n",
"35 15 -1 0.848033 0.844908 0.832758 \n",
"\n",
" euclidean_spearman manhattan_pearson manhattan_spearman dot_pearson \\\n",
"0 0.795001 0.798473 0.794257 0.698020 \n",
"1 0.812220 0.814498 0.811689 0.732746 \n",
"2 0.822151 0.825924 0.821681 0.748794 \n",
"3 0.825238 0.827827 0.824739 0.755207 \n",
"4 0.743911 0.760203 0.744113 0.681647 \n",
"5 0.803834 0.806447 0.803400 0.724916 \n",
"6 0.815761 0.819400 0.815993 0.742720 \n",
"7 0.820538 0.821950 0.820751 0.746923 \n",
"8 0.827699 0.831241 0.828058 0.760023 \n",
"9 0.827053 0.827468 0.827264 0.758369 \n",
"10 0.826239 0.827930 0.826622 0.756345 \n",
"11 0.827343 0.828418 0.827361 0.763227 \n",
"12 0.826893 0.827871 0.827051 0.765544 \n",
"13 0.828065 0.828630 0.828003 0.765577 \n",
"14 0.830173 0.831042 0.829934 0.766516 \n",
"15 0.829324 0.829868 0.829125 0.764627 \n",
"16 0.830302 0.830940 0.830102 0.766376 \n",
"17 0.828979 0.829962 0.829043 0.766809 \n",
"18 0.829802 0.830586 0.829736 0.766818 \n",
"19 0.829747 0.830460 0.829629 0.767047 \n",
"20 0.735840 0.750595 0.735447 0.656377 \n",
"21 0.801195 0.805206 0.800686 0.714148 \n",
"22 0.821647 0.824005 0.821232 0.750288 \n",
"23 0.820839 0.822907 0.820771 0.752619 \n",
"24 0.828715 0.829097 0.828821 0.739662 \n",
"25 0.827166 0.828384 0.827049 0.762402 \n",
"26 0.828822 0.829770 0.828740 0.764267 \n",
"27 0.829617 0.831128 0.829610 0.767783 \n",
"28 0.830741 0.831503 0.830459 0.768833 \n",
"29 0.829512 0.829674 0.829268 0.765590 \n",
"30 0.830753 0.831911 0.830722 0.767915 \n",
"31 0.831659 0.832199 0.831515 0.769308 \n",
"32 0.832165 0.832418 0.831973 0.768521 \n",
"33 0.831874 0.832490 0.831656 0.768880 \n",
"34 0.832048 0.832587 0.831915 0.769845 \n",
"35 0.831945 0.832527 0.831946 0.769407 \n",
"\n",
" dot_spearman \n",
"0 0.691569 \n",
"1 0.730361 \n",
"2 0.737895 \n",
"3 0.744858 \n",
"4 0.667173 \n",
"5 0.715752 \n",
"6 0.731048 \n",
"7 0.737437 \n",
"8 0.746661 \n",
"9 0.746365 \n",
"10 0.745615 \n",
"11 0.751573 \n",
"12 0.753510 \n",
"13 0.754536 \n",
"14 0.754324 \n",
"15 0.752546 \n",
"16 0.754835 \n",
"17 0.755277 \n",
"18 0.755222 \n",
"19 0.755781 \n",
"20 0.647560 \n",
"21 0.704298 \n",
"22 0.740117 \n",
"23 0.743209 \n",
"24 0.729897 \n",
"25 0.751935 \n",
"26 0.753591 \n",
"27 0.756129 \n",
"28 0.757720 \n",
"29 0.754997 \n",
"30 0.754851 \n",
"31 0.758294 \n",
"32 0.757657 \n",
"33 0.757353 \n",
"34 0.758487 \n",
"35 0.758096 "
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Evaluate the model performance\n",
"eval_df = pd.read_csv(f\"{model_save_path}/eval/similarity_evaluation_sts-dev_results.csv\")\n",
"eval_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Plot the model performance evaluation"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# plot figure size\n",
"plt.figure(figsize=(12, 6))\n",
"# plot each column\n",
"for column in eval_df.drop(columns=['epoch', 'steps']).columns:\n",
" plt.plot(eval_df['epoch'], eval_df[column], label=column)\n",
"# put ledgets outside plot\n",
"plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('prediction accuracy')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test the model\n"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset json (/Users/charleskabue/.cache/huggingface/datasets/mteb___json/mteb--stsbenchmark-sts-998a21523b45a16a/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf477c12b90443f48828cd85cbf27e58",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b61d3c996934e999be7af415b13856c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b6b5e2bddae1406eac479653f67d5283",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a5e6ae0fa6f4186862f05c829a7ee38",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9bb6ab18289043dda6b5076cf16028b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d065326f9fc8459797014d08fd1afea0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f1156939a4c34c63a949c6089851aaa5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "350c3e23d4a7488fa1fe2a4a47408fe1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10ecdd3ae0424e4aad3c08fa044fe748",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "74d23769e9f842aead2804c0632c6a9c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bed8470ea561468f8fe3d42c9141b88b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a33978361c5b40cca6b84ad7595c2569",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/1500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" sentence1 \n",
" sentence2 \n",
" score \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" A man with a hard hat is dancing. \n",
" A man wearing a hard hat is dancing. \n",
" 5.0 \n",
" \n",
" \n",
" 1 \n",
" A man is fitting silencer on a pistol. \n",
" A man is adding a silencer to a gun. \n",
" 4.5 \n",
" \n",
" \n",
" 2 \n",
" Kittens are eating food. \n",
" Kittens are eating from dishes. \n",
" 4.0 \n",
" \n",
" \n",
" 3 \n",
" A woman is mixing ingrediants. \n",
" A woman is mixing food in a bowl. \n",
" 3.5 \n",
" \n",
" \n",
" 4 \n",
" A woman is cooking eggs. \n",
" A woman is cooking something. \n",
" 3.0 \n",
" \n",
" \n",
" 5 \n",
" Someone is beating an egg. \n",
" A woman stirs eggs in a bowl. \n",
" 2.5 \n",
" \n",
" \n",
" 6 \n",
" A small baby is playing a guitar. \n",
" A boy sits on a bed, sings and plays a guitar. \n",
" 2.0 \n",
" \n",
" \n",
" 7 \n",
" I think it is still feasible to store seeds un... \n",
" I haven't tried storing tomato seeds myself, b... \n",
" 1.5 \n",
" \n",
" \n",
" 8 \n",
" A man is playing soccer. \n",
" A man is playing flute. \n",
" 1.0 \n",
" \n",
" \n",
" 9 \n",
" Two little girls are talking on the phone. \n",
" A little girl is walking down the street. \n",
" 0.5 \n",
" \n",
" \n",
" 10 \n",
" The man is riding a horse. \n",
" A woman is using a hoe. \n",
" 0.0 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sentence1 \\\n",
"0 A man with a hard hat is dancing. \n",
"1 A man is fitting silencer on a pistol. \n",
"2 Kittens are eating food. \n",
"3 A woman is mixing ingrediants. \n",
"4 A woman is cooking eggs. \n",
"5 Someone is beating an egg. \n",
"6 A small baby is playing a guitar. \n",
"7 I think it is still feasible to store seeds un... \n",
"8 A man is playing soccer. \n",
"9 Two little girls are talking on the phone. \n",
"10 The man is riding a horse. \n",
"\n",
" sentence2 score \n",
"0 A man wearing a hard hat is dancing. 5.0 \n",
"1 A man is adding a silencer to a gun. 4.5 \n",
"2 Kittens are eating from dishes. 4.0 \n",
"3 A woman is mixing food in a bowl. 3.5 \n",
"4 A woman is cooking something. 3.0 \n",
"5 A woman stirs eggs in a bowl. 2.5 \n",
"6 A boy sits on a bed, sings and plays a guitar. 2.0 \n",
"7 I haven't tried storing tomato seeds myself, b... 1.5 \n",
"8 A man is playing flute. 1.0 \n",
"9 A little girl is walking down the street. 0.5 \n",
"10 A woman is using a hoe. 0.0 "
]
},
"execution_count": 116,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_samples = get_samples()\n",
"test_samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Cross-Encoder\n",
"\n",
"We pass both sentences simultaneously to the Transformer network. It produces then an output value between 0 and 1 indicating the similarity of the input sentence pair, see [cross-encoders-usage](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/cross-encoder#cross-encoders-usage).\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at trained_model_stsbenchmark_bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"cross_encoder_model = CrossEncoder(model_save_path)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.60892063"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_encoder_model.predict([\n",
" 'A man with a hard hat is dancing.',\n",
" 'A man wearing a hard hat is dancing.'])"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5701721"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_encoder_model.predict([\n",
" 'A dog and cat laying down together.',\n",
" 'Two grey dogs are carrying a stick in the water.'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Bi-Encoder\n",
"\n",
"Bi-Encoders produce sentence embedding. These sentence embedding can then be compared using cosine similarity.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = SentenceTransformer(model_save_path)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.9799072]], dtype=float32)"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cosine_similarity(\n",
" [model.encode('A man with a hard hat is dancing.')],\n",
" [model.encode('A man wearing a hard hat is dancing.')])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.13418931]], dtype=float32)"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cosine_similarity(\n",
" [model.encode('A dog and cat laying down together.')],\n",
" [model.encode('Two grey dogs are carrying a stick in the water.')])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparison\n",
"\n",
"Normally, Cross-Encoder achieve higher performance than Bi-Encoders, however, they do not scale well for large datasets, ([Reimers, Nils and Gurevych, Iryna](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/cross-encoder#combining-bi--and-cross-encoders)). But in the case of this dataset, Bi-Encoders achieve a higher, probably due to small dataset used for training.\n"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" sentence1 \n",
" sentence2 \n",
" score \n",
" normalized_score \n",
" Cross-Encoder \n",
" Bi-Encoder \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" A man with a hard hat is dancing. \n",
" A man wearing a hard hat is dancing. \n",
" 5.0 \n",
" 1.0 \n",
" 0.608921 \n",
" 0.979907 \n",
" \n",
" \n",
" 1 \n",
" A man is fitting silencer on a pistol. \n",
" A man is adding a silencer to a gun. \n",
" 4.5 \n",
" 0.9 \n",
" 0.614019 \n",
" 0.874845 \n",
" \n",
" \n",
" 2 \n",
" Kittens are eating food. \n",
" Kittens are eating from dishes. \n",
" 4.0 \n",
" 0.8 \n",
" 0.613255 \n",
" 0.872530 \n",
" \n",
" \n",
" 3 \n",
" A woman is mixing ingrediants. \n",
" A woman is mixing food in a bowl. \n",
" 3.5 \n",
" 0.7 \n",
" 0.602167 \n",
" 0.440890 \n",
" \n",
" \n",
" 4 \n",
" A woman is cooking eggs. \n",
" A woman is cooking something. \n",
" 3.0 \n",
" 0.6 \n",
" 0.599842 \n",
" 0.619852 \n",
" \n",
" \n",
" 5 \n",
" Someone is beating an egg. \n",
" A woman stirs eggs in a bowl. \n",
" 2.5 \n",
" 0.5 \n",
" 0.593724 \n",
" 0.435095 \n",
" \n",
" \n",
" 6 \n",
" A small baby is playing a guitar. \n",
" A boy sits on a bed, sings and plays a guitar. \n",
" 2.0 \n",
" 0.4 \n",
" 0.593979 \n",
" 0.505967 \n",
" \n",
" \n",
" 7 \n",
" I think it is still feasible to store seeds un... \n",
" I haven't tried storing tomato seeds myself, b... \n",
" 1.5 \n",
" 0.3 \n",
" 0.603493 \n",
" 0.331625 \n",
" \n",
" \n",
" 8 \n",
" A man is playing soccer. \n",
" A man is playing flute. \n",
" 1.0 \n",
" 0.2 \n",
" 0.576015 \n",
" 0.094140 \n",
" \n",
" \n",
" 9 \n",
" Two little girls are talking on the phone. \n",
" A little girl is walking down the street. \n",
" 0.5 \n",
" 0.1 \n",
" 0.614359 \n",
" 0.390625 \n",
" \n",
" \n",
" 10 \n",
" The man is riding a horse. \n",
" A woman is using a hoe. \n",
" 0.0 \n",
" 0.0 \n",
" 0.599499 \n",
" -0.028931 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sentence1 \\\n",
"0 A man with a hard hat is dancing. \n",
"1 A man is fitting silencer on a pistol. \n",
"2 Kittens are eating food. \n",
"3 A woman is mixing ingrediants. \n",
"4 A woman is cooking eggs. \n",
"5 Someone is beating an egg. \n",
"6 A small baby is playing a guitar. \n",
"7 I think it is still feasible to store seeds un... \n",
"8 A man is playing soccer. \n",
"9 Two little girls are talking on the phone. \n",
"10 The man is riding a horse. \n",
"\n",
" sentence2 score \\\n",
"0 A man wearing a hard hat is dancing. 5.0 \n",
"1 A man is adding a silencer to a gun. 4.5 \n",
"2 Kittens are eating from dishes. 4.0 \n",
"3 A woman is mixing food in a bowl. 3.5 \n",
"4 A woman is cooking something. 3.0 \n",
"5 A woman stirs eggs in a bowl. 2.5 \n",
"6 A boy sits on a bed, sings and plays a guitar. 2.0 \n",
"7 I haven't tried storing tomato seeds myself, b... 1.5 \n",
"8 A man is playing flute. 1.0 \n",
"9 A little girl is walking down the street. 0.5 \n",
"10 A woman is using a hoe. 0.0 \n",
"\n",
" normalized_score Cross-Encoder Bi-Encoder \n",
"0 1.0 0.608921 0.979907 \n",
"1 0.9 0.614019 0.874845 \n",
"2 0.8 0.613255 0.872530 \n",
"3 0.7 0.602167 0.440890 \n",
"4 0.6 0.599842 0.619852 \n",
"5 0.5 0.593724 0.435095 \n",
"6 0.4 0.593979 0.505967 \n",
"7 0.3 0.603493 0.331625 \n",
"8 0.2 0.576015 0.094140 \n",
"9 0.1 0.614359 0.390625 \n",
"10 0.0 0.599499 -0.028931 "
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_samples['normalized_score'] = test_samples['score'] / 5.0\n",
"test_samples['Cross-Encoder'] = test_samples.apply(\n",
" lambda x: cross_encoder_model.predict([x['sentence1'], x['sentence2']]), axis=1)\n",
"test_samples['Bi-Encoder'] = test_samples.apply(\n",
" lambda x: cosine_similarity([model.encode(x['sentence1'])],[model.encode(x['sentence2'])])[0][0], axis=1)\n",
"test_samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Other Text Comparisons\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Levenshtein Distance\n"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" sentence1 \n",
" sentence2 \n",
" score \n",
" normalized_score \n",
" Cross-Encoder \n",
" Bi-Encoder \n",
" Levenshtein \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" A man with a hard hat is dancing. \n",
" A man wearing a hard hat is dancing. \n",
" 5.0 \n",
" 1.0 \n",
" 0.608921 \n",
" 0.979907 \n",
" 0.861111 \n",
" \n",
" \n",
" 1 \n",
" A man is fitting silencer on a pistol. \n",
" A man is adding a silencer to a gun. \n",
" 4.5 \n",
" 0.9 \n",
" 0.614019 \n",
" 0.874845 \n",
" 0.631579 \n",
" \n",
" \n",
" 2 \n",
" Kittens are eating food. \n",
" Kittens are eating from dishes. \n",
" 4.0 \n",
" 0.8 \n",
" 0.613255 \n",
" 0.872530 \n",
" 0.741935 \n",
" \n",
" \n",
" 3 \n",
" A woman is mixing ingrediants. \n",
" A woman is mixing food in a bowl. \n",
" 3.5 \n",
" 0.7 \n",
" 0.602167 \n",
" 0.440890 \n",
" 0.606061 \n",
" \n",
" \n",
" 4 \n",
" A woman is cooking eggs. \n",
" A woman is cooking something. \n",
" 3.0 \n",
" 0.6 \n",
" 0.599842 \n",
" 0.619852 \n",
" 0.724138 \n",
" \n",
" \n",
" 5 \n",
" Someone is beating an egg. \n",
" A woman stirs eggs in a bowl. \n",
" 2.5 \n",
" 0.5 \n",
" 0.593724 \n",
" 0.435095 \n",
" 0.310345 \n",
" \n",
" \n",
" 6 \n",
" A small baby is playing a guitar. \n",
" A boy sits on a bed, sings and plays a guitar. \n",
" 2.0 \n",
" 0.4 \n",
" 0.593979 \n",
" 0.505967 \n",
" 0.456522 \n",
" \n",
" \n",
" 7 \n",
" I think it is still feasible to store seeds un... \n",
" I haven't tried storing tomato seeds myself, b... \n",
" 1.5 \n",
" 0.3 \n",
" 0.603493 \n",
" 0.331625 \n",
" 0.347826 \n",
" \n",
" \n",
" 8 \n",
" A man is playing soccer. \n",
" A man is playing flute. \n",
" 1.0 \n",
" 0.2 \n",
" 0.576015 \n",
" 0.094140 \n",
" 0.791667 \n",
" \n",
" \n",
" 9 \n",
" Two little girls are talking on the phone. \n",
" A little girl is walking down the street. \n",
" 0.5 \n",
" 0.1 \n",
" 0.614359 \n",
" 0.390625 \n",
" 0.642857 \n",
" \n",
" \n",
" 10 \n",
" The man is riding a horse. \n",
" A woman is using a hoe. \n",
" 0.0 \n",
" 0.0 \n",
" 0.599499 \n",
" -0.028931 \n",
" 0.653846 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sentence1 \\\n",
"0 A man with a hard hat is dancing. \n",
"1 A man is fitting silencer on a pistol. \n",
"2 Kittens are eating food. \n",
"3 A woman is mixing ingrediants. \n",
"4 A woman is cooking eggs. \n",
"5 Someone is beating an egg. \n",
"6 A small baby is playing a guitar. \n",
"7 I think it is still feasible to store seeds un... \n",
"8 A man is playing soccer. \n",
"9 Two little girls are talking on the phone. \n",
"10 The man is riding a horse. \n",
"\n",
" sentence2 score \\\n",
"0 A man wearing a hard hat is dancing. 5.0 \n",
"1 A man is adding a silencer to a gun. 4.5 \n",
"2 Kittens are eating from dishes. 4.0 \n",
"3 A woman is mixing food in a bowl. 3.5 \n",
"4 A woman is cooking something. 3.0 \n",
"5 A woman stirs eggs in a bowl. 2.5 \n",
"6 A boy sits on a bed, sings and plays a guitar. 2.0 \n",
"7 I haven't tried storing tomato seeds myself, b... 1.5 \n",
"8 A man is playing flute. 1.0 \n",
"9 A little girl is walking down the street. 0.5 \n",
"10 A woman is using a hoe. 0.0 \n",
"\n",
" normalized_score Cross-Encoder Bi-Encoder Levenshtein \n",
"0 1.0 0.608921 0.979907 0.861111 \n",
"1 0.9 0.614019 0.874845 0.631579 \n",
"2 0.8 0.613255 0.872530 0.741935 \n",
"3 0.7 0.602167 0.440890 0.606061 \n",
"4 0.6 0.599842 0.619852 0.724138 \n",
"5 0.5 0.593724 0.435095 0.310345 \n",
"6 0.4 0.593979 0.505967 0.456522 \n",
"7 0.3 0.603493 0.331625 0.347826 \n",
"8 0.2 0.576015 0.094140 0.791667 \n",
"9 0.1 0.614359 0.390625 0.642857 \n",
"10 0.0 0.599499 -0.028931 0.653846 "
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_samples['Levenshtein'] = test_samples.apply(\n",
" lambda x: textdistance.levenshtein.normalized_similarity(x['sentence1'], x['sentence2']), axis=1)\n",
"test_samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### TF-IDF\n"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" sentence1 \n",
" sentence2 \n",
" score \n",
" normalized_score \n",
" Cross-Encoder \n",
" Bi-Encoder \n",
" Levenshtein \n",
" TF-IDF \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" A man with a hard hat is dancing. \n",
" A man wearing a hard hat is dancing. \n",
" 5.0 \n",
" 1.0 \n",
" 0.608921 \n",
" 0.979907 \n",
" 0.861111 \n",
" 0.716812 \n",
" \n",
" \n",
" 1 \n",
" A man is fitting silencer on a pistol. \n",
" A man is adding a silencer to a gun. \n",
" 4.5 \n",
" 0.9 \n",
" 0.614019 \n",
" 0.874845 \n",
" 0.631579 \n",
" 0.336097 \n",
" \n",
" \n",
" 2 \n",
" Kittens are eating food. \n",
" Kittens are eating from dishes. \n",
" 4.0 \n",
" 0.8 \n",
" 0.613255 \n",
" 0.872530 \n",
" 0.741935 \n",
" 0.510149 \n",
" \n",
" \n",
" 3 \n",
" A woman is mixing ingrediants. \n",
" A woman is mixing food in a bowl. \n",
" 3.5 \n",
" 0.7 \n",
" 0.602167 \n",
" 0.440890 \n",
" 0.606061 \n",
" 0.450176 \n",
" \n",
" \n",
" 4 \n",
" A woman is cooking eggs. \n",
" A woman is cooking something. \n",
" 3.0 \n",
" 0.6 \n",
" 0.599842 \n",
" 0.619852 \n",
" 0.724138 \n",
" 0.602975 \n",
" \n",
" \n",
" 5 \n",
" Someone is beating an egg. \n",
" A woman stirs eggs in a bowl. \n",
" 2.5 \n",
" 0.5 \n",
" 0.593724 \n",
" 0.435095 \n",
" 0.310345 \n",
" 0.000000 \n",
" \n",
" \n",
" 6 \n",
" A small baby is playing a guitar. \n",
" A boy sits on a bed, sings and plays a guitar. \n",
" 2.0 \n",
" 0.4 \n",
" 0.593979 \n",
" 0.505967 \n",
" 0.456522 \n",
" 0.087044 \n",
" \n",
" \n",
" 7 \n",
" I think it is still feasible to store seeds un... \n",
" I haven't tried storing tomato seeds myself, b... \n",
" 1.5 \n",
" 0.3 \n",
" 0.603493 \n",
" 0.331625 \n",
" 0.347826 \n",
" 0.143098 \n",
" \n",
" \n",
" 8 \n",
" A man is playing soccer. \n",
" A man is playing flute. \n",
" 1.0 \n",
" 0.2 \n",
" 0.576015 \n",
" 0.094140 \n",
" 0.791667 \n",
" 0.602975 \n",
" \n",
" \n",
" 9 \n",
" Two little girls are talking on the phone. \n",
" A little girl is walking down the street. \n",
" 0.5 \n",
" 0.1 \n",
" 0.614359 \n",
" 0.390625 \n",
" 0.642857 \n",
" 0.155929 \n",
" \n",
" \n",
" 10 \n",
" The man is riding a horse. \n",
" A woman is using a hoe. \n",
" 0.0 \n",
" 0.0 \n",
" 0.599499 \n",
" -0.028931 \n",
" 0.653846 \n",
" 0.127360 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sentence1 \\\n",
"0 A man with a hard hat is dancing. \n",
"1 A man is fitting silencer on a pistol. \n",
"2 Kittens are eating food. \n",
"3 A woman is mixing ingrediants. \n",
"4 A woman is cooking eggs. \n",
"5 Someone is beating an egg. \n",
"6 A small baby is playing a guitar. \n",
"7 I think it is still feasible to store seeds un... \n",
"8 A man is playing soccer. \n",
"9 Two little girls are talking on the phone. \n",
"10 The man is riding a horse. \n",
"\n",
" sentence2 score \\\n",
"0 A man wearing a hard hat is dancing. 5.0 \n",
"1 A man is adding a silencer to a gun. 4.5 \n",
"2 Kittens are eating from dishes. 4.0 \n",
"3 A woman is mixing food in a bowl. 3.5 \n",
"4 A woman is cooking something. 3.0 \n",
"5 A woman stirs eggs in a bowl. 2.5 \n",
"6 A boy sits on a bed, sings and plays a guitar. 2.0 \n",
"7 I haven't tried storing tomato seeds myself, b... 1.5 \n",
"8 A man is playing flute. 1.0 \n",
"9 A little girl is walking down the street. 0.5 \n",
"10 A woman is using a hoe. 0.0 \n",
"\n",
" normalized_score Cross-Encoder Bi-Encoder Levenshtein TF-IDF \n",
"0 1.0 0.608921 0.979907 0.861111 0.716812 \n",
"1 0.9 0.614019 0.874845 0.631579 0.336097 \n",
"2 0.8 0.613255 0.872530 0.741935 0.510149 \n",
"3 0.7 0.602167 0.440890 0.606061 0.450176 \n",
"4 0.6 0.599842 0.619852 0.724138 0.602975 \n",
"5 0.5 0.593724 0.435095 0.310345 0.000000 \n",
"6 0.4 0.593979 0.505967 0.456522 0.087044 \n",
"7 0.3 0.603493 0.331625 0.347826 0.143098 \n",
"8 0.2 0.576015 0.094140 0.791667 0.602975 \n",
"9 0.1 0.614359 0.390625 0.642857 0.155929 \n",
"10 0.0 0.599499 -0.028931 0.653846 0.127360 "
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tfidf_vectorizer = TfidfVectorizer()\n",
"\n",
"test_samples['TF-IDF'] = test_samples.apply(\n",
" lambda x: cosine_similarity(tfidf_vectorizer.fit_transform([x['sentence1'], x['sentence2']]))[0][1], axis=1)\n",
"test_samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Random Forest\n"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
"# Initialize the sentence transformer model\n",
"bert_based_tokenizer = SentenceTransformer(model_save_path)\n",
"# Scale features\n",
"scaler = StandardScaler()\n",
"\n",
"def get_features(dataset):\n",
" # Convert sentences to embeddings\n",
" embeddings1 = bert_based_tokenizer.encode(dataset['sentence1'], convert_to_tensor=True).cpu()\n",
" embeddings2 = bert_based_tokenizer.encode(dataset['sentence2'], convert_to_tensor=True).cpu()\n",
" # Calculate the difference of embeddings as features\n",
" features = abs(embeddings1 - embeddings2).numpy()\n",
" # Labels\n",
" labels = np.array(dataset['score']) / 5.0\n",
" return scaler.fit_transform(features), labels\n",
"\n",
"X_train, y_train = get_features(train_dataset)\n",
"X_test, y_test = get_features(test_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean Squared Error: 0.03355835914744937\n"
]
}
],
"source": [
"# Initialize the Random Forest regressor\n",
"random_forest = RandomForestRegressor(n_estimators=100, random_state=42)\n",
"\n",
"# Train the model\n",
"random_forest.fit(X_train, y_train)\n",
"\n",
"# Predict on the test set\n",
"y_pred = random_forest.predict(X_test)\n",
"\n",
"# Calculate the mean squared error\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"print(f\"Mean Squared Error: {mse}\")"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['trained_model_random_forest.joblib']"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Save the model\n",
"joblib.dump(random_forest, 'trained_model_random_forest.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" sentence1 \n",
" sentence2 \n",
" score \n",
" normalized_score \n",
" Cross-Encoder \n",
" Bi-Encoder \n",
" Levenshtein \n",
" TF-IDF \n",
" RandomForest \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" A man with a hard hat is dancing. \n",
" A man wearing a hard hat is dancing. \n",
" 5.0 \n",
" 1.0 \n",
" 0.608921 \n",
" 0.979907 \n",
" 0.861111 \n",
" 0.716812 \n",
" 0.969502 \n",
" \n",
" \n",
" 1 \n",
" A man is fitting silencer on a pistol. \n",
" A man is adding a silencer to a gun. \n",
" 4.5 \n",
" 0.9 \n",
" 0.614019 \n",
" 0.874845 \n",
" 0.631579 \n",
" 0.336097 \n",
" 0.748200 \n",
" \n",
" \n",
" 2 \n",
" Kittens are eating food. \n",
" Kittens are eating from dishes. \n",
" 4.0 \n",
" 0.8 \n",
" 0.613255 \n",
" 0.872530 \n",
" 0.741935 \n",
" 0.510149 \n",
" 0.723038 \n",
" \n",
" \n",
" 3 \n",
" A woman is mixing ingrediants. \n",
" A woman is mixing food in a bowl. \n",
" 3.5 \n",
" 0.7 \n",
" 0.602167 \n",
" 0.440890 \n",
" 0.606061 \n",
" 0.450176 \n",
" 0.372034 \n",
" \n",
" \n",
" 4 \n",
" A woman is cooking eggs. \n",
" A woman is cooking something. \n",
" 3.0 \n",
" 0.6 \n",
" 0.599842 \n",
" 0.619852 \n",
" 0.724138 \n",
" 0.602975 \n",
" 0.453100 \n",
" \n",
" \n",
" 5 \n",
" Someone is beating an egg. \n",
" A woman stirs eggs in a bowl. \n",
" 2.5 \n",
" 0.5 \n",
" 0.593724 \n",
" 0.435095 \n",
" 0.310345 \n",
" 0.000000 \n",
" 0.312862 \n",
" \n",
" \n",
" 6 \n",
" A small baby is playing a guitar. \n",
" A boy sits on a bed, sings and plays a guitar. \n",
" 2.0 \n",
" 0.4 \n",
" 0.593979 \n",
" 0.505967 \n",
" 0.456522 \n",
" 0.087044 \n",
" 0.345736 \n",
" \n",
" \n",
" 7 \n",
" I think it is still feasible to store seeds un... \n",
" I haven't tried storing tomato seeds myself, b... \n",
" 1.5 \n",
" 0.3 \n",
" 0.603493 \n",
" 0.331625 \n",
" 0.347826 \n",
" 0.143098 \n",
" 0.440634 \n",
" \n",
" \n",
" 8 \n",
" A man is playing soccer. \n",
" A man is playing flute. \n",
" 1.0 \n",
" 0.2 \n",
" 0.576015 \n",
" 0.094140 \n",
" 0.791667 \n",
" 0.602975 \n",
" 0.217502 \n",
" \n",
" \n",
" 9 \n",
" Two little girls are talking on the phone. \n",
" A little girl is walking down the street. \n",
" 0.5 \n",
" 0.1 \n",
" 0.614359 \n",
" 0.390625 \n",
" 0.642857 \n",
" 0.155929 \n",
" 0.411490 \n",
" \n",
" \n",
" 10 \n",
" The man is riding a horse. \n",
" A woman is using a hoe. \n",
" 0.0 \n",
" 0.0 \n",
" 0.599499 \n",
" -0.028931 \n",
" 0.653846 \n",
" 0.127360 \n",
" 0.235032 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sentence1 \\\n",
"0 A man with a hard hat is dancing. \n",
"1 A man is fitting silencer on a pistol. \n",
"2 Kittens are eating food. \n",
"3 A woman is mixing ingrediants. \n",
"4 A woman is cooking eggs. \n",
"5 Someone is beating an egg. \n",
"6 A small baby is playing a guitar. \n",
"7 I think it is still feasible to store seeds un... \n",
"8 A man is playing soccer. \n",
"9 Two little girls are talking on the phone. \n",
"10 The man is riding a horse. \n",
"\n",
" sentence2 score \\\n",
"0 A man wearing a hard hat is dancing. 5.0 \n",
"1 A man is adding a silencer to a gun. 4.5 \n",
"2 Kittens are eating from dishes. 4.0 \n",
"3 A woman is mixing food in a bowl. 3.5 \n",
"4 A woman is cooking something. 3.0 \n",
"5 A woman stirs eggs in a bowl. 2.5 \n",
"6 A boy sits on a bed, sings and plays a guitar. 2.0 \n",
"7 I haven't tried storing tomato seeds myself, b... 1.5 \n",
"8 A man is playing flute. 1.0 \n",
"9 A little girl is walking down the street. 0.5 \n",
"10 A woman is using a hoe. 0.0 \n",
"\n",
" normalized_score Cross-Encoder Bi-Encoder Levenshtein TF-IDF \\\n",
"0 1.0 0.608921 0.979907 0.861111 0.716812 \n",
"1 0.9 0.614019 0.874845 0.631579 0.336097 \n",
"2 0.8 0.613255 0.872530 0.741935 0.510149 \n",
"3 0.7 0.602167 0.440890 0.606061 0.450176 \n",
"4 0.6 0.599842 0.619852 0.724138 0.602975 \n",
"5 0.5 0.593724 0.435095 0.310345 0.000000 \n",
"6 0.4 0.593979 0.505967 0.456522 0.087044 \n",
"7 0.3 0.603493 0.331625 0.347826 0.143098 \n",
"8 0.2 0.576015 0.094140 0.791667 0.602975 \n",
"9 0.1 0.614359 0.390625 0.642857 0.155929 \n",
"10 0.0 0.599499 -0.028931 0.653846 0.127360 \n",
"\n",
" RandomForest \n",
"0 0.969502 \n",
"1 0.748200 \n",
"2 0.723038 \n",
"3 0.372034 \n",
"4 0.453100 \n",
"5 0.312862 \n",
"6 0.345736 \n",
"7 0.440634 \n",
"8 0.217502 \n",
"9 0.411490 \n",
"10 0.235032 "
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from encode_sentences import encode_sentences\n",
"\n",
"test_samples['RandomForest'] = test_samples.apply(\n",
" lambda x: random_forest.predict(encode_sentences(model, x['sentence1'], x['sentence2']))[0], axis=1)\n",
"test_samples"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.637266])"
]
},
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from encode_sentences import encode_sentences\n",
"\n",
"joblib.load('trained_model_random_forest.joblib').predict(encode_sentences(model, 'sentence1', 'sentence2'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deployment\n"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0m\n",
"\u001b[34m\u001b[1m You can now view your Streamlit app in your browser.\u001b[0m\n",
"\u001b[0m\n",
"\u001b[34m Local URL: \u001b[0m\u001b[1mhttp://localhost:8501\u001b[0m\n",
"\u001b[34m Network URL: \u001b[0m\u001b[1mhttp://192.168.1.107:8501\u001b[0m\n",
"\u001b[0m\n",
"\u001b[34m\u001b[1m For better performance, install the Watchdog module:\u001b[0m\n",
"\n",
" $ xcode-select --install\n",
" $ pip install watchdog\n",
" \u001b[0m\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at trained_model_stsbenchmark_bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at trained_model_stsbenchmark_bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"^C\n",
"\u001b[34m Stopping...\u001b[0m\n"
]
}
],
"source": [
"!streamlit run app.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code and deployment are at: https://huggingface.co/spaces/mckabue/text-similarity-prediction-and-analysis\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dss-env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}