# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0
"""Public QDiffusion module for generic discrete-sequence generation."""
from __future__ import annotations
from contextlib import nullcontext
from dataclasses import dataclass
import math
from typing import Any, Protocol
import numpy as np
import torch
from torch import nn
from ._qdiffusion_sampling import (
stochastic_sample_from_categorical,
stochastic_sample_from_categorical_n,
top_k_top_p_filtering,
topk_masking,
)
__all__ = ["QDiffusion", "QDiffusionConfig"]
@dataclass(frozen=True)
class SequenceTokenSpec:
"""Special-token metadata required by generic QDiffusion logic.
Attributes:
mask_id: Token id used as the diffusion noise token.
pad_id: Token id used for sequence padding.
bos_id: Token id used for beginning-of-sequence markers.
eos_id: Token id used for end-of-sequence markers.
x_id: Optional token id reserved by some backbones and excluded from
proposal sampling.
tokenizer: Optional tokenization helper used for encoding and decoding.
"""
mask_id: int
pad_id: int
bos_id: int
eos_id: int
x_id: int | None = None
tokenizer: Any | None = None
class EnergyBackboneAdapter(Protocol):
"""Protocol describing the generic energy-backbone bridge."""
@property
def hidden_size(self) -> int:
"""Returns the backbone hidden size used by the energy head."""
def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""Embeds discrete token ids into hidden representations."""
def encode_conditioned(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Encodes fused embeddings and returns token-level hidden states."""
def score_conditioned(
self,
noisy_tokens: torch.Tensor,
candidate_tokens: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Optionally scores candidates directly without the fallback path."""
[文档]
@dataclass
class QDiffusionConfig:
"""Configuration for energy-guided discrete generation.
Attributes:
num_diffusion_timesteps: Number of discrete noising steps used by the
training objective.
use_coupled_sampling: Whether to use the coupled corruption variant.
num_candidates: Number of proposal candidates sampled at each decode step.
proposal_temperature: Temperature used for proposal-side sampling.
proposal_noise_scale: Gumbel noise scale used during proposal sampling.
energy_temperature: Temperature used when converting energies into
reranking weights.
disable_resample: Whether to disable repetition-collapse resampling.
resample_ratio: Frequency threshold that triggers resampling.
resample_top_p: Top-p cutoff used during resampling.
decoding_strategy: Skeptical-remasking strategy string.
"""
num_diffusion_timesteps: int = 500
use_coupled_sampling: bool = False
num_candidates: int = 1
proposal_temperature: float = 0.0
proposal_noise_scale: float = 1.0
energy_temperature: float = 1.0
disable_resample: bool = False
resample_ratio: float = 0.25
resample_top_p: float = 0.95
decoding_strategy: str = "reparam-uncond-deterministic-linear"
[文档]
class QDiffusion(nn.Module):
"""Energy-guided discrete diffusion wrapper over generic sequence backbones.
The class combines two backbone roles:
- a proposal model that predicts token logits for the current noisy state
- an energy model that reranks candidate reconstructions
It exposes both training-oriented APIs such as :meth:`objective` and
decoding-oriented APIs such as :meth:`initialize_state`, :meth:`step`, and
:meth:`generate`.
"""
def __init__(
self,
proposal_model: nn.Module,
energy_model: nn.Module,
token_spec: SequenceTokenSpec,
energy_adapter: EnergyBackboneAdapter,
config: QDiffusionConfig | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str | None = None,
freeze_proposal: bool = True,
energy_head: nn.Module | None = None,
) -> None:
"""Initializes a QDiffusion model.
Args:
proposal_model: Backbone used to predict proposal logits.
energy_model: Backbone module whose parameters are optimized for
sequence-level energy scoring.
token_spec: Special-token metadata required by the generator.
energy_adapter: Adapter exposing the generic embedding and encoding
hooks used by :meth:`energy`.
config: Optional generation/training configuration.
dtype: Floating point dtype tracked by the wrapper.
device: Optional target device. When omitted, infer from parameters.
freeze_proposal: Whether to freeze proposal model parameters.
energy_head: Optional custom sequence-level energy head.
"""
super().__init__()
self.proposal_model = proposal_model
self.energy_model = energy_model
self.token_spec = token_spec
self.energy_adapter = energy_adapter
self.config = config or QDiffusionConfig()
self.dtype = dtype
if freeze_proposal:
self.proposal_model.eval()
for parameter in self.proposal_model.parameters():
parameter.requires_grad = False
hidden_size = self.energy_adapter.hidden_size
self.energy_head = energy_head or nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
)
self.vocab_proj = nn.Linear(2 * hidden_size, hidden_size, bias=True)
self.tokenizer = token_spec.tokenizer
self.mask_id = token_spec.mask_id
self.pad_id = token_spec.pad_id
self.bos_id = token_spec.bos_id
self.eos_id = token_spec.eos_id
self.x_id = token_spec.x_id
self.softplus = nn.Softplus()
if device is None:
try:
self.device = next(self.parameters()).device
except StopIteration:
self.device = torch.device("cpu")
else:
self.device = torch.device(device)
self.to(device=self.device, dtype=self.dtype)
[文档]
def to(self, *args: Any, **kwargs: Any) -> QDiffusion:
"""Moves the module and refreshes cached device/dtype metadata.
Args:
*args: Positional arguments forwarded to ``nn.Module.to``.
**kwargs: Keyword arguments forwarded to ``nn.Module.to``.
Returns:
QDiffusion: The moved module instance.
"""
module = super().to(*args, **kwargs)
try:
self.device = next(self.parameters()).device
self.dtype = next(self.parameters()).dtype
except StopIteration:
self.device = torch.device("cpu")
return module
[文档]
def forward(self, noisy_tokens: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Runs the proposal model on the current noisy state.
Args:
noisy_tokens: Current noisy token tensor.
**kwargs: Additional keyword arguments forwarded to the proposal
model.
Returns:
torch.Tensor: Proposal logits over the token vocabulary.
Raises:
TypeError: If the proposal model does not implement ``forward``.
"""
if hasattr(self.proposal_model, "forward"):
return self.proposal_model.forward(noisy_tokens, **kwargs)
raise TypeError("proposal_model must implement forward().")
[文档]
def proposal(self, noisy_tokens: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Semantic alias around :meth:`forward` for proposal-side calls.
Args:
noisy_tokens: Current noisy token tensor.
**kwargs: Additional keyword arguments forwarded to the proposal
model.
Returns:
torch.Tensor: Proposal logits over the token vocabulary.
"""
return self.forward(noisy_tokens, **kwargs)
[文档]
def energy(
self,
noisy_tokens: torch.Tensor,
candidate_tokens: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Scores candidate reconstructions conditioned on the noisy state.
Args:
noisy_tokens: Noisy token tensor used as conditioning input.
candidate_tokens: Candidate clean token tensor to score.
attention_mask: Optional attention mask for the energy model.
Returns:
torch.Tensor: A tensor of scalar energies with shape ``[batch, 1]``.
"""
if attention_mask is None:
attention_mask = candidate_tokens.ne(self.pad_id)
score_conditioned = getattr(self.energy_adapter, "score_conditioned", None)
if callable(score_conditioned):
return score_conditioned(
noisy_tokens=noisy_tokens,
candidate_tokens=candidate_tokens,
attention_mask=attention_mask,
)
if noisy_tokens.device.type == "cuda":
outer_context = torch.amp.autocast("cuda", dtype=torch.float32)
inner_context = torch.amp.autocast("cuda", dtype=torch.bfloat16)
else:
outer_context = nullcontext()
inner_context = nullcontext()
with outer_context:
noisy_embeds = self.energy_adapter.embed_tokens(noisy_tokens)
candidate_embeds = self.energy_adapter.embed_tokens(candidate_tokens)
fused_embed = self.vocab_proj(
torch.cat([noisy_embeds, candidate_embeds], dim=-1)
)
with inner_context:
hidden = self.energy_adapter.encode_conditioned(
input_ids=noisy_tokens,
inputs_embeds=fused_embed,
attention_mask=attention_mask,
)
pooled = hidden.mean(dim=1)
return self.energy_head(pooled)
[文档]
def objective(
self, batch: dict[str, torch.Tensor], weighting: str = "constant"
) -> dict[str, torch.Tensor]:
"""Builds the one-step training objective used by an external loop.
Args:
batch: Batch dictionary containing at least ``batch["targets"]``.
weighting: Per-sample timestep weighting mode.
Returns:
dict[str, torch.Tensor]: A dictionary containing proposal logits,
supervision masks, loss weights, and the EBM objective term.
"""
target = batch["targets"]
first_timestep, second_timestep = torch.randint(
1,
self.config.num_diffusion_timesteps + 1,
(2 * target.size(0),),
device=target.device,
).chunk(2)
if self.config.use_coupled_sampling:
sample_outputs = self._sample_coupled(
target,
first_timestep,
second_timestep,
self.get_non_special_symbol_mask(target),
)
target = target.repeat(2, 1)
else:
sample_outputs = self._sample(
target,
first_timestep,
self.get_non_special_symbol_mask(target),
)
noisy_tokens = sample_outputs["x_t"]
timesteps = sample_outputs["t"]
loss_mask = sample_outputs["loss_mask"]
with torch.no_grad():
logits = self.forward(noisy_tokens).detach()
negative_tokens, _ = self._sample_candidates(logits, self.config.num_candidates)
positive_energy = self.energy(noisy_tokens, target, target.ne(self.pad_id))
negative_energy = self._score_candidates(noisy_tokens, negative_tokens).mean(
dim=1, keepdim=True
)
energy_objective = self.softplus(positive_energy) + self.softplus(
-negative_energy
)
weight = self._compute_loss_weight(timesteps, weighting)
return {
"logits": logits,
"targets": target,
"loss_mask": loss_mask,
"weight": weight,
"energy_objective": energy_objective,
}
[文档]
def initialize_state(
self,
input_tokens: torch.Tensor,
partial_masks: torch.Tensor | None = None,
max_steps: int = 500,
temperature: float = 1.0,
) -> dict[str, Any]:
"""Creates the initial decoding state for an external generation loop.
Args:
input_tokens: Initial token tensor.
partial_masks: Optional boolean mask of fixed positions.
max_steps: Planned number of decode iterations.
temperature: Sampling temperature stored in the state payload.
Returns:
dict[str, Any]: A mutable state dictionary suitable for repeated :meth:`step` calls.
"""
output_tokens, output_scores = self._initialize_output_tokens(
input_tokens, partial_masks=partial_masks
)
return {
"output_tokens": output_tokens,
"output_scores": output_scores,
"output_masks": self.get_non_special_symbol_mask(
output_tokens, partial_masks=partial_masks
),
"step": 0,
"max_steps": max_steps,
"history": [output_tokens.clone()],
"temperature": temperature,
"partial_masks": partial_masks,
}
[文档]
def step(
self,
state: dict[str, Any],
partial_masks: torch.Tensor | None = None,
) -> dict[str, Any]:
"""Runs one denoising/reranking step and returns updated state.
Args:
state: Current decode state created by :meth:`initialize_state`.
partial_masks: Optional boolean mask of fixed positions.
Returns:
dict[str, Any]: The updated decode state after one iteration.
"""
partial_masks = (
partial_masks if partial_masks is not None else state.get("partial_masks")
)
step_outputs = self._decode_step(state, partial_masks=partial_masks)
editable_token_mask = self.get_non_special_symbol_mask(
state["output_tokens"], partial_masks=partial_masks
)
output_masks, result_tokens, result_scores = self._reparam_decoding(
output_tokens=state["output_tokens"].clone(),
output_scores=state["output_scores"].clone(),
step_tokens=step_outputs["output_tokens"].clone(),
step_scores=step_outputs["output_scores"].clone(),
decoding_strategy=self.config.decoding_strategy,
still_noisy_mask=state["output_masks"],
editable_token_mask=editable_token_mask,
t=state["step"] + 1,
max_step=state["max_steps"],
noise=self.mask_id,
)
new_state = dict(state)
new_state.update(
output_tokens=result_tokens,
output_scores=result_scores,
output_masks=output_masks,
step=state["step"] + 1,
history=step_outputs["history"],
partial_masks=partial_masks,
)
return new_state
[文档]
def generate(
self,
input_tokens: torch.Tensor,
*,
max_steps: int = 500,
partial_masks: torch.Tensor | None = None,
temperature: float = 1.0,
return_state: bool = False,
) -> torch.Tensor | dict[str, Any]:
"""Runs a complete iterative decoding loop inside the core class.
Args:
input_tokens: Initial token tensor.
max_steps: Number of decode iterations to run.
partial_masks: Optional boolean mask of fixed positions.
temperature: Sampling temperature stored in the decode state.
return_state: Whether to return the full final state dictionary.
Returns:
torch.Tensor | dict[str, Any]: Either the final token tensor or the full decode state.
"""
state = self.initialize_state(
input_tokens=input_tokens,
partial_masks=partial_masks,
max_steps=max_steps,
temperature=temperature,
)
for _ in range(max_steps):
state = self.step(state, partial_masks=partial_masks)
if return_state:
return state
return state["output_tokens"]
# Internal decode and training helpers follow. These stay in the model
# class for now, but they are not part of the intended public surface.
[文档]
def get_non_special_symbol_mask(
self, output_tokens: torch.Tensor, partial_masks: torch.Tensor | None = None
) -> torch.Tensor:
"""Returns a boolean mask of editable non-special-token positions.
Args:
output_tokens: Token tensor to inspect.
partial_masks: Optional boolean mask of positions that should remain
fixed.
Returns:
torch.Tensor: A boolean mask where ``True`` marks editable non-special positions.
"""
editable_token_mask = (
output_tokens.ne(self.pad_id)
& output_tokens.ne(self.bos_id)
& output_tokens.ne(self.eos_id)
)
if partial_masks is not None:
editable_token_mask &= ~partial_masks
return editable_token_mask
def _initialize_output_tokens(
self, input_tokens: torch.Tensor, partial_masks: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Builds the initial fully masked token state for decoding.
Args:
input_tokens: Initial token tensor.
partial_masks: Optional boolean mask of fixed positions.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple
``(output_tokens, output_scores)`` for the initial decode state.
"""
output_mask = self.get_non_special_symbol_mask(
input_tokens, partial_masks=partial_masks
)
output_tokens = input_tokens.masked_fill(output_mask, self.mask_id)
output_scores = torch.zeros_like(output_tokens, dtype=torch.float)
return output_tokens, output_scores
def _sample(
self,
clean_tokens: torch.Tensor,
first_timestep: torch.Tensor,
maskable_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""Applies one-step discrete corruption for training.
Args:
clean_tokens: Clean target tokens.
first_timestep: Sampled diffusion timesteps.
maskable_mask: Boolean mask of positions eligible for masking.
Returns:
dict[str, torch.Tensor]: A dictionary containing corrupted tokens,
timesteps, and loss mask.
"""
noise = torch.rand_like(clean_tokens, dtype=torch.float)
first_timestep_mask = (
noise < (first_timestep / self.config.num_diffusion_timesteps)[:, None]
) & maskable_mask
noisy_tokens = clean_tokens.masked_fill(first_timestep_mask, self.mask_id)
return {
"x_t": noisy_tokens,
"t": first_timestep,
"loss_mask": first_timestep_mask,
}
def _sample_coupled(
self,
clean_tokens: torch.Tensor,
first_timestep: torch.Tensor,
second_timestep: torch.Tensor,
maskable_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""Applies the coupled corruption variant used by RDM-style training.
Args:
clean_tokens: Clean target tokens.
first_timestep: First sampled timestep tensor.
second_timestep: Second sampled timestep tensor.
maskable_mask: Boolean mask of positions eligible for masking.
Returns:
dict[str, torch.Tensor]: A dictionary containing paired
corruptions, timesteps, and loss mask.
"""
same_timestep_mask = first_timestep == second_timestep
first_timestep, second_timestep = (
torch.maximum(first_timestep, second_timestep).float(),
torch.minimum(first_timestep, second_timestep).float(),
)
noise = torch.rand_like(clean_tokens, dtype=torch.float)
first_timestep_mask = (
noise < (first_timestep / self.config.num_diffusion_timesteps)[:, None]
) & maskable_mask
first_noisy_tokens = clean_tokens.masked_fill(first_timestep_mask, self.mask_id)
noise = torch.rand_like(clean_tokens, dtype=torch.float)
second_timestep_mask = first_timestep_mask & (
noise > ((first_timestep - second_timestep) / first_timestep)[:, None]
)
noise = torch.rand_like(clean_tokens[same_timestep_mask], dtype=torch.float)
second_timestep_mask[same_timestep_mask] = (
noise
< (
first_timestep[same_timestep_mask]
/ self.config.num_diffusion_timesteps
)[:, None]
) & maskable_mask[same_timestep_mask]
second_noisy_tokens = clean_tokens.masked_fill(
second_timestep_mask, self.mask_id
)
return {
"x_t": torch.cat([first_noisy_tokens, second_noisy_tokens], dim=0),
"t": torch.cat([first_timestep, second_timestep]),
"loss_mask": torch.cat([first_timestep_mask, second_timestep_mask], dim=0),
}
def _compute_loss_weight(
self, timesteps: torch.Tensor, weighting: str
) -> torch.Tensor:
"""Converts sampled timesteps into per-sample loss weights.
Args:
timesteps: Sampled diffusion timesteps.
weighting: Weighting strategy name.
Returns:
torch.Tensor: A column vector of normalized per-sample weights.
"""
num_timesteps = self.config.num_diffusion_timesteps
weight = {
"linear": num_timesteps - (timesteps - 1),
"constant": num_timesteps * torch.ones_like(timesteps),
}[weighting]
return weight[:, None].float() / num_timesteps
def _mask_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""Suppresses special-token logits before categorical sampling.
Args:
logits: Raw proposal logits.
Returns:
torch.Tensor: A cloned logits tensor with special-token entries masked out.
"""
logits = logits.clone()
logits[..., self.mask_id] = -math.inf
if self.x_id is not None:
logits[..., self.x_id] = -math.inf
logits[..., self.pad_id] = -math.inf
logits[..., self.bos_id] = -math.inf
logits[..., self.eos_id] = -math.inf
return logits
def _reshape_candidates(
self, tensor: torch.Tensor, batch_size: int, num_candidates: int
) -> torch.Tensor:
"""Normalizes sampled candidate tensors to ``[batch, k, seq_len]``.
Args:
tensor: Candidate tensor in one of the supported layouts.
batch_size: Expected batch size.
num_candidates: Expected candidate count.
Returns:
torch.Tensor: A candidate tensor shaped as ``[batch, num_candidates, seq_len]``.
Raises:
ValueError: If the incoming tensor shape is not recognized.
"""
if tensor.shape[0] == num_candidates and tensor.shape[1] == batch_size:
return tensor.permute(1, 0, 2).contiguous()
if tensor.shape[0] == batch_size and tensor.shape[1] == num_candidates:
return tensor.contiguous()
if tensor.shape[0] == batch_size * num_candidates:
return tensor.view(batch_size, num_candidates, -1)
raise ValueError(f"Unexpected candidate tensor shape {tuple(tensor.shape)}.")
def _sample_candidates(
self, logits: torch.Tensor, num_candidates: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Samples proposal candidates from logits.
Args:
logits: Proposal logits for the current decode step.
num_candidates: Number of candidates to sample per sequence.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple ``(tokens, scores)`` shaped as
``[batch, num_candidates, seq_len]``.
"""
samples, scores = stochastic_sample_from_categorical_n(
logits,
temperature=self.config.proposal_temperature,
noise_scale=self.config.proposal_noise_scale,
n=num_candidates,
)
batch_size = logits.size(0)
return (
self._reshape_candidates(samples, batch_size, num_candidates),
self._reshape_candidates(scores, batch_size, num_candidates),
)
def _score_candidates(
self, noisy_tokens: torch.Tensor, candidate_tokens: torch.Tensor
) -> torch.Tensor:
"""Scores a batch of candidate reconstructions with the energy model.
Args:
noisy_tokens: Current noisy token tensor.
candidate_tokens: Candidate reconstructions shaped as
``[batch, num_candidates, seq_len]``.
Returns:
torch.Tensor: Candidate energies shaped as ``[batch, num_candidates]``.
"""
batch_size, num_candidates, seq_len = candidate_tokens.shape
flat_noisy_tokens = (
noisy_tokens.unsqueeze(1)
.expand(-1, num_candidates, -1)
.reshape(batch_size * num_candidates, seq_len)
)
flat_candidate_tokens = candidate_tokens.reshape(
batch_size * num_candidates, seq_len
)
flat_attention_mask = flat_candidate_tokens.ne(self.pad_id)
energy = self.energy(
flat_noisy_tokens, flat_candidate_tokens, flat_attention_mask
)
return energy.view(batch_size, num_candidates)
def _select_candidates(
self,
noisy_tokens: torch.Tensor,
candidate_tokens: torch.Tensor,
candidate_scores: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Selects one candidate per batch element by energy-based reranking.
Args:
noisy_tokens: Current noisy token tensor.
candidate_tokens: Candidate reconstructions.
candidate_scores: Proposal-side candidate scores.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple ``(tokens, scores)``
for the selected candidate per sample.
"""
energies = self._score_candidates(noisy_tokens, candidate_tokens)
neg_energies = -energies
neg_energies = neg_energies - neg_energies.max(dim=-1, keepdim=True)[0]
weights = torch.softmax(neg_energies / self.config.energy_temperature, dim=-1)
selected_idx = torch.multinomial(weights, 1).squeeze(-1)
batch_idx = torch.arange(noisy_tokens.size(0), device=noisy_tokens.device)
return (
candidate_tokens[batch_idx, selected_idx],
candidate_scores[batch_idx, selected_idx],
)
def _resample(self, tokens: torch.Tensor, scores: torch.Tensor) -> None:
"""Mitigates repetition collapse by masked resampling in place.
Args:
tokens: Candidate token tensor updated in place.
scores: Candidate score tensor updated in place.
"""
to_be_resampled = []
resample_input = []
resample_masks = []
resample_scores = []
for batch_index, sequence in enumerate(tokens):
token_positions = {}
max_frequency = -1
for position, token in enumerate(sequence):
token = int(token)
token_positions.setdefault(token, []).append(position)
max_frequency = max(max_frequency, len(token_positions[token]))
if max_frequency <= len(sequence) * self.config.resample_ratio:
continue
mask = torch.zeros_like(sequence).bool()
for token, positions in token_positions.items():
if len(positions) > len(sequence) * self.config.resample_ratio:
mask |= sequence.eq(token)
to_be_resampled.append(batch_index)
resample_scores.append(scores[batch_index])
resample_masks.append(mask)
resample_input.append(sequence.masked_fill(mask, self.mask_id))
if not to_be_resampled:
return
resample_input = torch.stack(resample_input, dim=0).type_as(tokens)
resample_scores = torch.stack(resample_scores, dim=0).type_as(scores)
resample_masks = torch.stack(resample_masks, dim=0).bool()
logits = self._mask_logits(self.forward(resample_input))
logits = top_k_top_p_filtering(logits, top_p=self.config.resample_top_p)
new_tokens, new_scores = stochastic_sample_from_categorical(
logits, temperature=0.0
)
resample_input.masked_scatter_(resample_masks, new_tokens[resample_masks])
resample_scores.masked_scatter_(resample_masks, new_scores[resample_masks])
tokens[to_be_resampled] = resample_input
scores[to_be_resampled] = resample_scores
def _decode_step(
self, state: dict[str, Any], partial_masks: torch.Tensor | None = None
) -> dict[str, Any]:
"""Runs proposal, reranking, and optional resampling for one step.
Args:
state: Current decode state.
partial_masks: Optional boolean mask of fixed positions.
Returns:
dict[str, Any]: Intermediate decode outputs before skeptical remasking.
"""
output_tokens = state["output_tokens"].clone()
output_scores = state["output_scores"].clone()
output_masks = self.get_non_special_symbol_mask(
output_tokens, partial_masks=partial_masks
)
logits = self._mask_logits(self.forward(output_tokens))
if logits.dtype != output_scores.dtype:
logits = logits.type_as(output_scores)
candidate_tokens, candidate_scores = self._sample_candidates(
logits, self.config.num_candidates
)
selected_tokens, selected_scores = self._select_candidates(
output_tokens, candidate_tokens, candidate_scores
)
if not self.config.disable_resample:
self._resample(selected_tokens, selected_scores)
output_tokens.masked_scatter_(output_masks, selected_tokens[output_masks])
output_scores.masked_scatter_(output_masks, selected_scores[output_masks])
history = list(state["history"])
history.append(output_tokens.clone())
return {
"output_tokens": output_tokens,
"output_scores": output_scores,
"history": history,
}
def _reparam_decoding(
self,
output_tokens: torch.Tensor,
output_scores: torch.Tensor,
step_tokens: torch.Tensor,
step_scores: torch.Tensor,
decoding_strategy: str,
still_noisy_mask: torch.Tensor,
editable_token_mask: torch.Tensor,
t: int,
max_step: int,
noise: int | float | torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Applies skeptical remasking to produce the next decode state.
Args:
output_tokens: Previous-step output tokens.
output_scores: Previous-step token scores.
step_tokens: Current-step candidate tokens.
step_scores: Current-step candidate scores.
decoding_strategy: Skeptical-remasking strategy string.
still_noisy_mask: Boolean mask tracking which positions remain noisy.
editable_token_mask: Editable non-special-token mask.
t: Current decode step index, starting from ``1``.
max_step: Total decode step count.
noise: Mask token id or per-position noise tensor.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple
``(new_mask, new_tokens, new_scores)`` describing the next decode
state.
"""
_, condition, topk_mode, schedule = decoding_strategy.split("-")
if schedule == "linear":
rate = 1 - t / max_step
elif schedule == "cosine":
rate = np.cos(t / max_step * np.pi * 0.5)
else:
raise NotImplementedError(f"Unknown schedule: {schedule}")
cutoff_len = (
editable_token_mask.sum(1, keepdim=True).type_as(output_scores) * rate
).long()
scores_for_topk = step_scores.masked_fill(~editable_token_mask, 1000.0)
if topk_mode.startswith("stochastic"):
noise_scale = float(topk_mode.replace("stochastic", ""))
lowest_k_mask = topk_masking(
scores_for_topk,
cutoff_len,
stochastic=True,
temp=noise_scale * rate,
)
elif topk_mode == "deterministic":
lowest_k_mask = topk_masking(scores_for_topk, cutoff_len, stochastic=False)
else:
raise NotImplementedError(f"Unknown topk mode: {topk_mode}")
if condition == "cond":
keep_masked_from_previous = (
(step_tokens == output_tokens)
& (step_scores < output_scores)
& lowest_k_mask
)
elif condition == "uncond":
keep_masked_from_previous = lowest_k_mask
else:
raise NotImplementedError(f"Unknown condition mode: {condition}")
keep_masked_this_step = lowest_k_mask
masked_to_noise = (~still_noisy_mask & keep_masked_from_previous) | (
still_noisy_mask & keep_masked_this_step
)
if isinstance(noise, torch.Tensor):
output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise])
else:
output_tokens.masked_fill_(masked_to_noise, noise)
output_scores.masked_fill_(masked_to_noise, -math.inf)
masked_to_candidate_tokens = still_noisy_mask & ~keep_masked_this_step
output_tokens.masked_scatter_(
masked_to_candidate_tokens, step_tokens[masked_to_candidate_tokens]
)
output_scores.masked_scatter_(
masked_to_candidate_tokens, step_scores[masked_to_candidate_tokens]
)
new_still_noisy_mask = (
still_noisy_mask | keep_masked_from_previous
) & keep_masked_this_step
return new_still_noisy_mask, output_tokens, output_scores