| | from typing import Any, Mapping, Optional |
| |
|
| | import torch |
| | from peft import LoraConfig, get_peft_model |
| | from transformers import WhisperConfig, WhisperModel |
| |
|
| |
|
| | class WhisperEncoderBackbone(torch.nn.Module): |
| | def __init__( |
| | self, |
| | model: str = "openai/whisper-small.en", |
| | hf_config: Optional[Mapping[str, Any]] = None, |
| | lora_params: Optional[Mapping[str, Any]] = None, |
| | ): |
| | """Whisper encoder model with optional Low-Rank Adaptation. |
| | |
| | Parameters |
| | ---------- |
| | model: Name of WhisperModel whose encoder to load from HuggingFace |
| | hf_config: Optional config for HuggingFace model |
| | lora_params: Parameters for Low-Rank Adaptation |
| | |
| | """ |
| | super().__init__() |
| | hf_config = hf_config if hf_config is not None else dict() |
| | backbone_config = WhisperConfig.from_pretrained(model, **hf_config) |
| | self.backbone = ( |
| | WhisperModel.from_pretrained( |
| | model, |
| | config=backbone_config, |
| | ) |
| | .get_encoder() |
| | .train() |
| | ) |
| | if lora_params is not None and len(lora_params) > 0: |
| | lora_config = LoraConfig(**lora_params) |
| | self.backbone = get_peft_model(self.backbone, lora_config) |
| | self.backbone_dim = backbone_config.hidden_size |
| |
|
| | def forward(self, whisper_feature_batch): |
| | return self.backbone(whisper_feature_batch).last_hidden_state.mean(dim=1) |
| |
|
| |
|
| | class SharedLayers(torch.nn.Module): |
| | def __init__(self, input_dim: int, proj_dims: list[int]): |
| | """Fully connected network with Mish nonlinearities between linear layers. No nonlinearity at input or output. |
| | |
| | Parameters |
| | ---------- |
| | input_dim: Dimension of input features |
| | proj_dims: Dimensions of layers to create |
| | |
| | """ |
| | super().__init__() |
| | modules = [] |
| | for output_dim in proj_dims[:-1]: |
| | modules.extend([torch.nn.Linear(input_dim, output_dim), torch.nn.Mish()]) |
| | input_dim = output_dim |
| | modules.append(torch.nn.Linear(input_dim, proj_dims[-1])) |
| | self.shared_layers = torch.nn.Sequential(*modules) |
| |
|
| | def forward(self, x): |
| | return self.shared_layers(x) |
| |
|
| | class TaskHead(torch.nn.Module): |
| | def __init__(self, input_dim: int, proj_dim: int, dropout: float = 0.0): |
| | """Fully connected network with one hidden layer, dropout, and a scalar output.""" |
| | super().__init__() |
| |
|
| | self.linear = torch.nn.Linear(input_dim, proj_dim) |
| | self.activation = torch.nn.Mish() |
| | self.dropout = torch.nn.Dropout(dropout) |
| | self.final_layer = torch.nn.Linear(proj_dim, 1, bias=False) |
| |
|
| | def forward(self, x): |
| | x = self.linear(x) |
| | x = self.activation(x) |
| | x = self.dropout(x) |
| | x = self.final_layer(x) |
| | return x |
| |
|
| |
|
| | class MultitaskHead(torch.nn.Module): |
| | def __init__( |
| | self, |
| | backbone_dim: int, |
| | shared_projection_dim: list[int], |
| | tasks: Mapping[str, Mapping[str, Any]], |
| | ): |
| | """Fully connected network with multiple named scalar outputs.""" |
| | super().__init__() |
| |
|
| | |
| | self.shared_layers = SharedLayers(backbone_dim, shared_projection_dim) |
| | self.classifier_head = torch.nn.ModuleDict( |
| | { |
| | task: TaskHead(shared_projection_dim[-1], **task_config) |
| | for task, task_config in tasks.items() |
| | } |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.shared_layers(x) |
| | return {task: head(x) for task, head in self.classifier_head.items()} |
| |
|
| |
|
| | def average_tensor_in_segments(tensor: torch.Tensor, lengths: list[int] | torch.Tensor): |
| | """Average segments of a `tensor` along dimension 0 based on a list of `lengths` |
| | |
| | For example, with input tensor `t` and `lengths` [1, 3, 2], the output would be |
| | [t[0], (t[1] + t[2] + t[3]) / 3, (t[4] + t[5]) / 2] |
| | |
| | Parameters |
| | ---------- |
| | tensor : torch.Tensor |
| | The tensor to average |
| | lengths : list of ints |
| | The lengths of each segment to average in the tensor, in order |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | The tensor with relevant segments averaged |
| | """ |
| | if not torch.is_tensor(lengths): |
| | lengths = torch.tensor(lengths, device=tensor.device) |
| | index = torch.repeat_interleave( |
| | torch.arange(len(lengths), device=tensor.device), lengths |
| | ) |
| | out = torch.zeros( |
| | lengths.shape + tensor.shape[1:], device=tensor.device, dtype=tensor.dtype |
| | ) |
| | out.index_add_(0, index, tensor) |
| | broadcastable_lengths = lengths.view((-1,) + (1,) * (len(out.shape) - 1)) |
| | return out / broadcastable_lengths |
| |
|
| |
|
| | class Classifier(torch.nn.Module): |
| | def __init__( |
| | self, |
| | backbone_configs: Mapping[str, Mapping[str, Any]], |
| | classifier_config: Mapping[str, Any], |
| | inference_thresholds: Mapping[str, Any], |
| | preprocessor_config: Mapping[str, Any], |
| | ): |
| | """Full Kintsugi Depression and Anxiety model. |
| | |
| | Whisper encoder -> Mean pooling over time -> Layers shared across tasks -> Per-task heads |
| | |
| | Parameters |
| | ---------- |
| | backbone_configs: |
| | classifier_config: |
| | inference_thresholds: |
| | preprocessor_config: |
| | |
| | """ |
| | super().__init__() |
| |
|
| | self.backbone = torch.nn.ModuleDict( |
| | { |
| | key: WhisperEncoderBackbone(**backbone_configs[key]) |
| | for key in sorted(backbone_configs.keys()) |
| | } |
| | ) |
| |
|
| | backbone_dim = sum(layer.backbone_dim for layer in self.backbone.values()) |
| | self.head = MultitaskHead(backbone_dim, **classifier_config) |
| | self.inference_thresholds = inference_thresholds |
| | self.preprocessor_config = preprocessor_config |
| |
|
| | def forward(self, x, lengths): |
| | backbone_outputs = { |
| | key: average_tensor_in_segments(layer(x), lengths) |
| | for key, layer in self.backbone.items() |
| | } |
| | backbone_output = torch.cat(list(backbone_outputs.values()), dim=1) |
| | return self.head(backbone_output), torch.ones_like(lengths) |
| |
|
| | def quantize_scores(self, scores: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]: |
| | """Map per-task scores to discrete predictions per `inference_thresholds` config.""" |
| | return { |
| | key: torch.searchsorted(torch.tensor(self.inference_thresholds[key], device=value.device), value.mean(), out_int32=True) |
| | for key, value in scores.items() |
| | } |
| |
|