|
|
|
@ -63,9 +63,9 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOp |
|
|
|
} |
|
|
|
|
|
|
|
type TextExperts struct { |
|
|
|
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` |
|
|
|
Up ml.Tensor `gguf:"ffn_up_exps.weight"` |
|
|
|
Down ml.Tensor `gguf:"ffn_down_exps.weight"` |
|
|
|
Gate *nn.Linear `gguf:"ffn_gate_exps"` |
|
|
|
Up *nn.Linear `gguf:"ffn_up_exps"` |
|
|
|
Down *nn.Linear `gguf:"ffn_down_exps"` |
|
|
|
} |
|
|
|
|
|
|
|
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { |
|
|
|
@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens |
|
|
|
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) |
|
|
|
hiddenStates = hiddenStates.Mul(ctx, scores) |
|
|
|
|
|
|
|
upStates := e.Up.MulmatID(ctx, hiddenStates, experts) |
|
|
|
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts) |
|
|
|
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) |
|
|
|
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts) |
|
|
|
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts) |
|
|
|
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) |
|
|
|
|
|
|
|
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) |
|
|
|
for i := 1; i < opts.numExpertsUsed; i++ { |
|
|
|
|