{ "cells": [ { "cell_type": "markdown", "id": "2e1c1ce8", "metadata": {}, "source": [ "# LSH and Document Similarity" ] }, { "cell_type": "markdown", "id": "cb10f331", "metadata": {}, "source": [ "In this section, we test out the `alis.similarity.LSH` class on the [News Group Dataset](https://scikit-learn.org/stable/datasets/real_world.html#newsgroups-dataset) data available in `sklearn`. We initially follow the steps in the Minhashing section to get the minhash signatures of the dataset. Note that, here, only the first $2000$ documents are used for simplicity in inspection of similar items.\n", "\n", "Since `alis.similarity` and `alis.feature_extraction` makes use of `dask`, we first initialize a `dask.distributed.Client`." ] }, { "cell_type": "markdown", "id": "fcb05e9e", "metadata": {}, "source": [ "## Minhashing" ] }, { "cell_type": "code", "execution_count": 1, "id": "ba8b1a4d", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:16.145772Z", "start_time": "2022-04-01T04:56:12.797168Z" } }, "outputs": [], "source": [ "from dask.distributed import Client\n", "import dask.bag as db\n", "\n", "client = Client()" ] }, { "cell_type": "markdown", "id": "b74cf3e7", "metadata": {}, "source": [ "We first load the dataset from `sklearn` and create a `dask.bag` with elements corresponding to `(identifier, text)` tuples. " ] }, { "cell_type": "code", "execution_count": 2, "id": "a66a4b80", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:18.344573Z", "start_time": "2022-04-01T04:56:16.150650Z" } }, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "\n", "# Load the news group dataset without the headers, footers and quotes\n", "newsgroup = fetch_20newsgroups(\n", " subset='train', remove=('headers', 'footers', 'quotes'),\n", " random_state=11,\n", ")\n", "newsgroup_data = newsgroup['data'][0:2000] # load only first 3000 documents\n", "\n", "newsgroup_bag = db.from_sequence(\n", " zip(range(len(newsgroup_data)), newsgroup_data))" ] }, { "cell_type": "markdown", "id": "9c250f26", "metadata": {}, "source": [ "We then implement a simple cleaning step that normalizes all the blank space characters into a single space and then removes all alpha-numeric characters. We also cast all the text to lowercase for simplicity. `dask.bag` allows for mapping a function onto its elements using the `.map()` method." ] }, { "cell_type": "code", "execution_count": 3, "id": "ed0bd3d7", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:18.355305Z", "start_time": "2022-04-01T04:56:18.348032Z" } }, "outputs": [], "source": [ "import re\n", "\n", "def clean_text(text):\n", " \"\"\"Clean text by removing non-alphanumeric characters and replacing\n", " all blank space characters into a single space\n", " \"\"\"\n", " return (re.sub(r'\\s', r' ', re.sub(r'[^\\w\\s]', r'', text)).lower())" ] }, { "cell_type": "code", "execution_count": 4, "id": "0749651a", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:18.380522Z", "start_time": "2022-04-01T04:56:18.357698Z" } }, "outputs": [], "source": [ "newsgroup_bag_cleaned = newsgroup_bag.map(lambda x: (x[0], clean_text(x[1])))" ] }, { "cell_type": "markdown", "id": "d2c0c788", "metadata": {}, "source": [ "We now apply minhashing onto the text of each document in `newsgroup_bag_cleaned`. This is done by simply using the `MinhashLSH` class' `transform()` method on the `dask.bag`. In thie example below, we instantiate the class by initializing the parameters for shingling (i.e., `shingle_size`) and for the minhashing step (i.e., `num_shingle_bucket`, `num_hash`, and `hash_size`. Since there are $12$ hashes, we can expect that the signature matrix returned by `transform()` for each document is of size $12$." ] }, { "cell_type": "code", "execution_count": 5, "id": "0cc3209e", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:18.394535Z", "start_time": "2022-04-01T04:56:18.383435Z" } }, "outputs": [], "source": [ "from alis.feature_extraction import MinhashLSH" ] }, { "cell_type": "code", "execution_count": 6, "id": "8fc0424d", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:18.399899Z", "start_time": "2022-04-01T04:56:18.396081Z" } }, "outputs": [], "source": [ "minhasher = MinhashLSH(shingle_size=3, num_shingle_bucket=12, num_hash=12,\n", " hash_size=2**12)" ] }, { "cell_type": "code", "execution_count": 7, "id": "a1cb0a34", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:18.407771Z", "start_time": "2022-04-01T04:56:18.402105Z" } }, "outputs": [], "source": [ "newsgroup_signatures = minhasher.transform(newsgroup_bag_cleaned)" ] }, { "cell_type": "markdown", "id": "49115ea1", "metadata": {}, "source": [ "We can inspect 10 random elements from the bag of signatures using `take(10)` method. Observe that indeed, signatures are all lists of size $12$, with elements in the bag as `(indentifier, signature)` tuples. " ] }, { "cell_type": "code", "execution_count": 8, "id": "c0a70eb9", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:19.266025Z", "start_time": "2022-04-01T04:56:18.410081Z" } }, "outputs": [ { "data": { "text/plain": [ "((0, [1494, 2469, 2651, 671, 507, 1425, 159, 417, 1644, 1514, 1163, 350]),\n", " (1, [33, 12, 196, 37, 8, 6, 6, 39, 6, 64, 15, 10]),\n", " (2, [4, 12, 26, 7, 38, 27, 15, 39, 6, 101, 1, 45]),\n", " (3, [47, 75, 46, 8, 21, 15, 69, 39, 66, 95, 88, 5]),\n", " (4, [40, 12, 36, 2, 85, 3, 15, 39, 11, 4, 0, 5]),\n", " (5, [8, 12, 1, 1, 13, 3, 15, 39, 28, 16, 15, 5]),\n", " (6, [1323, 264, 1836, 1364, 264, 999, 978, 1614, 2111, 2854, 799, 3095]),\n", " (7, [79, 12, 31, 87, 140, 144, 6, 39, 146, 62, 5, 135]),\n", " (8, [240, 138, 231, 47, 36, 324, 6, 228, 61, 24, 48, 15]),\n", " (9, [246, 75, 691, 94, 142, 765, 357, 291, 10, 56, 37, 405]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "newsgroup_signatures.take(10)" ] }, { "cell_type": "markdown", "id": "9f4f3942", "metadata": {}, "source": [ "The `minhasher.transform` step also filters the documents by taking only those that contain at least $1$ shingle. Hence we can observe that documents with fewer characters may not accomodate the `shingle_size` size chosen. Let's inspect the new size from the original $2000$ documents." ] }, { "cell_type": "code", "execution_count": 9, "id": "b5e5210f", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:21.067652Z", "start_time": "2022-04-01T04:56:19.269754Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "New Number of Documents: 1927\n" ] } ], "source": [ "signatures_list = newsgroup_signatures.compute()\n", "print(\"New Number of Documents: \", len(signatures_list))" ] }, { "cell_type": "markdown", "id": "3c024e47", "metadata": {}, "source": [ "## LSH Banding Technique" ] }, { "cell_type": "markdown", "id": "3fb35a76", "metadata": {}, "source": [ "We now perform locality-sensitive hashing on the minhash signatures, specifically using the banding technique slicing the singatures of size $15$ into bands (i.e., `bands=3`). We then inspect and verify, using the `r` and `bands` attribute the diminsions of the each signature band." ] }, { "cell_type": "code", "execution_count": 10, "id": "daab2c3f", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:22.349098Z", "start_time": "2022-04-01T04:56:21.070660Z" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Rows per band: 4\n", "Number of bands: 3\n" ] } ], "source": [ "from alis.similarity import LSH\n", "\n", "lsh = LSH(newsgroup_signatures)\n", "lsh.make_bands(bands=3)\n", "print(\"Rows per band: \", lsh.r)\n", "print(\"Number of bands: \", lsh.bands)" ] }, { "cell_type": "markdown", "id": "a67c0378", "metadata": {}, "source": [ "We then apply the `.get_buckets()` method, which is the per-band hashing step in the traditional LSH where we apply a hash function on each band that hashes the similar documents into the same bucket (i.e., candidate pairs). \n", "\n", "Inspecting the `buckets` show that we have a dictionary with keys as the band identifier and `dask.bag` as values." ] }, { "cell_type": "code", "execution_count": 11, "id": "2d7296d7", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:22.503481Z", "start_time": "2022-04-01T04:56:22.351457Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Group of buckets: 3\n" ] }, { "data": { "text/plain": [ "{0: dask.bag,\n", " 1: dask.bag,\n", " 2: dask.bag}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "buckets = lsh.get_buckets()\n", "print(\"Group of buckets: \", len(buckets.keys()))\n", "\n", "# inspect the structure of the buckets\n", "display(buckets)" ] }, { "cell_type": "markdown", "id": "8618d99a", "metadata": {}, "source": [ "Lets further inspect the contents of the band identified as `0` using the `.take(10)` method. We can see that the elements are tuples of `(hash bucket, list of document identifiers)`. Elements of the list for each tuple correspond to documents that are hashed in that hash value. It follows that list contaning more than one element, surface candidate pairs." ] }, { "cell_type": "code", "execution_count": 12, "id": "701a655c", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:26.266182Z", "start_time": "2022-04-01T04:56:22.505423Z" } }, "outputs": [ { "data": { "text/plain": [ "((5349837036471001827, [0]),\n", " (4061131334253544622, [1]),\n", " (-3201416287804991259, [2]),\n", " (6009770189650563317, [3]),\n", " (-6239275907353270376, [4]),\n", " (7074500024779176862, [5]),\n", " (9150421557748798632, [6]),\n", " (-1140712557129957939, [7]),\n", " (-5415297283791330710, [8]),\n", " (1150195895645989385, [9]))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "buckets[0].repartition(10).take(10)" ] }, { "cell_type": "markdown", "id": "1bd4d8a7", "metadata": {}, "source": [ "In this section, we inspect the candidate pairs by showing only those hash buckets (i.e., hash values) whose lists of document identifiers have size greater than $1$ for each of the 3 bands. We surfaced a candidate pair for atleast one of the bands!" ] }, { "cell_type": "code", "execution_count": 13, "id": "41d90e57", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:28.660010Z", "start_time": "2022-04-01T04:56:26.269149Z" } }, "outputs": [ { "data": { "text/plain": [ "[(5323470252749634725, [1362, 1674]),\n", " (9133247501793010434, [1376, 1639, 1973]),\n", " (-5131043885278292129, [1402, 1530])]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# candidate pair for band zero\n", "buckets[0].filter(lambda x: len(x[1]) > 1).compute()" ] }, { "cell_type": "code", "execution_count": 14, "id": "72a0d0b9", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:31.250886Z", "start_time": "2022-04-01T04:56:28.665916Z" } }, "outputs": [ { "data": { "text/plain": [ "[(9133247501793010434, [211, 1720]),\n", " (3328690425592746002, [431, 1855]),\n", " (5447225194283883763, [676, 1040]),\n", " (5323470252749634725, [1362, 1674])]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# candidate pair for band 1\n", "buckets[1].filter(lambda x: len(x[1]) > 1).compute()" ] }, { "cell_type": "code", "execution_count": 15, "id": "fb125d59", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:33.971550Z", "start_time": "2022-04-01T04:56:31.253916Z" } }, "outputs": [ { "data": { "text/plain": [ "[(-8431313624916672304, [320, 1673]), (413562760039724022, [516, 1281, 1781])]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# candidate pair for band 2\n", "buckets[2].filter(lambda x: len(x[1]) > 1).compute()" ] }, { "cell_type": "markdown", "id": "547ec50f", "metadata": {}, "source": [ "We can also inspect further by looking at the actual text, you can do so by printing `newsgroups[index]` where index are document identifiers surfaced to be candidate pairs.\n", "\n", "We now inspect the s-curve: a plot showing the probability that some actual pairwise similarity will be considered a candidate pair. We see here the similiraty threshold which marks the minimum similarity value for a pair to be considered candidate. This threshold is dependent on the number of bands $b$ and the resulting number of rows $r$ and is approximated by ${1 / b}^{1/r}$." ] }, { "cell_type": "code", "execution_count": 16, "id": "8c8b252f", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:34.355561Z", "start_time": "2022-04-01T04:56:33.975628Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plotting the s-curve\n", "ax, thresh = lsh.plot_thresh(display_thresh=True, lw=3, c='deeppink')" ] }, { "cell_type": "markdown", "id": "35829854", "metadata": {}, "source": [ "## Exhaustive Search" ] }, { "cell_type": "markdown", "id": "6cda8e5f", "metadata": {}, "source": [ "We can do the exhaustive pairwise similarity calculation and verify that indeed, the similarity threshold shown in the figure above aligns with the distributio of actual similarities of documents in our corpus. For simplicity, we opted in calculating the similarities of signatures. *Note: Depending on your machine and the number of documents used, this exhaustive step may crash due to memory constraints.*" ] }, { "cell_type": "code", "execution_count": 17, "id": "d43cae86", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:34.361877Z", "start_time": "2022-04-01T04:56:34.357792Z" } }, "outputs": [], "source": [ "signatures_dict = {k:v for k, v in signatures_list}" ] }, { "cell_type": "code", "execution_count": 18, "id": "8198edc7", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:56:59.510524Z", "start_time": "2022-04-01T04:56:34.364441Z" } }, "outputs": [], "source": [ "from itertools import combinations\n", "from scipy.spatial.distance import jaccard\n", "\n", "def jaccard_sim(u, v):\n", " return 1 - jaccard(u, v)\n", "\n", "similarities = [jaccard_sim(signatures_dict[u_idx], signatures_dict[v_idx]) \\\n", " for u_idx, v_idx in combinations(list(signatures_dict.keys()), 2)]\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "687e5ac8", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:57:06.437226Z", "start_time": "2022-04-01T04:56:59.513388Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# jaccard similarity distribution\n", "import matplotlib.pyplot as plt\n", "\n", "_, ax = plt.subplots(figsize=(8, 5))\n", "ax.hist(similarities, bins=10, log=True, color='deeppink')\n", "ax.set_title('Distribution of similarities', fontsize=18)\n", "ax.set_xlabel('Similarity Score (Jaccard)', fontsize=13)\n", "ax.set_ylabel('Count in log scale', fontsize=13)\n", "ax.set_xlim(0,1)\n", "ax.axvline(thresh, color='black', linestyle='--',\n", " label=f'Similarity Threshold: {thresh:.2f}')\n", "\n", "# Hide the right and top spines\n", "ax.spines['right'].set_visible(False)\n", "ax.spines['top'].set_visible(False)\n", "\n", "# set spines lw\n", "ax.spines['left'].set_linewidth(3)\n", "ax.spines['bottom'].set_linewidth(3)\n", "\n", "ax.legend(fontsize=13)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "5ddc34a5", "metadata": {}, "source": [ "To further verifiy, let's count the number of documents whose similarity is greater than the threshold. You can further verfiy by inspecting the actual similirity of the candidate pairs surfaced." ] }, { "cell_type": "code", "execution_count": 20, "id": "2fe1c687", "metadata": { "ExecuteTime": { "end_time": "2022-04-01T04:57:06.552662Z", "start_time": "2022-04-01T04:57:06.439258Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of docs exceeding the simiarlity threshold: 7\n" ] } ], "source": [ "import numpy as np\n", "# display the similarity value above the approx thresh\n", "num_candidate_pairs = (np.array(similarities) > thresh).sum()\n", "print(\"Number of docs exceeding the simiarlity threshold: \", num_candidate_pairs)" ] }, { "cell_type": "markdown", "id": "f271b4dd", "metadata": {}, "source": [ "## Final note for LSH" ] }, { "cell_type": "markdown", "id": "746c5c46", "metadata": {}, "source": [ "You may observe that some documents `LSH` classed as candidate pairs, may have actual similarities less than the threshold. This presents the trade-off between efficiency and the accuracy of similarity classification--whereas, the banding technique may surface false negatives and false positives. False negatives occur when the banding strategy used does not align with the similarity threshold that we expect. However, this can be addressed by calculating the approximate threshold which is shown in the plots. False positives occur because of the more tolerant condition wherein we considered a pair as candidates if they are hashed in the same bucket for at least one band. This can be addressed by adding an exhaustive similarity search on the smaller subset of pairs--the candidate pairs." ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }