🎧 Techniques for fine‑tuning Whisper on French telephone conversations
Language models: Focus on Whisper
Let’s first talk about Whisper, it is an Automatic Speech Recognition (ASR) model that transcribes spoken audio into text predictions. It belongs to modern language models which generally fall into three broad categories:
- Masked Language Models (MLMs): These models learn to predict missing tokens within a text by leveraging bidirectional context. Such as the highly influential BERT model (Devlin et al., 2019).
- Causal Language Models (CLMs): Also known as autoregressive models, CLMs generate text one token at a time, conditioning each prediction on all preceding tokens. The GPT family is the most famous representative, optimized for fluent, left‑to‑right language generation.
- Sequence‑to‑Sequence Models (Seq2Seq): Combining an encoder and a decoder, Seq2Seq architectures transform an input sequence into a target sequence. They are particularly well‑suited for tasks like machine translation, text summarization, and speech‑to‑text transcription.
Whisper belongs to this last family, as a Seq2Seq model, it first encodes 30 seconds of raw audio into a latent context representation and then decodes this representation as an autoregressive model to predict from left to right, text tokens.
Whisper prompt decoder system 🎯
Whisper was trained on 30-second audio chunks, learning to predict both regular and special tokens for each segment. The decoding process begins with the special token <startoftranscript|>
and continues until it generates the final token <endoftranscript|>
.
To better handle long-range dependencies, Whisper maintains previously transcribed 30-second chunks in its history. During training, these previous text tokens were included with a certain probability, which helps the model maintain context across longer audio files.
Thanks to its extensive multilingual training data, the model can identify and predict language tags (<|fr|>
or <|en|>)
from among 99 different languages. These language tags then condition the rest of the transcription process, ensuring appropriate language-specific processing.
Whisper was also trained to predict silence, generating the special token <nospeech|> when no speech is detected. However, in practice, Whisper sometimes hallucinates content during silent passages, so it’s generally recommended to use a dedicated Voice Activity Detection (VAD) model to pre-filter audio chunks before passing them to Whisper.
When speech is detected, we can use special tokens:
<|transcribe|>
instructs the model to translate the content into English<|translate|>
forces the model to transcribe the audio in its original language
If timestamp information isn’t needed, we can provide the <|notimestamp|> token to skip timestamp prediction entirely. Otherwise, the model can approximate start and end times for each detected group of words, thanks to its multitask training data. These timestamps are quantized to the nearest 20 milliseconds for efficiency.
For applications requiring more precise alignment between audio and text, more advanced approaches exist, such as Connectionist Temporal Classification (CTC), which directly aligns audio frames with text tokens for greater accuracy.
In this notebook we will see how to fine‑tune the Whisper small model (we select this size for demonstration purpose) on a French telephone conversations using techniques such as LoRA for parameter‑efficient adaptation and integrates speculative decoding to accelerate inference.
Setup & Authentication
model_name_or_path = "openai/whisper-small"
language = "French"
language_abbr = "fr"
task = "transcribe"
Your Hugging Face token can be found in Hugging Face Hub
import os
from huggingface_hub import login
login()
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/
assets/huggingface_logo-noborder.sv…
Install Dependencies
Install core libraries for audio, datasets, training, and deployment.
!pip install transformers peft torchaudio datasets accelerate
gradio jiwer evaluate torchinfo
Data Loading
We load a French telephone speech-to-text dataset processed by Diabolocom from Hugging Face Hub https://huggingface.co/datasets/diabolocom/talkbank_4_stt .
The dataset is structured into segment and switch parts, we are interested here into the segments and the “train” and “test” splits.
- Loads the Diabolocom French telephone dataset.
- Selects the first 1,000 samples for train/test (for demonstration).
- Ensures all audio is resampled to 16 kHz.
- Drops unused metadata columns.
from datasets import load_dataset, DatasetDict, Audio
talkbank_fr = load_dataset(
"diabolocom/talkbank_4_stt",
data_dir="fr/segment",
verification_mode="no_checks")
talkbank = DatasetDict({
"train": talkbank_fr["train"].select(range(1000)),
"test": talkbank_fr["test"].select(range(1000))
})
talkbank = talkbank.cast_column("audio", Audio(sampling_rate=16000))
Explore the dataset
Explore audio and transcripts in the dataset.
def format_seconds(seconds_input):
"""
Turn a number of seconds (possibly fractional) into a compact string,
omitting zero-value day/hour/minute units, and always showing seconds (with
up to millisecond precision, dropping trailing zeros).
"""
remaining = float(seconds_input)
parts = []
for label, unit_secs in (("d", 86400), ("h", 3600), ("m", 60)):
qty, remaining = divmod(remaining, unit_secs)
if qty >= 1:
parts.append(f"{int(qty)}{label}")
parts.append(f"{remaining:.3f}s")
return " ".join(parts)
import torch
from datasets import Audio
import IPython.display as ipd
id = 42
example = talkbank["test"][id]
audio_array = example["audio"]["array"]
sampling_rate = example["audio"]["sampling_rate"]
ipd.display(ipd.Audio(audio_array, rate=sampling_rate))
print(f"Duration: {format_seconds(audio_array.shape[0]/sampling_rate)}")
print(f"Language: {example['full_language']}")
print(f"subset: {example['subset']}")
Duration: 0.480s Language: French - Quebecois subset: CallFriend
import matplotlib.pyplot as plt
durations = {}
for split in ["train","test"]:
durations[split] = [
ex["audio"]["array"].shape[0] / ex["audio"]["sampling_rate"]
for ex in talkbank[split]
]
# two-panel histogram
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
for ax, split in zip(axes, ["train","test"]):
ax.hist(durations[split], bins=30)
ax.set_title(f"{split.capitalize()} split")
ax.set_xlabel("Duration (seconds)")
axes[0].set_ylabel("Count")
plt.tight_layout()
plt.show()
Preprocessing
Before we can fine-tune Whisper, we need to transform our raw dataset into exactly the inputs the model expects:
input_features
: 80-dim log-Mel spectrogram frames.labels
: target tokenized transcripts
# Just clean up unneeded columns, keeping the audio and transcript
talkbank = talkbank.remove_columns([
"language_code", "subset", "full_language",
"switch_id", "transcript_filename", "orig_file_start",
"orig_file_end", "channel"
])
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)
processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)
def prepare_dataset(batch):
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["transcript"]).input_ids
return batch
talkbank = talkbank.map(prepare_dataset, remove_columns=talkbank.column_names["train"], num_proc=1)
talkbank['train']
Dataset({ features: ['input_features', 'labels'], num_rows: 1000 })
Data collator
When you train with variable-length audio and text, you need a custom “collator” that dynamically pads each batch.
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
Load the Pre-Trained Model and Tokenizer
In this step we load the pre-trained Whisper model with 8-bit precision. We print a summary of the model’s layers and parameter counts.
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")
from torchinfo import summary
summary(model)
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
================================================================================ Layer (type:depth-idx) Param # ================================================================================ WhisperForConditionalGeneration -- ├─WhisperModel: 1-1 -- │ └─WhisperEncoder: 2-1 -- │ │ └─Conv1d: 3-1 185,088 │ │ └─Conv1d: 3-2 1,770,240 │ │ └─Embedding: 3-3 (1,152,000) │ │ └─ModuleList: 3-4 85,045,248 │ │ └─LayerNorm: 3-5 1,536 │ └─WhisperDecoder: 2-2 -- │ │ └─Embedding: 3-6 39,832,320 │ │ └─WhisperPositionalEmbedding: 3-7 344,064 │ │ └─ModuleList: 3-8 113,402,880 │ │ └─LayerNorm: 3-9 1,536 ├─Linear: 1-2 39,832,320 ================================================================================ Total params: 281,567,232 Trainable params: 82,059,264 Non-trainable params: 199,507,968 ================================================================================
The model has arguments that we will override such as:
forced_decoder_ids (List[List[int]]
, optional): A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, [[1, 123]]
means the second generated token will always be a token of index 123.
suppress_tokens (List[int]
, optional): A list of tokens that will be suppressed at generation. The SupressTokens
logit processor will set their log probs to -inf
so that they are not sampled
print(f"model forced_decoder_ids: {model.config.forced_decoder_ids}")
print(f"model suppress_tokens: {model.config.suppress_tokens}")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model forced_decoder_ids: [[1, 50259], [2, 50359], [3, 50363]] model suppress_tokens: [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362]
LoRA Fine-Tuning ⚙️
We will now integrate LoRA adapters into our pre-trained Whisper model via the PEFT library.
LoRA (“Low-Rank Adaptation”) freezes the original weights and learns only a small, additive low-rank update in selected layers. This reduces the number of trainable parameters and accelerates fine-tuning.
Moreover, because the LoRA updates are stored separately, you can easily swap in different adapter checkpoints—for example, to support new languages or specialized domains—while keeping the same underlying base model.
from peft import LoraConfig, get_peft_model, TaskType
# Configure LoRA settings for sequence-to-sequence language modeling
lora_config = LoraConfig(
r=8, # Rank of the low-rank matrices
lora_alpha=32, # Scaling factor for LoRA updates
lora_dropout=0.1, # Dropout probability for LoRA layers
target_modules=["q_proj", "v_proj"] # Target modules in the Transformer to adapt (adjust as needed)
)
#Some models freeze embeddings by default; this ensures LoRA can adapt them if needed.
model.enable_input_require_grads()
# Wrap the model with LoRA adapters; this makes only the LoRA parameters trainable.
model = get_peft_model(model, lora_config)
# Optionally, print trainable parameters to verify that only the LoRA layers are being updated.
print("Trainable parameters after applying LoRA:")
model.print_trainable_parameters()
Trainable parameters after applying LoRA: trainable params: 884,736 || all params: 242,619,648 || trainable%: 0.3647
How it works
- Freeze pre-trained weights: \( W \in \mathbb{R}^{d \times d} \)
- Integrate LoRA adapters
For each layer name intarget_modules
, PEFT creates two small matrices of weights:- \( A \in \mathbb{R}^{d \times r} \)
- \( B \in \mathbb{R}^{r \times h} \)
- Training
Each adapted layer during training now computes:
\( h = Wx + BAx \)- \(W\) is the frozen original weight matrix.
- \(B\) and \(A\) are the trainable weights.
Define the Evaluation Metric (WER)
We use the Word Error Rate (WER) to evaluate the performance of our speech recognition system. The compute_metrics
function decodes model predictions and ground-truth labels, computes the WER, and returns it as a percentage.
import evaluate
# Load the WER metric from the evaluate library
metric = evaluate.load("wer")
def compute_metrics(pred):
"""
Compute Word Error Rate (WER) for model predictions.
Args:
pred: The prediction output from the trainer containing predictions and label_ids.
Returns:
A dictionary containing the WER score.
"""
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Setup Training Arguments
We use Hugging Face’s Seq2SeqTrainingArguments to define our training configuration. These settings include batch size, learning rate, number of steps, evaluation strategy, and logging.
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="results",
per_device_train_batch_size=8,
gradient_accumulation_steps=1,
learning_rate=1e-3,
warmup_steps=50,
num_train_epochs=3,
evaluation_strategy="epoch",
fp16=True,
per_device_eval_batch_size=8,
generation_max_length=128,
logging_steps=25,
remove_unused_columns=False, #required for the PeftModel forward
label_names=["labels"], #same reason as above
report_to=["none"], # Disable logging to avoid cluttering the output
)
/root/miniconda3/envs/canary/lib/python3.10/site-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
Initialize the Trainer
We instantiate Hugging Face’s Seq2SeqTrainer with our model, training arguments, datasets and data collator.
We are just using the first 100 segments for demonstration, if you want production-quality transcription, you should continue fine-tuning with a larger dataset and more training time
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=talkbank["train"].select(range(100)),
eval_dataset=talkbank["test"].select(range(100)),
data_collator=data_collator,
tokenizer=processor.feature_extractor,
)
/tmp/ipykernel_45417/2961733102.py:3: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Seq2SeqTrainer.__init__`. Use `processing_class` instead.
trainer = Seq2SeqTrainer(
Fine-Tune the Model
We now start the training process. The trainer will fine-tune the model using the LoRA adapters, updating only the LoRA-specific parameters.
trainer.train()
[39/39 01:20, Epoch 3/3]
Epoch | Training Loss | Validation Loss |
---|---|---|
1 | No log | 3.890726 |
2 | 3.955700 | 2.560604 |
3 | 3.955700 | 1.943963 |
/root/miniconda3/envs/canary/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization warnings.warn(f”MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization”)
TrainOutput(global_step=39, training_loss=3.2936521676870494, metrics={'train_runtime': 82.7204, 'train_samples_per_second': 3.627, 'train_steps_per_second': 0.471, 'total_flos': 8.6957826048e+16, 'train_loss': 3.2936521676870494, 'epoch': 3.0})
(Optional) Push the model on the Hugging Face Hub
model_name_or_path = "openai/whisper-small"
peft_model_id = "diabolocom/" + f"{model_name_or_path}-LoRA".replace("/", "-")
model.push_to_hub(peft_model_id)
print(peft_model_id)
README.md: 0%| | 0.00/5.17k [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/3.56M [00:00<?, ?B/s]
diabolocom/openai-whisper-small-LoRA
Evaluation
In this section we load our PEFT‐LoRA-fine-tuned Whisper model, run it on the test split, and compute the Word Error Rate (WER) metric.
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, Seq2SeqTrainer
peft_model_id = "diabolocom/openai-whisper-small-LoRA"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id)
adapter_config.json: 0%| | 0.00/896 [00:00, ?B/s]
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
adapter_model.safetensors: 0%| | 0.00/3.56M [00:00, ?B/s]
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import gc
eval_dataloader = DataLoader(talkbank["test"], batch_size=8, collate_fn=data_collator)
model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
with torch.cuda.amp.autocast():
with torch.no_grad():
generated_tokens = (
model.generate(
input_features=batch["input_features"].to("cuda"),
decoder_input_ids=batch["labels"][:, :4].to("cuda"),
max_new_tokens=255,
)
.cpu()
.numpy()
)
labels = batch["labels"].cpu().numpy()
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
del generated_tokens, labels, batch
gc.collect()
wer = 100 * metric.compute()
print(f"{wer=}")
Demonstration with Gradio
Here we build a simple web interface so users can speak into their microphone and see live transcription from our PEFT-LoRA Whisper model.
This French‐language Whisper-small model includes LoRA adapters to illustrate the fine-tuning process.
It isn’t fully trained, if you want production-quality transcription, you should continue fine-tuning with a larger dataset and more training time.
import torch
import gradio as gr
from transformers import (
AutomaticSpeechRecognitionPipeline,
WhisperForConditionalGeneration,
WhisperTokenizer,
WhisperProcessor,
)
from peft import PeftModel, PeftConfig
peft_model_id = "diabolocom/openai-whisper-small-LoRA"
language = "French"
task = "transcribe"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
feature_extractor = processor.feature_extractor
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
def transcribe(audio):
with torch.cuda.amp.autocast():
text = pipe(audio, generate_kwargs={"forced_decoder_ids": forced_decoder_ids}, max_new_tokens=255)["text"]
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(sources=["microphone"], type="filepath"),
outputs="text",
title="PEFT LoRA",
description="Tutorial demo only: this French‐language Whisper-small model includes LoRA adapters to illustrate the fine-tuning process. It isn’t fully trained—if you want production-quality transcription, you should continue fine-tuning with a larger dataset and more training time.",
)
iface.launch(share=True)
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.Device set to use cuda:0
* Running on local URL: http://127.0.0.1:7860
* Running on public URL: https://9cd5f0c486285ae32d.gradio.live
This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)