mirror of https://gitee.com/namelin2022/ollama
committed by
Michael Yang
13 changed files with 833 additions and 15 deletions
@ -0,0 +1,167 @@ |
|||
package convert |
|||
|
|||
import ( |
|||
"slices" |
|||
"strings" |
|||
|
|||
"github.com/pdevine/tensor" |
|||
"github.com/pdevine/tensor/native" |
|||
|
|||
"github.com/ollama/ollama/fs/ggml" |
|||
) |
|||
|
|||
type llama4Model struct { |
|||
ModelParameters |
|||
TextModel struct { |
|||
llamaModel |
|||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"` |
|||
NumLocalExperts uint32 `json:"num_local_experts"` |
|||
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"` |
|||
UseQKNorm bool `json:"use_qk_norm"` |
|||
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"` |
|||
} `json:"text_config"` |
|||
VisionModel struct { |
|||
NumHiddenLayers uint32 `json:"num_hidden_layers"` |
|||
HiddenSize uint32 `json:"hidden_size"` |
|||
IntermediateSize uint32 `json:"intermediate_size"` |
|||
NumAttentionHeads uint32 `json:"num_attention_heads"` |
|||
ImageSize uint32 `json:"image_size"` |
|||
PatchSize uint32 `json:"patch_size"` |
|||
RopeTheta float32 `json:"rope_theta"` |
|||
NormEpsilon float32 `json:"norm_eps"` |
|||
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"` |
|||
} `json:"vision_config"` |
|||
} |
|||
|
|||
// KV implements ModelConverter.
|
|||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV { |
|||
kv := p.ModelParameters.KV(t) |
|||
kv["general.architecture"] = "llama4" |
|||
|
|||
for k, v := range p.TextModel.KV(t) { |
|||
if strings.HasPrefix(k, "llama.") { |
|||
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v |
|||
} |
|||
} |
|||
|
|||
kv["llama4.intermediate_size"] = p.TextModel.IntermediateSizeMLP |
|||
kv["llama4.intermediate_size_moe"] = p.TextModel.IntermediateSize |
|||
|
|||
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts |
|||
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken |
|||
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep |
|||
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm |
|||
|
|||
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers |
|||
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize |
|||
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize |
|||
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads |
|||
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize |
|||
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize |
|||
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta |
|||
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon |
|||
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio |
|||
return kv |
|||
} |
|||
|
|||
// Replacements implements ModelConverter.
|
|||
func (p *llama4Model) Replacements() []string { |
|||
return append( |
|||
p.TextModel.Replacements(), |
|||
"language_model.", "", |
|||
"vision_model", "v", |
|||
"multi_modal_projector", "mm", |
|||
"feed_forward.down_proj", "ffn_down", |
|||
"feed_forward.up_proj", "ffn_up", |
|||
"feed_forward.gate_proj", "ffn_gate", |
|||
"feed_forward.", "ffn_", |
|||
"shared_expert.down_proj", "down_shexp", |
|||
"shared_expert.gate_proj", "gate_shexp", |
|||
"shared_expert.up_proj", "up_shexp", |
|||
"experts.down_proj", "down_exps.weight", |
|||
"experts.gate_up_proj", "gate_up_exps.weight", |
|||
"router", "gate_inp", |
|||
"patch_embedding.linear", "patch_embedding", |
|||
) |
|||
} |
|||
|
|||
// Tensors implements ModelConverter.
|
|||
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor { |
|||
var out []ggml.Tensor |
|||
|
|||
var textTensors []Tensor |
|||
for _, t := range ts { |
|||
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") { |
|||
out = append(out, ggml.Tensor{ |
|||
Name: t.Name(), |
|||
Kind: t.Kind(), |
|||
Shape: t.Shape(), |
|||
WriterTo: t, |
|||
}) |
|||
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") { |
|||
// gate and up projectors are fused
|
|||
// dims[1], dims[2] must be swapped
|
|||
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
|
|||
halfDim := int(t.Shape()[2]) / 2 |
|||
|
|||
newShape := slices.Clone(t.Shape()) |
|||
newShape[1], newShape[2] = newShape[2]/2, newShape[1] |
|||
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} { |
|||
// clone tensor since we need separate repackers
|
|||
tt := t.Clone() |
|||
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim))) |
|||
out = append(out, ggml.Tensor{ |
|||
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name), |
|||
Kind: tt.Kind(), |
|||
Shape: newShape, |
|||
WriterTo: tt, |
|||
}) |
|||
} |
|||
} else if strings.Contains(t.Name(), "ffn_down_exps") { |
|||
// dims[1], dims[2] must be swapped
|
|||
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
|
|||
t.SetRepacker(p.repack()) |
|||
newShape := slices.Clone(t.Shape()) |
|||
newShape[1], newShape[2] = newShape[2], newShape[1] |
|||
out = append(out, ggml.Tensor{ |
|||
Name: t.Name(), |
|||
Kind: t.Kind(), |
|||
Shape: newShape, |
|||
WriterTo: t, |
|||
}) |
|||
} else { |
|||
textTensors = append(textTensors, t) |
|||
} |
|||
} |
|||
|
|||
p.TextModel.skipRepack = true |
|||
out = append(out, p.TextModel.Tensors(textTensors)...) |
|||
return out |
|||
} |
|||
|
|||
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker { |
|||
return func(name string, data []float32, shape []uint64) ([]float32, error) { |
|||
dims := make([]int, len(shape)) |
|||
for i, dim := range shape { |
|||
dims[i] = int(dim) |
|||
} |
|||
|
|||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) |
|||
t, err := t.Slice(slice...) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
if err := t.T(0, 2, 1); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
t = tensor.Materialize(t) |
|||
// flatten tensor so it can be return as a vector
|
|||
if err := t.Reshape(t.Shape().TotalSize()); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return native.VectorF32(t.(*tensor.Dense)) |
|||
} |
|||
} |
|||
@ -0,0 +1,100 @@ |
|||
package llama4 |
|||
|
|||
import ( |
|||
"bytes" |
|||
"image" |
|||
|
|||
"github.com/ollama/ollama/fs" |
|||
"github.com/ollama/ollama/kvcache" |
|||
"github.com/ollama/ollama/ml" |
|||
"github.com/ollama/ollama/ml/nn" |
|||
"github.com/ollama/ollama/model" |
|||
"github.com/ollama/ollama/model/input" |
|||
) |
|||
|
|||
type Model struct { |
|||
model.Base |
|||
model.BytePairEncoding |
|||
|
|||
*VisionModel `gguf:"v,vision"` |
|||
*Projector `gguf:"mm"` |
|||
*TextModel |
|||
} |
|||
|
|||
type Projector struct { |
|||
Linear1 *nn.Linear `gguf:"linear_1"` |
|||
} |
|||
|
|||
func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor { |
|||
return p.Linear1.Forward(ctx, visionOutputs) |
|||
} |
|||
|
|||
func New(c fs.Config) (model.Model, error) { |
|||
m := Model{ |
|||
BytePairEncoding: model.NewBytePairEncoding( |
|||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), |
|||
&model.Vocabulary{ |
|||
Values: c.Strings("tokenizer.ggml.tokens"), |
|||
Types: c.Uints("tokenizer.ggml.token_type"), |
|||
Merges: c.Strings("tokenizer.ggml.merges"), |
|||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), |
|||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), |
|||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), |
|||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), |
|||
}, |
|||
), |
|||
VisionModel: newVisionModel(c), |
|||
TextModel: newTextModel(c), |
|||
} |
|||
|
|||
m.Cache = kvcache.NewWrapperCache( |
|||
// TODO: pretend this is chunked attention for now
|
|||
kvcache.NewSWACache(8192, m.Shift), |
|||
kvcache.NewCausalCache(m.Shift), |
|||
) |
|||
|
|||
return &m, nil |
|||
} |
|||
|
|||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { |
|||
if len(m.VisionModel.Layers) < 1 { |
|||
return nil, model.ErrNoVisionModel |
|||
} |
|||
|
|||
img, _, err := image.Decode(bytes.NewReader(multimodalData)) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
f32s, aspectRatio, err := m.ProcessImage(ctx, img) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, len(f32s)) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) |
|||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3)) |
|||
return m.Projector.Forward(ctx, visionOutputs), nil |
|||
} |
|||
|
|||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { |
|||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil |
|||
} |
|||
|
|||
func init() { |
|||
model.Register("llama4", New) |
|||
} |
|||
@ -0,0 +1,223 @@ |
|||
package llama4 |
|||
|
|||
import ( |
|||
"cmp" |
|||
"math" |
|||
|
|||
"github.com/ollama/ollama/fs" |
|||
"github.com/ollama/ollama/kvcache" |
|||
"github.com/ollama/ollama/ml" |
|||
"github.com/ollama/ollama/ml/nn" |
|||
"github.com/ollama/ollama/model/input" |
|||
) |
|||
|
|||
type TextAttention struct { |
|||
Query *nn.Linear `gguf:"attn_q"` |
|||
Key *nn.Linear `gguf:"attn_k"` |
|||
Value *nn.Linear `gguf:"attn_v"` |
|||
Output *nn.Linear `gguf:"attn_output"` |
|||
RopeFactors ml.Tensor `gguf:"rope_factors"` |
|||
} |
|||
|
|||
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor { |
|||
batchSize, headDim := hiddenStates.Dim(1), cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) |
|||
|
|||
query := sa.Query.Forward(ctx, hiddenStates) |
|||
key := sa.Key.Forward(ctx, hiddenStates) |
|||
value := sa.Value.Forward(ctx, hiddenStates) |
|||
|
|||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) |
|||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) |
|||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) |
|||
|
|||
if useRope { |
|||
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) |
|||
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) |
|||
|
|||
if opts.useQKNorm { |
|||
query = query.RMSNorm(ctx, nil, opts.eps) |
|||
key = key.RMSNorm(ctx, nil, opts.eps) |
|||
} |
|||
} |
|||
|
|||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache) |
|||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) |
|||
return sa.Output.Forward(ctx, attention) |
|||
} |
|||
|
|||
type TextMLP struct { |
|||
Gate *nn.Linear `gguf:"ffn_gate"` |
|||
Up *nn.Linear `gguf:"ffn_up"` |
|||
Down *nn.Linear `gguf:"ffn_down"` |
|||
} |
|||
|
|||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { |
|||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) |
|||
return mlp.Down.Forward(ctx, hiddenStates) |
|||
} |
|||
|
|||
type TextExperts struct { |
|||
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` |
|||
Up ml.Tensor `gguf:"ffn_up_exps.weight"` |
|||
Down ml.Tensor `gguf:"ffn_down_exps.weight"` |
|||
} |
|||
|
|||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { |
|||
experts := routerLogits.TopK(ctx, opts.numExpertsUsed) |
|||
scores := routerLogits.Sigmoid(ctx).Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, experts) |
|||
|
|||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) |
|||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) |
|||
hiddenStates = hiddenStates.Mul(ctx, scores) |
|||
|
|||
upStates := e.Up.MulmatID(ctx, hiddenStates, experts) |
|||
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts) |
|||
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) |
|||
|
|||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) |
|||
for i := 1; i < opts.numExpertsUsed; i++ { |
|||
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))) |
|||
} |
|||
|
|||
return nextStates |
|||
} |
|||
|
|||
// TextSharedExpert is TextMLP with different names
|
|||
type TextSharedExpert struct { |
|||
Gate *nn.Linear `gguf:"ffn_gate_shexp"` |
|||
Up *nn.Linear `gguf:"ffn_up_shexp"` |
|||
Down *nn.Linear `gguf:"ffn_down_shexp"` |
|||
} |
|||
|
|||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { |
|||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) |
|||
return mlp.Down.Forward(ctx, hiddenStates) |
|||
} |
|||
|
|||
type TextMOE struct { |
|||
Router *nn.Linear `gguf:"ffn_gate_inp"` |
|||
Experts *TextExperts |
|||
SharedExpert *TextSharedExpert |
|||
} |
|||
|
|||
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { |
|||
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) |
|||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize) |
|||
routerLogits := moe.Router.Forward(ctx, hiddenStates) |
|||
|
|||
sharedStates := moe.SharedExpert.Forward(ctx, hiddenStates, opts) |
|||
routedStates := moe.Experts.Forward(ctx, hiddenStates, routerLogits, opts) |
|||
return sharedStates.Add(ctx, routedStates) |
|||
} |
|||
|
|||
type TextFeedForward interface { |
|||
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor |
|||
} |
|||
|
|||
type TextLayer struct { |
|||
AttentionNorm *nn.LayerNorm `gguf:"attn_norm"` |
|||
Attention *TextAttention |
|||
|
|||
FFNNorm *nn.LayerNorm `gguf:"ffn_norm"` |
|||
FeedForward TextFeedForward |
|||
} |
|||
|
|||
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor { |
|||
residual := hiddenStates |
|||
|
|||
// self attention
|
|||
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) |
|||
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, useRope, opts) |
|||
|
|||
if outputs != nil { |
|||
hiddenStates = hiddenStates.Rows(ctx, outputs) |
|||
residual = residual.Rows(ctx, outputs) |
|||
} |
|||
|
|||
hiddenStates = hiddenStates.Add(ctx, residual) |
|||
residual = hiddenStates |
|||
|
|||
hiddenStates = d.FFNNorm.Forward(ctx, hiddenStates, opts.eps) |
|||
hiddenStates = d.FeedForward.Forward(ctx, hiddenStates, opts) |
|||
|
|||
return residual.Add(ctx, hiddenStates) |
|||
} |
|||
|
|||
type TextOptions struct { |
|||
hiddenSize int |
|||
numHeads, numKVHeads, headDim int |
|||
numExperts, numExpertsUsed int |
|||
ropeDim int |
|||
ropeBase, ropeScale float32 |
|||
eps float32 |
|||
interleaveLayerStep int |
|||
useQKNorm bool |
|||
} |
|||
|
|||
type TextModel struct { |
|||
Layers []TextLayer `gguf:"blk"` |
|||
|
|||
TokenEmbedding *nn.Embedding `gguf:"token_embd"` |
|||
OutputNorm *nn.LayerNorm `gguf:"output_norm"` |
|||
Output *nn.Linear `gguf:"output,alt:token_embd"` |
|||
|
|||
*TextOptions |
|||
} |
|||
|
|||
func newTextModel(c fs.Config) *TextModel { |
|||
layers := make([]TextLayer, c.Uint("block_count")) |
|||
interleaveLayerStep := c.Uint("interleave_moe_layer_step", 1) |
|||
for i := range layers { |
|||
if (i+1)%int(interleaveLayerStep) == 0 { |
|||
layers[i] = TextLayer{FeedForward: &TextMOE{}} |
|||
} else { |
|||
layers[i] = TextLayer{FeedForward: &TextMLP{}} |
|||
} |
|||
} |
|||
|
|||
return &TextModel{ |
|||
Layers: layers, |
|||
TextOptions: &TextOptions{ |
|||
hiddenSize: int(c.Uint("embedding_length")), |
|||
numHeads: int(c.Uint("attention.head_count")), |
|||
numKVHeads: int(c.Uint("attention.head_count_kv")), |
|||
headDim: int(c.Uint("attention.head_dim", 128)), |
|||
numExperts: int(c.Uint("expert_count")), |
|||
numExpertsUsed: int(c.Uint("expert_used_count")), |
|||
ropeDim: int(c.Uint("rope.dimension_count")), |
|||
ropeBase: c.Float("rope.freq_base"), |
|||
ropeScale: c.Float("rope.freq_scale", 1), |
|||
eps: c.Float("attention.layer_norm_rms_epsilon"), |
|||
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)), |
|||
useQKNorm: c.Bool("use_qk_norm", true), |
|||
}, |
|||
} |
|||
} |
|||
|
|||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { |
|||
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs) |
|||
|
|||
for i, layer := range m.Layers { |
|||
cache.SetLayer(i) |
|||
wc := cache.(*kvcache.WrapperCache) |
|||
wc.SetLayerType(1) |
|||
useChunkedAttention := (i+1)%4 != 0 |
|||
if useChunkedAttention { |
|||
wc.SetLayerType(0) |
|||
} |
|||
|
|||
var lastLayerOutputs ml.Tensor |
|||
if i == len(m.Layers)-1 { |
|||
lastLayerOutputs = outputs |
|||
} |
|||
|
|||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, useChunkedAttention, m.TextOptions) |
|||
} |
|||
|
|||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) |
|||
return m.Output.Forward(ctx, hiddenStates) |
|||
} |
|||
|
|||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { |
|||
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil |
|||
} |
|||
@ -0,0 +1,256 @@ |
|||
package llama4 |
|||
|
|||
import ( |
|||
"math" |
|||
|
|||
"github.com/ollama/ollama/fs" |
|||
"github.com/ollama/ollama/ml" |
|||
"github.com/ollama/ollama/ml/nn" |
|||
) |
|||
|
|||
type VisionAttention struct { |
|||
Query *nn.Linear `gguf:"attn_q"` |
|||
Key *nn.Linear `gguf:"attn_k"` |
|||
Value *nn.Linear `gguf:"attn_v"` |
|||
Output *nn.Linear `gguf:"attn_output"` |
|||
} |
|||
|
|||
// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor.
|
|||
// This is equivalent to the Pytorch implmentation using half rotations:
|
|||
//
|
|||
// cos, sin = torch.cos(freqs), torch.sin(freqs)
|
|||
// cos = cos.unsqueeze(-1)
|
|||
// sin = sin.unsqueeze(-1)
|
|||
// t = t.reshape(*t.shape[:-1], -1, 2)
|
|||
// t_out = (t * cos) + (_rotate_half(t) * sin)
|
|||
// t_out = t_out.flatten(3)
|
|||
//
|
|||
// Which is equivalent to the Pytorch implementation using complex numbers:
|
|||
//
|
|||
// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2))
|
|||
// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:]
|
|||
// freqs_ci = freqs_ci.to(t_.device)
|
|||
// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3)
|
|||
//
|
|||
// Due to the 1) the dimensional and 2) the datatype limitations of current backends,
|
|||
// we need to use a different approach to achieve the same result.
|
|||
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { |
|||
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3) |
|||
|
|||
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3)) |
|||
|
|||
// t1 = t[..., 0::2]
|
|||
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) |
|||
t1 = t1.Reshape(ctx, width/2, height, channels, tiles) |
|||
|
|||
// t2 = t[..., 1::2]
|
|||
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) |
|||
t2 = t2.Reshape(ctx, width/2, height, channels, tiles) |
|||
|
|||
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
|
|||
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0) |
|||
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3)) |
|||
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) |
|||
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles) |
|||
|
|||
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
|
|||
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) |
|||
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3)) |
|||
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) |
|||
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles) |
|||
|
|||
return cosOut.Add(ctx, sinOut) |
|||
} |
|||
|
|||
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor { |
|||
headDim := opts.hiddenSize / opts.numHeads |
|||
|
|||
query := sa.Query.Forward(ctx, hiddenState) |
|||
key := sa.Key.Forward(ctx, hiddenState) |
|||
value := sa.Value.Forward(ctx, hiddenState) |
|||
|
|||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2)) |
|||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2)) |
|||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2)) |
|||
|
|||
query = applyVisionRotaryEmbedding(ctx, query, cos, sin) |
|||
key = applyVisionRotaryEmbedding(ctx, key, cos, sin) |
|||
|
|||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) |
|||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3)) |
|||
return sa.Output.Forward(ctx, attention) |
|||
} |
|||
|
|||
type VisionMLP struct { |
|||
FC1 *nn.Linear `gguf:"fc1"` |
|||
FC2 *nn.Linear `gguf:"fc2"` |
|||
} |
|||
|
|||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { |
|||
hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx) |
|||
hiddenStates = mlp.FC2.Forward(ctx, hiddenStates) |
|||
return hiddenStates |
|||
} |
|||
|
|||
type VisionLayer struct { |
|||
InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"` |
|||
*VisionAttention |
|||
|
|||
PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"` |
|||
*VisionMLP `gguf:"mlp"` |
|||
} |
|||
|
|||
func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor { |
|||
residual := hiddenStates |
|||
|
|||
// self attention
|
|||
hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps) |
|||
hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts) |
|||
hiddenStates = hiddenStates.Add(ctx, residual) |
|||
|
|||
// MLP
|
|||
residual = hiddenStates |
|||
hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps) |
|||
hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts) |
|||
hiddenStates = hiddenStates.Add(ctx, residual) |
|||
|
|||
return hiddenStates |
|||
} |
|||
|
|||
type VisionAdapter struct { |
|||
FC1 *nn.Linear `gguf:"mlp.fc1"` |
|||
FC2 *nn.Linear `gguf:"mlp.fc2"` |
|||
} |
|||
|
|||
func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { |
|||
patches := hiddenStates.Dim(1) |
|||
patchSize := int(math.Sqrt(float64(patches))) |
|||
|
|||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2)) |
|||
|
|||
channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3) |
|||
|
|||
channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio) |
|||
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles) |
|||
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) |
|||
|
|||
channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio) |
|||
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles) |
|||
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) |
|||
|
|||
hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles) |
|||
|
|||
hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx) |
|||
hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx) |
|||
return hiddenStates |
|||
} |
|||
|
|||
type VisionOptions struct { |
|||
hiddenSize, numHeads int |
|||
imageSize, patchSize int |
|||
|
|||
ropeTheta float32 |
|||
eps float32 |
|||
pixelShuffleRatio float32 |
|||
} |
|||
|
|||
type PatchEmbedding struct { |
|||
*nn.Linear |
|||
} |
|||
|
|||
func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { |
|||
kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2)) |
|||
hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1) |
|||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3)) |
|||
return p.Linear.Forward(ctx, hiddenStates) |
|||
} |
|||
|
|||
type VisionModel struct { |
|||
Layers []VisionLayer `gguf:"blk"` |
|||
|
|||
*PatchEmbedding `gguf:"patch_embedding"` |
|||
ClassEmbedding ml.Tensor `gguf:"class_embedding"` |
|||
PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"` |
|||
|
|||
LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"` |
|||
LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"` |
|||
|
|||
*VisionAdapter `gguf:"vision_adapter"` |
|||
|
|||
*VisionOptions |
|||
} |
|||
|
|||
func newVisionModel(c fs.Config) *VisionModel { |
|||
return &VisionModel{ |
|||
Layers: make([]VisionLayer, c.Uint("vision.block_count")), |
|||
VisionOptions: &VisionOptions{ |
|||
hiddenSize: int(c.Uint("vision.embedding_length")), |
|||
numHeads: int(c.Uint("vision.attention.head_count")), |
|||
imageSize: int(c.Uint("vision.image_size")), |
|||
patchSize: int(c.Uint("vision.patch_size")), |
|||
ropeTheta: float32(c.Float("vision.rope.freq_base")), |
|||
eps: c.Float("vision.layer_norm_epsilon"), |
|||
pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")), |
|||
}, |
|||
} |
|||
} |
|||
|
|||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { |
|||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions) |
|||
hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1) |
|||
|
|||
hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding) |
|||
hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps) |
|||
|
|||
cos, sin := m.rotaryEmbedding(ctx) |
|||
for _, layer := range m.Layers { |
|||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions) |
|||
} |
|||
|
|||
hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps) |
|||
hiddenStates = hiddenStates.Unpad(ctx, 0, 1, 0, 0) |
|||
hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions) |
|||
return hiddenStates |
|||
} |
|||
|
|||
// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function
|
|||
// which in turn mimics Python's // operator.
|
|||
func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T { |
|||
if b == 0 { |
|||
panic("division by zero") |
|||
} |
|||
|
|||
if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 { |
|||
return a / b |
|||
} |
|||
|
|||
return a/b - 1 |
|||
} |
|||
|
|||
func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) { |
|||
patchesPerSide := m.imageSize / m.patchSize |
|||
numPatches := patchesPerSide*patchesPerSide + 1 |
|||
|
|||
headDim := m.hiddenSize / m.numHeads |
|||
freqDim := headDim / 2 |
|||
|
|||
freqs := make([]float32, numPatches*freqDim) |
|||
for i := range numPatches - 1 { |
|||
for j := 0; j < freqDim; j += 2 { |
|||
positionX := i*freqDim/2 + j/2 |
|||
positionY := (i+numPatches)*freqDim/2 + j/2 |
|||
ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim)) |
|||
freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq) |
|||
freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq) |
|||
} |
|||
} |
|||
|
|||
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
|
|||
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) |
|||
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) |
|||
return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx) |
|||
} |
|||
Loading…
Reference in new issue