Transformers documentation
Granite Speech Plus
This model was released on 2026-04-23 and added to Hugging Face Transformers on 2026-04-29.
Granite Speech Plus
Overview
Granite Speech Plus is a variant of Granite Speech whose projector consumes the concatenation of
the encoder’s final hidden states with an arbitrary subset of its intermediate hidden states (along the feature
dimension). The selected intermediate layers are controlled by the cat_hidden_layers config field on
GraniteSpeechPlusEncoderConfig; when it is None, the model behaves identically to Granite Speech. When it is set, the
projector’s encoder_hidden_size must equal encoder_config.hidden_dim * (len(cat_hidden_layers) + 1).
The rest of the architecture — speech encoder, query transformer projector, language model, and optional LoRA adapter — is inherited unchanged from Granite Speech. See the Granite Speech documentation for usage examples; the same GraniteSpeechProcessor and GraniteSpeechFeatureExtractor are used here.
Usage
Granite Speech Plus is a multimodal speech-to-text model that can transcribe audio, provide speaker annotation and word level timestamps by responding to text prompts. Here’s how to use the different functions:
Setup — load the model and a test audio clip:
import re
import torch
from datasets import Audio, load_dataset
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
SAMPLE_RATE = 16000
MODEL_NAME = "ibm-granite/granite-speech-4.1-2b-plus"Define the prompts used for the different tasks:
SYSTEM_PROMPT = "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant"
ASR_PROMPT = "<|audio|> can you transcribe the speech into a written format?"
SAA_PROMPT = "<|audio|> Speaker attribution: Transcribe and denote who is speaking by adding [Speaker 1]: and [Speaker 2]: tags before speaker turns."
TS_PROMPT = "<|audio|> Timestamps: Transcribe the speech. After each word, add a timestamp tag showing the end time in centiseconds, e.g. hello [T:45] world [T:82]"Load the model and define a general function for decoding the audio:
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME, device_map="auto")
@torch.inference_mode()
def transcribe(audio, prompt, max_new_tokens=2000, prefix_text=None):
chat = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}]
extra = {"prefix_text": prefix_text} if prefix_text is not None else {}
prompt_text = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, **extra)
inputs = processor(prompt_text, audio, device=device, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, num_beams=1)
new_tokens = outputs[0, inputs["input_ids"].shape[-1]:]
output_text = processor.decode(new_tokens, add_special_tokens=False, skip_special_tokens=True)
return output_textLoad some example audio data from the AMI dataset
ds = load_dataset("diarizers-community/ami", "ihm", split="test")
ds = ds.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE, num_channels=1))
TEST_SAMPLE = 0
START_TIME, END_TIME = 5 * 60, 6 * 60
audio = ds["audio"][TEST_SAMPLE].get_samples_played_in_range(START_TIME, END_TIME)Task 1: ASR — plain speech-to-text transcription:
asr_text = transcribe(audio.data, ASR_PROMPT)
print(asr_text)Task 2: Speaker Attributed ASR — transcription with speaker labels:
saa_text = transcribe(audio.data, SAA_PROMPT)
for segment in re.split(r"(\[Speaker \d+\]:)", saa_text):
print(segment.strip())Task 3: Word-level timestamps — transcription with per-word timing:
The timestamps are given in centiseconds and are modulo 1000 (=10 seconds) so we need to unwrap them by adding multiples of 10 seconds.
ts_text = transcribe(audio.data, TS_PROMPT, max_new_tokens=10000)
ts_words = re.split(r"\[T:(\d+)\]", ts_text)
last_word_end_time = 0
offset_time = 0
for word, ts in zip(ts_words[::2], ts_words[1::2]):
word_end_time = float(ts) / 100
while word_end_time + offset_time < last_word_end_time:
offset_time += 10
last_word_end_time = word_end_time + offset_time
print(f"{word}\t{last_word_end_time:.2f}s")Task 4: Incremental decoding — transcribe segments while accumulating audio context:
NUM_SEGMENTS = 3
previous_transcript = ""
all_audio = None
for k in range(NUM_SEGMENTS):
t1 = START_TIME + (END_TIME - START_TIME) * k / NUM_SEGMENTS
t2 = START_TIME + (END_TIME - START_TIME) * (k + 1) / NUM_SEGMENTS
new_audio = ds["audio"][TEST_SAMPLE].get_samples_played_in_range(t1, t2)
all_audio = new_audio.data if all_audio is None else torch.cat([all_audio, new_audio.data], dim=-1)
saa_text = transcribe(all_audio, SAA_PROMPT, prefix_text=previous_transcript)
print(f"{t1:06.2f}-{t2:06.2f}:\t{saa_text}")
previous_transcript = (previous_transcript + " " + saa_text).strip()GraniteSpeechPlusConfig
class transformers.GraniteSpeechPlusConfig
< source >( transformers_version: str | None = None architectures: list[str] | None = None output_hidden_states: bool | None = False return_dict: bool | None = True dtype: typing.Union[str, ForwardRef('torch.dtype'), NoneType] = None chunk_size_feed_forward: int = 0 is_encoder_decoder: bool = False id2label: dict[int, str] | dict[str, str] | None = None label2id: dict[str, int] | dict[str, str] | None = None problem_type: typing.Optional[typing.Literal['regression', 'single_label_classification', 'multi_label_classification']] = None text_config: dict | transformers.configuration_utils.PreTrainedConfig | None = None encoder_config: dict | transformers.configuration_utils.PreTrainedConfig | None = None projector_config: dict | transformers.configuration_utils.PreTrainedConfig | None = None audio_token_index: int = 49155 initializer_range: float = 0.02 has_lora_adapter: bool = True downsample_rate: int = 5 window_size: int = 15 )
Parameters
- text_config (
Union[dict, ~configuration_utils.PreTrainedConfig], optional) — The config object or dictionary of the text backbone. - encoder_config (
Union[dict, ~configuration_utils.PreTrainedConfig], optional) — The config object or dictionary of the encoder backbone. - projector_config (
Union[AutoConfig, dict], optional, defaults toBlip2QFormerConfig) — The config object or dictionary of the audio projector. - audio_token_index (
int, optional, defaults to49155) — The audio token index used as a placeholder for input audio. - initializer_range (
float, optional, defaults to0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - has_lora_adapter (
bool, optional, defaults toTrue) — Indicates whether or not the model has a lora adapter that should only be activate when processing audio inputs. - downsample_rate (
int, optional, defaults to 5) — Downsample rate for the audio feature extractor. - window_size (
int, optional, defaults to 15) — Window size for the audio feature projector.
This is the configuration class to store the configuration of a Granite Speech PlusModel. It is used to instantiate a Granite Speech Plus model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ibm-granite/granite-speech-4.1-2b-plus
Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.
Example:
>>> from transformers import GraniteSpeechPlusConfig, GraniteSpeechPlusForConditionalGeneration
>>> # Initializing a GraniteSpeechPlusConfig
>>> configuration = GraniteSpeechPlusConfig()
>>> # Initializing a GraniteSpeechPlusForConditionalGeneration (with random weights)
>>> model = GraniteSpeechPlusForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.configGraniteSpeechPlusEncoderConfig
class transformers.GraniteSpeechPlusEncoderConfig
< source >( transformers_version: str | None = None architectures: list[str] | None = None output_hidden_states: bool | None = False return_dict: bool | None = True dtype: typing.Union[str, ForwardRef('torch.dtype'), NoneType] = None chunk_size_feed_forward: int = 0 is_encoder_decoder: bool = False id2label: dict[int, str] | dict[str, str] | None = None label2id: dict[str, int] | dict[str, str] | None = None problem_type: typing.Optional[typing.Literal['regression', 'single_label_classification', 'multi_label_classification']] = None input_dim: int = 160 num_layers: int = 10 hidden_dim: int = 1024 feedforward_mult: int = 4 num_heads: int = 8 dim_head: int = 128 output_dim: int = 42 context_size: int = 200 max_pos_emb: int = 512 dropout: float | int = 0.1 conv_kernel_size: int = 15 conv_expansion_factor: int = 2 cat_hidden_layers: list[int] | None = None )
Parameters
- input_dim (int, optional, defaults to 160) — Dimensionality of the input acoustic features (e.g., number of mel-filterbank channels).
- num_layers (int, optional, defaults to 10) — Number of hidden layers in the Transformer decoder.
- hidden_dim (int, optional, defaults to 1024) — Dimension of the hidden representations.
- feedforward_mult (int, optional, defaults to 4) — Multiplier for the up/down projections in the encoder’s feedforward layers; The projections will have intermediate dim of size hidden_dim feedforward_mult*.
- num_heads (int, optional, defaults to 8) — Number of attention heads for each attention layer in the Transformer decoder.
- dim_head (int, optional, defaults to 128) — The attention head dimension. If None, it will default to hidden_size // num_attention_heads
- output_dim (int, optional, defaults to 42) — Intermediate dimension of the feedforward projections in the conformer to be added to every other encoder block’s output.
- context_size (int, optional, defaults to 200) — Context size to be used in conformer attention.
- max_pos_emb (int, optional, defaults to 512) — Max pos embeds to be used in attention (shaw’s relative positional encoding).
- dropout (Union[float, int], optional, defaults to 0.1) — The ratio for all dropout layers.
- conv_kernel_size (int, optional, defaults to 15) — The size of the convolutional kernel.
- conv_expansion_factor (int, optional, defaults to 2) — Intermediate dimension to be used in conformer convolutions.
- cat_hidden_layers (list[int], optional) —
Indices of encoder conformer layers whose outputs are concatenated with the final encoder
output (along the feature dimension) before being passed to the projector. When set, the
projector’s
encoder_hidden_sizemust equalencoder_config.hidden_dim * (len(cat_hidden_layers) + 1).
This is the configuration class to store the configuration of a Granite Speech PlusModel. It is used to instantiate a Granite Speech Plus model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ibm-granite/granite-speech-4.1-2b-plus
Configuration objects inherit from [PreTrainedConfig] and can be used to control the model outputs. Read the documentation from [PreTrainedConfig] for more information.
Example:
>>> from transformers import GraniteSpeechPlusEncoderConfig, GraniteSpeechPlusCTCEncoder
>>> # Initializing a GraniteSpeechPlusEncoderConfig
>>> configuration = GraniteSpeechPlusEncoderConfig()
>>> # Initializing a GraniteSpeechPlusCTCEncoder (with random weights)
>>> model = GraniteSpeechPlusCTCEncoder(configuration)
>>> # Accessing the model configuration
>>> configuration = model.configGraniteSpeechPlusForConditionalGeneration
class transformers.GraniteSpeechPlusForConditionalGeneration
< source >( config: GraniteSpeechPlusConfig )
Parameters
- config (GraniteSpeechPlusConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The Granite Speech Plus model, a Granite Speech variant whose projector consumes the concatenation of the encoder’s final hidden states with an arbitrary subset of its intermediate hidden states.
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: torch.LongTensor | None = None input_features: torch.FloatTensor | None = None input_features_mask: torch.Tensor | None = None attention_mask: torch.Tensor | None = None position_ids: torch.LongTensor | None = None past_key_values: transformers.cache_utils.Cache | None = None inputs_embeds: torch.FloatTensor | None = None labels: torch.LongTensor | None = None use_cache: bool | None = None output_attentions: bool | None = None output_hidden_states: bool | None = None return_dict: bool | None = None logits_to_keep: int | torch.Tensor = 0 **lm_kwargs ) → GraniteSpeechPlusCausalLMOutputWithPast or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensorof shape(batch_size, sequence_length), optional) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- input_features (
torch.FloatTensorof shape(batch_size, sequence_length, feature_dim), optional) — The tensors corresponding to the input audio features. Audio features can be obtained using GraniteSpeechFeatureExtractor. SeeGraniteSpeechFeatureExtractor.__call__()for details (GraniteSpeechProcessor uses GraniteSpeechFeatureExtractor for processing audios). - input_features_mask (
torch.Tensor, optional) — Mask to be applied to audio features prior to scattering into the language embeddings. - attention_mask (
torch.Tensorof shape(batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
- position_ids (
torch.LongTensorof shape(batch_size, sequence_length), optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.n_positions - 1]. - past_key_values (
~cache_utils.Cache, optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in thepast_key_valuesreturned by the model at a previous stage of decoding, whenuse_cache=Trueorconfig.use_cache=True.Only Cache instance is allowed as input, see our kv cache guide. If no
past_key_valuesare passed, DynamicCache will be initialized by default.The model will output the same cache format that is fed as input.
If
past_key_valuesare used, the user is expected to input only unprocessedinput_ids(those that don’t have their past key value states given to this model) of shape(batch_size, unprocessed_length)instead of allinput_idsof shape(batch_size, sequence_length). - inputs_embeds (
torch.FloatTensorof shape(batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passinginput_idsyou can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_idsindices into associated vectors than the model’s internal embedding lookup matrix. - labels (
torch.LongTensorof shape(batch_size, sequence_length), optional) — Labels for computing the masked language modeling loss. Indices should either be in[0, ..., config.vocab_size]or -100 (seeinput_idsdocstring). Tokens with indices set to-100are ignored (masked), the loss is only computed for the tokens with labels in[0, ..., config.vocab_size]. - use_cache (
bool, optional) — If set toTrue,past_key_valueskey value states are returned and can be used to speed up decoding (seepast_key_values). - output_attentions (
bool, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentionsunder returned tensors for more detail. - output_hidden_states (
bool, optional) — Whether or not to return the hidden states of all layers. Seehidden_statesunder returned tensors for more detail. - return_dict (
bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - logits_to_keep (
Union[int, torch.Tensor], optional, defaults to0) — If anint, compute logits for the lastlogits_to_keeptokens. If0, calculate logits for allinput_ids(special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If atorch.Tensor, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns
GraniteSpeechPlusCausalLMOutputWithPast or tuple(torch.FloatTensor)
A GraniteSpeechPlusCausalLMOutputWithPast or a tuple of
torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various
elements depending on the configuration (GraniteSpeechPlusConfig) and inputs.
The GraniteSpeechPlusForConditionalGeneration forward method, overrides the __call__ special method.
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.
loss (
torch.FloatTensorof shape(1,), optional, returned whenlabelsis provided) — Language modeling loss (for next-token prediction).logits (
torch.FloatTensorof shape(batch_size, sequence_length, config.vocab_size)) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).past_key_values (
Cache, optional, returned whenuse_cache=Trueis passed or whenconfig.use_cache=True) — It is a Cache instance. For more details, see our kv cache guide.Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
past_key_valuesinput) to speed up sequential decoding.hidden_states (
tuple[torch.FloatTensor], optional, returned whenoutput_hidden_states=Trueis passed or whenconfig.output_hidden_states=True) — Tuple oftorch.FloatTensor(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size).Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (
tuple[torch.FloatTensor], optional, returned whenoutput_attentions=Trueis passed or whenconfig.output_attentions=True) — Tuple oftorch.FloatTensor(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length).Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Example:
>>> from transformers import AutoProcessor, GraniteSpeechPlusForConditionalGeneration
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = AutoProcessor.from_pretrained("ibm-granite/granite-speech-4.1-2b-plus")
>>> model = GraniteSpeechPlusForConditionalGeneration.from_pretrained("ibm-granite/granite-speech-4.1-2b-plus")
>>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> # transcribe speech
>>> transcription = processor.batch_decode(predicted_ids)
>>> transcription[0]
...
>>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
>>> # compute loss
>>> loss = model(**inputs).loss
>>> round(loss.item(), 2)
...get_audio_features
< source >( input_features: Tensor **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) → BaseModelOutputWithPooling or tuple(torch.FloatTensor)
Parameters
- input_features (
torch.Tensorof shape(batch_size, sequence_length, feature_dim)) — The tensors corresponding to the input audio features. Audio features can be obtained using GraniteSpeechFeatureExtractor. SeeGraniteSpeechFeatureExtractor.__call__()for details (GraniteSpeechProcessor uses GraniteSpeechFeatureExtractor for processing audios).
Returns
BaseModelOutputWithPooling or tuple(torch.FloatTensor)
A BaseModelOutputWithPooling or a tuple of
torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various
elements depending on the configuration (GraniteSpeechPlusConfig) and inputs.
last_hidden_state (
torch.FloatTensorof shape(batch_size, sequence_length, hidden_size)) — Sequence of hidden-states at the output of the last layer of the model.pooler_output (
torch.FloatTensorof shape(batch_size, hidden_size)) — Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.hidden_states (
tuple(torch.FloatTensor), optional, returned whenoutput_hidden_states=Trueis passed or whenconfig.output_hidden_states=True) — Tuple oftorch.FloatTensor(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size).Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (
tuple(torch.FloatTensor), optional, returned whenoutput_attentions=Trueis passed or whenconfig.output_attentions=True) — Tuple oftorch.FloatTensor(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length).Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Example:
>>> from transformers import AutoProcessor, GraniteSpeechPlusForConditionalGeneration
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = AutoProcessor.from_pretrained("ibm-granite/granite-speech-4.1-2b-plus")
>>> model = GraniteSpeechPlusForConditionalGeneration.from_pretrained("ibm-granite/granite-speech-4.1-2b-plus")
>>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> # transcribe speech
>>> transcription = processor.batch_decode(predicted_ids)
>>> transcription[0]
...
>>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
>>> # compute loss
>>> loss = model(**inputs).loss
>>> round(loss.item(), 2)
...