Mixture of Experts API¶
MixtureOfExperts
¶
Bases: Module
Mixture of Experts layer with top-k gating and load balancing.
Each expert is a simple feedforward network. The gating network decides which experts to use for each input, and load balancing ensures experts are utilized evenly.
__init__(input_dim, hidden_dim=1024, num_experts=4, top_k=2)
¶
Initialize Mixture of Experts layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension (e.g., 768 * 4 for 4 channels) |
required |
hidden_dim
|
int
|
Hidden dimension for each expert |
1024
|
num_experts
|
int
|
Number of expert networks |
4
|
top_k
|
int
|
Number of experts to select for each input |
2
|
forward(x)
¶
Forward pass through MoE layer.
Uses compiled operations for gating when torch.compile is enabled.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, input_dim) |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor, Tensor]
|
Tuple of: - Output tensor of shape (batch_size, input_dim) - Gating probabilities of shape (batch_size, num_experts) - Top-k indices of shape (batch_size, top_k) |
create_moe_layer(input_dim, hidden_dim=1024, num_experts=4, top_k=2)
¶
Factory function to create a MoE layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension |
required |
hidden_dim
|
int
|
Hidden dimension for experts |
1024
|
num_experts
|
int
|
Number of experts |
4
|
top_k
|
int
|
Number of experts to activate |
2
|
Returns:
| Type | Description |
|---|---|
MixtureOfExperts
|
MixtureOfExperts module |