You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
1.6 KiB

2 years ago
package convert
import (
"errors"
"io"
"io/fs"
2 years ago
"strings"
)
type Tensor interface {
Name() string
Shape() []uint64
Kind() uint32
SetRepacker(repacker)
WriteTo(io.Writer) (int64, error)
}
type tensorBase struct {
name string
shape []uint64
repacker
}
func (t tensorBase) Name() string {
return t.name
}
func (t tensorBase) Shape() []uint64 {
return t.shape
}
2 years ago
const (
tensorKindF32 uint32 = iota
tensorKindF16
)
2 years ago
func (t tensorBase) Kind() uint32 {
2 years ago
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" {
// these tensors are always F32
2 years ago
return 0
2 years ago
}
switch len(t.shape) {
case 0:
panic("invalid tensor shape")
case 1:
2 years ago
return tensorKindF32
2 years ago
default:
2 years ago
return tensorKindF16
2 years ago
}
}
func (t *tensorBase) SetRepacker(fn repacker) {
t.repacker = fn
}
type repacker func(string, []float32, []uint64) ([]float32, error)
2 years ago
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
patterns := []struct {
Pattern string
2 years ago
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
}{
{"model-*-of-*.safetensors", parseSafetensors},
{"model.safetensors", parseSafetensors},
{"adapters.safetensors", parseSafetensors},
{"adapter_model.safetensors", parseSafetensors},
{"pytorch_model-*-of-*.bin", parseTorch},
{"pytorch_model.bin", parseTorch},
{"consolidated.*.pth", parseTorch},
2 years ago
}
for _, pattern := range patterns {
matches, err := fs.Glob(fsys, pattern.Pattern)
2 years ago
if err != nil {
return nil, err
}
if len(matches) > 0 {
2 years ago
return pattern.Func(fsys, replacer, matches...)
2 years ago
}
}
return nil, errors.New("unknown tensor format")
}