Building a Information Retrieval system based on the Covid-19 research challenge dataset: Part 3

In the first post of this series, we went through the ETL process used to transform a collection of papers/articles about Covid-19 and other infectious diseases into a SQLite database. In the second post, we saw how to generate weighted embedding vectors for important sentences in the corpus which can be used to quickly identify sentences and associated documents relevant to a user query. In this third and final post,  I’ll show you how to use a BERT model fine-tuned on the SQuAD dataset to identify excerpts that are relevant to the user query from the body text of the top matching documents returned by the sentence embeddings search. The application source code is here.

I will not be going through the details of how BERT works. For that, I refer you to the following references:

Below, I’ll go through some of my implementation specific quirks. The flowchart the diagram below shows the workflow of applying BERT models to the full text of the documents returned by the sentence embedding search to retrieve answer spans.

 

The diagram below shows this workflow being applied to the query “are antimalarial drugs effective in treating covid-19?”.

Let’s go through a few BERT implementation details that took me some time to understand. The first is the idea of maintaining mapping between token and word indices and vice versa. This is needed because BERT operates at the level of tokens rather than words. BERT uses the wordpiece tokenization scheme which splits up an input word into smaller units until the units are found in the vocabulary. BERT outputs the probabilities of each input token being the start or end of the answer span. These probabilities are input to a decoding module which produces token indices that maximize the product of the start and end probabilities subject to a maximum answer length constraint. These indices are at the token level rather than word level, and word spans are what users care about. To retrieve the original word indices, we maintain maps of token to original word indices, as shown in the picture below.

 

Spans is another important concept to understand. The input to a BERT model fine-tuned in a QA setting is a batch of query-context items. Each item in the batch is a concatenation of query tokens and context tokens separated by a special character (see diagram above). A context is simply a collection of sentences from the documents over which we wish to identify excerpts. Sentences are split into tokens and added to a context until the total number of tokens exceeds a threshold. When this occurs, we finalize that span and start processing a new one. Because we may exceed the max number of token threshold when we are midway through a sentence, we start with the beginning of the current sentence when we initiate a new span. This can result in portion of a sentence being part of one span and the complete sentence part of the next one, but does no harm other than increasing the number of spans (and thus the processing time) by a small amount. In my implementation (see preprocess function in src/covid_browser/fast_qa), I create these spans ahead of time, assuming a maximum query length. The flowchart below shows this process:

At query time, I add the tokenized query to each span and maintain an offset equal to the query length, which is later used to adjust the token to word indices.

This may not make much sense if you are not familiar with how BERT preprocessing works. I recommend running the code in a debugger and stepping through the preprocessing steps to get a feel for what’s going on.

A few pitfalls to be aware of:

  1. If you are using an English vocabulary but processing a document in another language (eg. spanish), then tokenization can take a long time because the wordpiece tokenizer must continue splitting up words into smaller tokens before the tokens are found in the vocabulary due to the language mismatch. I hack around this issue by rejecting documents that contain non-ascii characters such as characters with accent marks, which don’t occur in English. I also terminate span creation if the total time exceeds a threshold (2 seconds).
  2. Some sentences (such as those containing long mathematical formulas) can exceed the total number of tokens threshold after tokenization. Such sentences must be appropriately truncated.

Data Parallel Forward Pass

I have two 1080Ti GPUs on my deep learning machine. To speed up forward pass through the BERT model, I split up the batch along the batch dimension and use Python’s futures library to execute the forward pass in parallel. Here’s the code (see parallel_execute2 in fast_qa)

The total query processing time is ~3 seconds split into three major components – BERT execution time (~1.8 sec), span creation time (~1 sec) and top-n embedding look time (0.1 sec)

This concludes this series of posts. I think this Covid-19 information retrieval system is a great example of combining conventional natural processing techniques such as TF-IDF and word embeddings with modern neural network based techniques such as BERT. I hope you found these posts useful, please leave a comment if you did.

Be the first to comment

Leave a Reply

Your email address will not be published.


*