Skip to content

Commit 49199b6

Browse files
committed
chore: refactor to extend to other types
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 749a299 commit 49199b6

File tree

4 files changed

+97
-98
lines changed

4 files changed

+97
-98
lines changed

core/gallery/gallery.go

Lines changed: 24 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -6,82 +6,40 @@ import (
66
"path/filepath"
77
"strings"
88

9-
"dario.cat/mergo"
109
"github.com/mudler/LocalAI/core/config"
1110
"github.com/mudler/LocalAI/pkg/downloader"
1211
"github.com/rs/zerolog/log"
1312
"gopkg.in/yaml.v2"
1413
)
1514

16-
// Installs a model from the gallery
17-
func InstallModelFromGallery(galleries []config.Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan bool) error {
18-
19-
applyModel := func(model *GalleryModel) error {
20-
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
21-
22-
var config Config
23-
24-
if len(model.URL) > 0 {
25-
var err error
26-
config, err = GetGalleryConfigFromURL(model.URL, basePath)
27-
if err != nil {
28-
return err
29-
}
30-
config.Description = model.Description
31-
config.License = model.License
32-
} else if len(model.ConfigFile) > 0 {
33-
// TODO: is this worse than using the override method with a blank cfg yaml?
34-
reYamlConfig, err := yaml.Marshal(model.ConfigFile)
35-
if err != nil {
36-
return err
37-
}
38-
config = Config{
39-
ConfigFile: string(reYamlConfig),
40-
Description: model.Description,
41-
License: model.License,
42-
URLs: model.URLs,
43-
Name: model.Name,
44-
Files: make([]File, 0), // Real values get added below, must be blank
45-
// Prompt Template Skipped for now - I expect in this mode that they will be delivered as files.
46-
}
47-
} else {
48-
return fmt.Errorf("invalid gallery model %+v", model)
49-
}
50-
51-
installName := model.Name
52-
if req.Name != "" {
53-
installName = req.Name
54-
}
55-
56-
// Copy the model configuration from the request schema
57-
config.URLs = append(config.URLs, model.URLs...)
58-
config.Icon = model.Icon
59-
config.Files = append(config.Files, req.AdditionalFiles...)
60-
config.Files = append(config.Files, model.AdditionalFiles...)
61-
62-
// TODO model.Overrides could be merged with user overrides (not defined yet)
63-
if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
64-
return err
65-
}
66-
67-
if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus, enforceScan); err != nil {
68-
return err
69-
}
70-
71-
return nil
15+
func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
16+
var config T
17+
uri := downloader.URI(url)
18+
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
19+
return yaml.Unmarshal(d, &config)
20+
})
21+
if err != nil {
22+
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
23+
return config, err
7224
}
25+
return config, nil
26+
}
7327

74-
models, err := AvailableGalleryModels(galleries, basePath)
28+
func ReadConfigFile[T any](filePath string) (*T, error) {
29+
// Read the YAML file
30+
yamlFile, err := os.ReadFile(filePath)
7531
if err != nil {
76-
return err
32+
return nil, fmt.Errorf("failed to read YAML file: %v", err)
7733
}
7834

79-
model := FindGalleryElement(models, name, basePath)
80-
if model == nil {
81-
return fmt.Errorf("no model found with name %q", name)
35+
// Unmarshal YAML data into a Config struct
36+
var config T
37+
err = yaml.Unmarshal(yamlFile, &config)
38+
if err != nil {
39+
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
8240
}
8341

84-
return applyModel(model)
42+
return &config, nil
8543
}
8644

8745
type GalleryElement interface {
@@ -123,7 +81,7 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) (Galler
12381

12482
// Get models from galleries
12583
for _, gallery := range galleries {
126-
galleryModels, err := getGalleryModels[*GalleryModel](gallery, basePath)
84+
galleryModels, err := getGalleryElements[*GalleryModel](gallery, basePath)
12785
if err != nil {
12886
return nil, err
12987
}
@@ -139,7 +97,7 @@ func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryBack
13997

14098
// Get models from galleries
14199
for _, gallery := range galleries {
142-
galleryModels, err := getGalleryModels[*GalleryBackend](gallery, basePath)
100+
galleryModels, err := getGalleryElements[*GalleryBackend](gallery, basePath)
143101
if err != nil {
144102
return nil, err
145103
}
@@ -164,7 +122,7 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error)
164122
return refFile, err
165123
}
166124

167-
func getGalleryModels[T GalleryElement](gallery config.Gallery, basePath string) ([]T, error) {
125+
func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string) ([]T, error) {
168126
var models []T = []T{}
169127

170128
if strings.HasSuffix(gallery.URL, ".ref") {

core/gallery/models.go

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ prompt_templates:
4343
content: ""
4444
4545
*/
46-
// Config is the model configuration which contains all the model details
46+
// ModelConfig is the model configuration which contains all the model details
4747
// This configuration is read from the gallery endpoint and is used to download and install the model
4848
// It is the internal structure, separated from the request
49-
type Config struct {
49+
type ModelConfig struct {
5050
Description string `yaml:"description"`
5151
Icon string `yaml:"icon"`
5252
License string `yaml:"license"`
@@ -68,37 +68,78 @@ type PromptTemplate struct {
6868
Content string `yaml:"content"`
6969
}
7070

71-
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
72-
var config Config
73-
uri := downloader.URI(url)
74-
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
75-
return yaml.Unmarshal(d, &config)
76-
})
77-
if err != nil {
78-
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
79-
return config, err
71+
// Installs a model from the gallery
72+
func InstallModelFromGallery(galleries []config.Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan bool) error {
73+
74+
applyModel := func(model *GalleryModel) error {
75+
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
76+
77+
var config ModelConfig
78+
79+
if len(model.URL) > 0 {
80+
var err error
81+
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, basePath)
82+
if err != nil {
83+
return err
84+
}
85+
config.Description = model.Description
86+
config.License = model.License
87+
} else if len(model.ConfigFile) > 0 {
88+
// TODO: is this worse than using the override method with a blank cfg yaml?
89+
reYamlConfig, err := yaml.Marshal(model.ConfigFile)
90+
if err != nil {
91+
return err
92+
}
93+
config = ModelConfig{
94+
ConfigFile: string(reYamlConfig),
95+
Description: model.Description,
96+
License: model.License,
97+
URLs: model.URLs,
98+
Name: model.Name,
99+
Files: make([]File, 0), // Real values get added below, must be blank
100+
// Prompt Template Skipped for now - I expect in this mode that they will be delivered as files.
101+
}
102+
} else {
103+
return fmt.Errorf("invalid gallery model %+v", model)
104+
}
105+
106+
installName := model.Name
107+
if req.Name != "" {
108+
installName = req.Name
109+
}
110+
111+
// Copy the model configuration from the request schema
112+
config.URLs = append(config.URLs, model.URLs...)
113+
config.Icon = model.Icon
114+
config.Files = append(config.Files, req.AdditionalFiles...)
115+
config.Files = append(config.Files, model.AdditionalFiles...)
116+
117+
// TODO model.Overrides could be merged with user overrides (not defined yet)
118+
if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
119+
return err
120+
}
121+
122+
if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus, enforceScan); err != nil {
123+
return err
124+
}
125+
126+
return nil
80127
}
81-
return config, nil
82-
}
83128

84-
func ReadConfigFile(filePath string) (*Config, error) {
85-
// Read the YAML file
86-
yamlFile, err := os.ReadFile(filePath)
129+
models, err := AvailableGalleryModels(galleries, basePath)
87130
if err != nil {
88-
return nil, fmt.Errorf("failed to read YAML file: %v", err)
131+
return err
89132
}
90133

91-
// Unmarshal YAML data into a Config struct
92-
var config Config
93-
err = yaml.Unmarshal(yamlFile, &config)
94-
if err != nil {
95-
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
134+
model := FindGalleryElement(models, name, basePath)
135+
if model == nil {
136+
return fmt.Errorf("no model found with name %q", name)
96137
}
97138

98-
return &config, nil
139+
return applyModel(model)
99140
}
100141

101-
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) error {
142+
func InstallModel(basePath, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) error {
102143
// Create base path if it doesn't exist
103144
err := os.MkdirAll(basePath, 0750)
104145
if err != nil {
@@ -222,10 +263,10 @@ func galleryFileName(name string) string {
222263
return "._gallery_" + name + ".yaml"
223264
}
224265

225-
func GetLocalModelConfiguration(basePath string, name string) (*Config, error) {
266+
func GetLocalModelConfiguration(basePath string, name string) (*ModelConfig, error) {
226267
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
227268
galleryFile := filepath.Join(basePath, galleryFileName(name))
228-
return ReadConfigFile(galleryFile)
269+
return ReadConfigFile[ModelConfig](galleryFile)
229270
}
230271

231272
func DeleteModelFromSystem(basePath string, name string, additionalFiles []string) error {
@@ -245,7 +286,7 @@ func DeleteModelFromSystem(basePath string, name string, additionalFiles []strin
245286
var err error
246287
// Delete all the files associated to the model
247288
// read the model config
248-
galleryconfig, err := ReadConfigFile(galleryFile)
289+
galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile)
249290
if err != nil {
250291
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
251292
}

core/gallery/models_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ var _ = Describe("Model test", func() {
2121
tempdir, err := os.MkdirTemp("", "test")
2222
Expect(err).ToNot(HaveOccurred())
2323
defer os.RemoveAll(tempdir)
24-
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
24+
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
2525
Expect(err).ToNot(HaveOccurred())
2626
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
2727
Expect(err).ToNot(HaveOccurred())
@@ -107,7 +107,7 @@ var _ = Describe("Model test", func() {
107107
tempdir, err := os.MkdirTemp("", "test")
108108
Expect(err).ToNot(HaveOccurred())
109109
defer os.RemoveAll(tempdir)
110-
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
110+
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
111111
Expect(err).ToNot(HaveOccurred())
112112

113113
err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
@@ -123,7 +123,7 @@ var _ = Describe("Model test", func() {
123123
tempdir, err := os.MkdirTemp("", "test")
124124
Expect(err).ToNot(HaveOccurred())
125125
defer os.RemoveAll(tempdir)
126-
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
126+
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
127127
Expect(err).ToNot(HaveOccurred())
128128

129129
err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
@@ -149,7 +149,7 @@ var _ = Describe("Model test", func() {
149149
tempdir, err := os.MkdirTemp("", "test")
150150
Expect(err).ToNot(HaveOccurred())
151151
defer os.RemoveAll(tempdir)
152-
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
152+
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
153153
Expect(err).ToNot(HaveOccurred())
154154

155155
err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)

core/gallery/request_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ var _ = Describe("Gallery API tests", func() {
1414
URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main",
1515
},
1616
}
17-
e, err := GetGalleryConfigFromURL(req.URL, "")
17+
e, err := GetGalleryConfigFromURL[ModelConfig](req.URL, "")
1818
Expect(err).ToNot(HaveOccurred())
1919
Expect(e.Name).To(Equal("gpt4all-j"))
2020
})

0 commit comments

Comments
 (0)