Mixture of Experts API

MixtureOfExperts

Bases: Module

Mixture of Experts layer with top-k gating and load balancing.

Each expert is a simple feedforward network. The gating network decides which experts to use for each input, and load balancing ensures experts are utilized evenly.

__init__(input_dim, hidden_dim=1024, num_experts=4, top_k=2)

Initialize Mixture of Experts layer.

Parameters:

Name Type Description Default
input_dim int

Input dimension (e.g., 768 * 4 for 4 channels)

required
hidden_dim int

Hidden dimension for each expert

1024
num_experts int

Number of expert networks

4
top_k int

Number of experts to select for each input

2

forward(x)

Forward pass through MoE layer.

Uses compiled operations for gating when torch.compile is enabled.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, input_dim)

required

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

Tuple of: - Output tensor of shape (batch_size, input_dim) - Gating probabilities of shape (batch_size, num_experts) - Top-k indices of shape (batch_size, top_k)

create_moe_layer(input_dim, hidden_dim=1024, num_experts=4, top_k=2)

Factory function to create a MoE layer.

Parameters:

Name Type Description Default
input_dim int

Input dimension

required
hidden_dim int

Hidden dimension for experts

1024
num_experts int

Number of experts

4
top_k int

Number of experts to activate

2

Returns:

Type Description
MixtureOfExperts

MixtureOfExperts module