How to compute word error rate (WER) with OpenAI Whisper
Automatic speech recognition (ASR) is the task of transcribing speech from audio in an automated manner. Whisper is an ASR model introduced by OpenAI in September 2022. The code for Whisper is publicly available. The model was trained on 680,000 hours of speech data in a supervised manner, and it stands out for being able to produce good results on out-of-distribution data. The training data included 96 other languages, and the model is also capable of translating transcriptions from any language to English.
Word error rate
Word error rate (WER) is the most commonly used measure for evaluating speech recognition models. It’s expressed as the percentage of errors made in transcriptions, and is computed as follows:
where,
- , and are word counts.
Note: WER is increasingly approaching the approximate of a human transcriptionist (4%). The ASR field, as of 2023, is quite saturated, and only large companies are fully capable of performing effective research in this area. Generally, a single-digit WER is considered good enough.
Toy Example
Let’s compute WER for a short example. Consider the following scenario:
Ground Truth
A man said to the universe: Sir, I exist!
Transcription
A mane to universe: Sire I egg exist!
There were 9 words in the original transcription. Let’s count the errors in the transcription:
- Substitutions:
- “man”/“mane”
- “Sir”/“Sire”
- Deletions:
- “the”
- Insertions:
- “egg”
Source of errors
ASR models generally perform well on read speech in a controlled environment. The results are worse when the speech is spontaneous or conversational in nature or when there’s ambient noise present. When there’s noise or background music present in the recording, the model might try to map it to characters/words, resulting in erroneous transcriptions. Consequently, many models use augmentation strategies that add noise to the training dataset dynamically during training. This increases the robustness of the model. Similarly, two words can have a similar pronunciation and different spellings, or the words might not be very clear. ASR models usually use a language model when decoding sentences. This might mitigate the problem as the context is also being used to predict the letters/words. This is where the problem with spontaneous or conversational speech comes in—the grammar might not be perfect. Consequently, these problems still adversely affect the transcription quality.
Applications
Let’s say that we have a call center that handles millions of calls per day, most of which are redundant queries, and that we want to at least automate the redundant queries in order to save resources. The first step toward this automation process is to select an ASR model that performs well on our data. We retrieve a sample dataset from the data distribution that our call center is generating and evaluate the WER score for multiple models, such as Google’s Universal Speech Model, OpenAI’s Whisper, and Microsoft’s Azure AI Speech, on our data. The best model to use with our data would be the one with the lowest WER score.
Furthermore, all the models representing the current state-of-the-art in speech recognition are neural network models, and the WER score is used to train these models.
Code
Let’s now explore how the WER metric can be used for evaluating transcriptions with the Whisper model.
from datasets import load_dataset, load_metric, Audioimport torchimport pandas as pdimport refrom transformers import WhisperForConditionalGeneration, WhisperProcessordevice = 'cuda' if torch.cuda.is_available() else 'cpu'print('Inference device:', device)dummy = load_dataset('EducativeCS2023/dummy_en_asr')processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language='English', task='transcribe')model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(device)results = []for i in range(len(dummy['train'])):inputs = processor(dummy['train'][i]["audio"]["array"], return_tensors="pt", sampling_rate=16000)input_features = inputs.input_featuresgenerated_ids = model.generate(inputs=input_features.to(device), task='transcribe', language='English')transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]string = ''.join(transcription)print('transcript',i , ":", transcription)results.append(string)chars_to_remove_regex = '[\,\'\.\-\,\۔\!\,\؟\:\-\“\%\‘\’\”\،\"]'predictions = [re.sub(chars_to_remove_regex, '', str(x).lower()) for x in results]references = [i['sentence'].lower() for i in dummy['train']]wer_metric = load_metric("wer")wer_score = wer_metric.compute(predictions=predictions, references=references)print("wer score", wer_score)
Code explanation
- Lines 1–5: We load the following libraries and modules:
datasets:load_dataset: To retrieve and load the datasetload_metric: To evaluate the modelAudio: To cast audio file paths to features
torch: The machine learning framework we’ll be usingre: The library required for removing punctuations from the transcriptionstransformers:WhisperForConditionalGeneration: To load thewhispermodelWhisperProcessor: To preprocess the audio and text
- Lines 7–8: We select the compute device.
- Line 10: We download and load the dataset.
- Lines 12–13: We download the
whispermodel and its processor from the trained checkpoint.processorcontains bothfeature_extractorand thetokenizer. - Line 17: We iterate through the
dummydataset. - Lines 19–20: We preprocess the input audio array and prepare the input features.
- Line 22: We use the model to generate the sequence of token IDs of the transcription given the input features.
- Line 24: We use the processor to decode the IDs into transcriptions
- Lines 29–30: We remove punctuation marks from the transcriptions.
- Line 32: We convert references to lowercase.
- Line 34: We load the WER metric.
- Lines 36–37: We compute and display the WER score.
Free Resources