From aac0e4a3ab534d080728b6d9338ca9ceef20beee Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 14 Aug 2025 15:24:54 +0200 Subject: [PATCH] feat(backends): add system backend, refactor - Add a system backend path - Refactor and consolidate system information in system state - Use system state in all the components to figure out the system paths to used whenever needed - Refactor BackendConfig -> ModelConfig. This was otherway misleading as now we do have a backend configuration which is not the model config. Signed-off-by: Ettore Di Giacinto --- core/application/application.go | 10 +- core/application/startup.go | 29 +- core/backend/detection.go | 4 +- core/backend/embeddings.go | 6 +- core/backend/image.go | 8 +- core/backend/llm.go | 6 +- core/backend/llm_test.go | 4 +- core/backend/options.go | 8 +- core/backend/rerank.go | 4 +- core/backend/soundgeneration.go | 6 +- core/backend/token_metrics.go | 4 +- core/backend/tokenize.go | 6 +- core/backend/transcript.go | 10 +- core/backend/tts.go | 12 +- core/backend/vad.go | 4 +- core/backend/video.go | 4 +- core/cli/backends.go | 36 ++- core/cli/models.go | 25 +- core/cli/run.go | 14 +- core/cli/soundgeneration.go | 14 +- core/cli/transcript.go | 19 +- core/cli/tts.go | 15 +- core/cli/util.go | 19 +- core/cli/worker/worker.go | 5 +- core/cli/worker/worker_llamacpp.go | 17 +- core/cli/worker/worker_p2p.go | 11 +- core/config/application_config.go | 16 +- core/config/backend_config.go | 100 +++---- core/config/backend_config_filter.go | 12 +- core/config/backend_config_loader.go | 76 ++--- core/config/backend_config_test.go | 24 +- core/config/config_test.go | 20 +- core/config/gguf.go | 2 +- core/config/guesser.go | 2 +- core/gallery/backends.go | 84 ++++-- core/gallery/backends_test.go | 261 +++++++++++++----- core/gallery/gallery.go | 23 +- core/gallery/models.go | 78 ++++-- core/gallery/models_test.go | 39 ++- core/http/app.go | 2 +- core/http/app_test.go | 38 ++- .../endpoints/elevenlabs/soundgeneration.go | 4 +- core/http/endpoints/elevenlabs/tts.go | 4 +- core/http/endpoints/jina/rerank.go | 4 +- core/http/endpoints/localai/backend.go | 25 +- core/http/endpoints/localai/detection.go | 4 +- core/http/endpoints/localai/gallery.go | 9 +- .../endpoints/localai/get_token_metrics.go | 4 +- core/http/endpoints/localai/tokenize.go | 4 +- core/http/endpoints/localai/tts.go | 4 +- core/http/endpoints/localai/vad.go | 4 +- core/http/endpoints/localai/video.go | 4 +- core/http/endpoints/localai/welcome.go | 8 +- core/http/endpoints/openai/chat.go | 10 +- core/http/endpoints/openai/completion.go | 6 +- core/http/endpoints/openai/edit.go | 4 +- core/http/endpoints/openai/embeddings.go | 4 +- core/http/endpoints/openai/image.go | 4 +- core/http/endpoints/openai/inference.go | 4 +- core/http/endpoints/openai/list.go | 2 +- core/http/endpoints/openai/realtime.go | 12 +- core/http/endpoints/openai/realtime_model.go | 39 ++- core/http/endpoints/openai/transcription.go | 4 +- core/http/middleware/request.go | 28 +- core/http/routes/elevenlabs.go | 2 +- core/http/routes/jina.go | 2 +- core/http/routes/localai.go | 15 +- core/http/routes/ui.go | 50 ++-- core/http/routes/ui_backend_gallery.go | 4 +- core/http/routes/ui_gallery.go | 8 +- core/services/backend_monitor.go | 18 +- core/services/backends.go | 8 +- core/services/gallery.go | 10 +- core/services/list_models.go | 6 +- core/services/models.go | 67 ++--- core/startup/backend_preload.go | 14 +- core/startup/model_preload.go | 29 +- core/startup/model_preload_test.go | 11 +- core/templates/evaluator.go | 4 +- core/templates/evaluator_test.go | 24 +- pkg/model/loader.go | 5 +- pkg/model/loader_test.go | 8 +- pkg/system/capabilities.go | 23 -- pkg/system/state.go | 61 ++++ tests/integration/stores_test.go | 10 +- 85 files changed, 999 insertions(+), 652 deletions(-) create mode 100644 pkg/system/state.go diff --git a/core/application/application.go b/core/application/application.go index a990866019a0..d49260eae4ad 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -7,7 +7,7 @@ import ( ) type Application struct { - backendLoader *config.BackendConfigLoader + backendLoader *config.ModelConfigLoader modelLoader *model.ModelLoader applicationConfig *config.ApplicationConfig templatesEvaluator *templates.Evaluator @@ -15,14 +15,14 @@ type Application struct { func newApplication(appConfig *config.ApplicationConfig) *Application { return &Application{ - backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath), - modelLoader: model.NewModelLoader(appConfig.ModelPath, appConfig.SingleBackend), + backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath), + modelLoader: model.NewModelLoader(appConfig.SystemState, appConfig.SingleBackend), applicationConfig: appConfig, - templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath), + templatesEvaluator: templates.NewEvaluator(appConfig.SystemState.Model.ModelsPath), } } -func (a *Application) BackendLoader() *config.BackendConfigLoader { +func (a *Application) BackendLoader() *config.ModelConfigLoader { return a.backendLoader } diff --git a/core/application/startup.go b/core/application/startup.go index cbb0be1eb53d..8ebd44071597 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -20,7 +20,7 @@ func New(opts ...config.AppOption) (*Application, error) { options := config.NewApplicationConfig(opts...) application := newApplication(options) - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath) + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.SystemState.Model.ModelsPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) caps, err := xsysinfo.CPUCapabilities() if err == nil { @@ -35,10 +35,11 @@ func New(opts ...config.AppOption) (*Application, error) { } // Make sure directories exists - if options.ModelPath == "" { - return nil, fmt.Errorf("options.ModelPath cannot be empty") + if options.SystemState.Model.ModelsPath == "" { + return nil, fmt.Errorf("models path cannot be empty") } - err = os.MkdirAll(options.ModelPath, 0750) + + err = os.MkdirAll(options.SystemState.Model.ModelsPath, 0750) if err != nil { return nil, fmt.Errorf("unable to create ModelPath: %q", err) } @@ -55,50 +56,50 @@ func New(opts ...config.AppOption) (*Application, error) { } } - if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { + if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { log.Error().Err(err).Msg("error installing models") } for _, backend := range options.ExternalBackends { - if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.BackendsPath, nil, backend, "", ""); err != nil { + if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, nil, backend, "", ""); err != nil { log.Error().Err(err).Msg("error installing external backend") } } configLoaderOpts := options.ToConfigLoaderOptions() - if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil { + if err := application.BackendLoader().LoadModelConfigsFromPath(options.SystemState.Model.ModelsPath, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config files") } - if err := gallery.RegisterBackends(options.BackendsPath, application.ModelLoader()); err != nil { + if err := gallery.RegisterBackends(options.SystemState, application.ModelLoader()); err != nil { log.Error().Err(err).Msg("error registering external backends") } if options.ConfigFile != "" { - if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { + if err := application.BackendLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config file") } } - if err := application.BackendLoader().Preload(options.ModelPath); err != nil { + if err := application.BackendLoader().Preload(options.SystemState.Model.ModelsPath); err != nil { log.Error().Err(err).Msg("error downloading models") } if options.PreloadJSONModels != "" { - if err := services.ApplyGalleryFromString(options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil { + if err := services.ApplyGalleryFromString(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil { return nil, err } } if options.PreloadModelsFromPath != "" { - if err := services.ApplyGalleryFromFile(options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil { + if err := services.ApplyGalleryFromFile(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil { return nil, err } } if options.Debug { - for _, v := range application.BackendLoader().GetAllBackendConfigs() { + for _, v := range application.BackendLoader().GetAllModelsConfigs() { log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v) } } @@ -131,7 +132,7 @@ func New(opts ...config.AppOption) (*Application, error) { if options.LoadToMemory != nil && !options.SingleBackend { for _, m := range options.LoadToMemory { - cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options) + cfg, err := application.BackendLoader().LoadModelConfigFileByNameDefaultOptions(m, options) if err != nil { return nil, err } diff --git a/core/backend/detection.go b/core/backend/detection.go index dc89560a89c9..a3a443952734 100644 --- a/core/backend/detection.go +++ b/core/backend/detection.go @@ -13,9 +13,9 @@ func Detection( sourceFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, - backendConfig config.BackendConfig, + modelConfig config.ModelConfig, ) (*proto.DetectResponse, error) { - opts := ModelOptions(backendConfig, appConfig) + opts := ModelOptions(modelConfig, appConfig) detectionModel, err := loader.Load(opts...) if err != nil { return nil, err diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index aece0cdd9e5c..c809992a4a4f 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -9,9 +9,9 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { - opts := ModelOptions(backendConfig, appConfig) + opts := ModelOptions(modelConfig, appConfig) inferenceModel, err := loader.Load(opts...) if err != nil { @@ -23,7 +23,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo switch model := inferenceModel.(type) { case grpc.Backend: fn = func() ([]float32, error) { - predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath) + predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath) if len(tokens) > 0 { embeds := []int32{} diff --git a/core/backend/image.go b/core/backend/image.go index 9f838a37356d..3a80563cf277 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -7,9 +7,9 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { - opts := ModelOptions(backendConfig, appConfig) + opts := ModelOptions(modelConfig, appConfig) inferenceModel, err := loader.Load( opts..., ) @@ -27,12 +27,12 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat Mode: int32(mode), Step: int32(step), Seed: int32(seed), - CLIPSkip: int32(backendConfig.Diffusers.ClipSkip), + CLIPSkip: int32(modelConfig.Diffusers.ClipSkip), PositivePrompt: positive_prompt, NegativePrompt: negative_prompt, Dst: dst, Src: src, - EnableParameters: backendConfig.Diffusers.EnableParameters, + EnableParameters: modelConfig.Diffusers.EnableParameters, RefImages: refImages, }) return err diff --git a/core/backend/llm.go b/core/backend/llm.go index 05151383354e..a74141fe675e 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -35,7 +35,7 @@ type TokenUsage struct { TimingTokenGeneration float64 } -func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, cl *config.BackendConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { +func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model // Check if the modelFile exists, if it doesn't try to load it from the gallery @@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im if !slices.Contains(modelNames, c.Name) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, c.Name, loader.ModelPath, o.BackendsPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) + err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) if err != nil { log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile) //return nil, err @@ -201,7 +201,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) var mu sync.Mutex = sync.Mutex{} -func Finetune(config config.BackendConfig, input, prediction string) string { +func Finetune(config config.ModelConfig, input, prediction string) string { if config.Echo { prediction = input + prediction } diff --git a/core/backend/llm_test.go b/core/backend/llm_test.go index f7630702e2a0..ea68a93156ef 100644 --- a/core/backend/llm_test.go +++ b/core/backend/llm_test.go @@ -12,14 +12,14 @@ import ( var _ = Describe("LLM tests", func() { Context("Finetune LLM output", func() { var ( - testConfig config.BackendConfig + testConfig config.ModelConfig input string prediction string result string ) BeforeEach(func() { - testConfig = config.BackendConfig{ + testConfig = config.ModelConfig{ PredictionOptions: schema.PredictionOptions{ Echo: false, }, diff --git a/core/backend/options.go b/core/backend/options.go index cfe7b35e4902..a64fbb74bbcb 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -11,7 +11,7 @@ import ( "github.com/rs/zerolog/log" ) -func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option { +func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option { name := c.Name if name == "" { name = c.Model @@ -58,7 +58,7 @@ func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ... return append(defOpts, opts...) } -func getSeed(c config.BackendConfig) int32 { +func getSeed(c config.ModelConfig) int32 { var seed int32 = config.RAND_SEED if c.Seed != nil { @@ -72,7 +72,7 @@ func getSeed(c config.BackendConfig) int32 { return seed } -func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions { +func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch @@ -195,7 +195,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions { } } -func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions { +func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions { promptCachePath := "" if c.PromptCachePath != "" { p := filepath.Join(modelPath, c.PromptCachePath) diff --git a/core/backend/rerank.go b/core/backend/rerank.go index d7937ce45f52..068d05e68bc2 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -9,8 +9,8 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { - opts := ModelOptions(backendConfig, appConfig) +func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) { + opts := ModelOptions(modelConfig, appConfig) rerankModel, err := loader.Load(opts...) if err != nil { return nil, err diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index 6379fb289781..29ba856bc429 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -21,10 +21,10 @@ func SoundGeneration( sourceDivisor *int32, loader *model.ModelLoader, appConfig *config.ApplicationConfig, - backendConfig config.BackendConfig, + modelConfig config.ModelConfig, ) (string, *proto.Result, error) { - opts := ModelOptions(backendConfig, appConfig) + opts := ModelOptions(modelConfig, appConfig) soundGenModel, err := loader.Load(opts...) if err != nil { return "", nil, err @@ -49,7 +49,7 @@ func SoundGeneration( res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{ Text: text, - Model: backendConfig.Model, + Model: modelConfig.Model, Dst: filePath, Sample: doSample, Duration: duration, diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go index ac34e34fdd4d..c3e15d773fc5 100644 --- a/core/backend/token_metrics.go +++ b/core/backend/token_metrics.go @@ -13,9 +13,9 @@ func TokenMetrics( modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, - backendConfig config.BackendConfig) (*proto.MetricsResponse, error) { + modelConfig config.ModelConfig) (*proto.MetricsResponse, error) { - opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) + opts := ModelOptions(modelConfig, appConfig, model.WithModel(modelFile)) model, err := loader.Load(opts...) if err != nil { return nil, err diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go index 43c46134624f..e85958b27053 100644 --- a/core/backend/tokenize.go +++ b/core/backend/tokenize.go @@ -7,19 +7,19 @@ import ( "github.com/mudler/LocalAI/pkg/model" ) -func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) { +func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) { var inferenceModel grpc.Backend var err error - opts := ModelOptions(backendConfig, appConfig) + opts := ModelOptions(modelConfig, appConfig) inferenceModel, err = loader.Load(opts...) if err != nil { return schema.TokenizeResponse{}, err } defer loader.Close() - predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath) + predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath) predictOptions.Prompt = s // tokenize the string diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 64f9c5e2b5d3..77aa7d0354da 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -12,13 +12,13 @@ import ( "github.com/mudler/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { +func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { - if backendConfig.Backend == "" { - backendConfig.Backend = model.WhisperBackend + if modelConfig.Backend == "" { + modelConfig.Backend = model.WhisperBackend } - opts := ModelOptions(backendConfig, appConfig) + opts := ModelOptions(modelConfig, appConfig) transcriptionModel, err := ml.Load(opts...) if err != nil { @@ -34,7 +34,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL Dst: audio, Language: language, Translate: translate, - Threads: uint32(*backendConfig.Threads), + Threads: uint32(*modelConfig.Threads), }) if err != nil { return nil, err diff --git a/core/backend/tts.go b/core/backend/tts.go index 5793957e3d1e..2a9a9c93d0de 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -19,9 +19,9 @@ func ModelTTS( language string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, - backendConfig config.BackendConfig, + modelConfig config.ModelConfig, ) (string, *proto.Result, error) { - opts := ModelOptions(backendConfig, appConfig) + opts := ModelOptions(modelConfig, appConfig) ttsModel, err := loader.Load(opts...) if err != nil { return "", nil, err @@ -29,7 +29,7 @@ func ModelTTS( defer loader.Close() if ttsModel == nil { - return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model) + return "", nil, fmt.Errorf("could not load tts model %q", modelConfig.Model) } audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio") @@ -47,14 +47,14 @@ func ModelTTS( // Checking first that it exists and is not outside ModelPath // TODO: we should actually first check if the modelFile is looking like // a FS path - mp := filepath.Join(loader.ModelPath, backendConfig.Model) + mp := filepath.Join(loader.ModelPath, modelConfig.Model) if _, err := os.Stat(mp); err == nil { - if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil { + if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil { return "", nil, err } modelPath = mp } else { - modelPath = backendConfig.Model // skip this step if it fails????? + modelPath = modelConfig.Model // skip this step if it fails????? } res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{ diff --git a/core/backend/vad.go b/core/backend/vad.go index 741dbb19e61e..91f70bbc3c79 100644 --- a/core/backend/vad.go +++ b/core/backend/vad.go @@ -13,8 +13,8 @@ func VAD(request *schema.VADRequest, ctx context.Context, ml *model.ModelLoader, appConfig *config.ApplicationConfig, - backendConfig config.BackendConfig) (*schema.VADResponse, error) { - opts := ModelOptions(backendConfig, appConfig) + modelConfig config.ModelConfig) (*schema.VADResponse, error) { + opts := ModelOptions(modelConfig, appConfig) vadModel, err := ml.Load(opts...) if err != nil { return nil, err diff --git a/core/backend/video.go b/core/backend/video.go index 49241070491c..b5a4dbc041c9 100644 --- a/core/backend/video.go +++ b/core/backend/video.go @@ -7,9 +7,9 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { +func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) { - opts := ModelOptions(backendConfig, appConfig) + opts := ModelOptions(modelConfig, appConfig) inferenceModel, err := loader.Load( opts..., ) diff --git a/core/cli/backends.go b/core/cli/backends.go index 59f0462cd47e..37d719c14be3 100644 --- a/core/cli/backends.go +++ b/core/cli/backends.go @@ -6,6 +6,7 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/startup" @@ -14,8 +15,9 @@ import ( ) type BackendsCMDFlags struct { - BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` - BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"` + BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` + BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"` + BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"` } type BackendsList struct { @@ -48,7 +50,15 @@ func (bl *BackendsList) Run(ctx *cliContext.Context) error { log.Error().Err(err).Msg("unable to load galleries") } - backends, err := gallery.AvailableBackends(galleries, bl.BackendsPath) + systemState, err := system.GetSystemState( + system.WithBackendSystemPath(bl.BackendsSystemPath), + system.WithBackendPath(bl.BackendsPath), + ) + if err != nil { + return err + } + + backends, err := gallery.AvailableBackends(galleries, systemState) if err != nil { return err } @@ -68,6 +78,14 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error { log.Error().Err(err).Msg("unable to load galleries") } + systemState, err := system.GetSystemState( + system.WithBackendSystemPath(bi.BackendsSystemPath), + system.WithBackendPath(bi.BackendsPath), + ) + if err != nil { + return err + } + progressBar := progressbar.NewOptions( 1000, progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", bi.BackendArgs)), @@ -82,7 +100,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error { } } - err := startup.InstallExternalBackends(galleries, bi.BackendsPath, progressCallback, bi.BackendArgs, bi.Name, bi.Alias) + err = startup.InstallExternalBackends(galleries, systemState, progressCallback, bi.BackendArgs, bi.Name, bi.Alias) if err != nil { return err } @@ -94,7 +112,15 @@ func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error { for _, backendName := range bu.BackendArgs { log.Info().Str("backend", backendName).Msg("uninstalling backend") - err := gallery.DeleteBackendFromSystem(bu.BackendsPath, backendName) + systemState, err := system.GetSystemState( + system.WithBackendSystemPath(bu.BackendsSystemPath), + system.WithBackendPath(bu.BackendsPath), + ) + if err != nil { + return err + } + + err = gallery.DeleteBackendFromSystem(systemState, backendName) if err != nil { return err } diff --git a/core/cli/models.go b/core/cli/models.go index 1dc018572159..9fe9163831ba 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -11,6 +11,7 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/pkg/downloader" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog/log" "github.com/schollz/progressbar/v3" ) @@ -45,7 +46,14 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error { log.Error().Err(err).Msg("unable to load galleries") } - models, err := gallery.AvailableGalleryModels(galleries, ml.ModelsPath) + systemState, err := system.GetSystemState( + system.WithModelPath(ml.ModelsPath), + system.WithBackendPath(ml.BackendsPath), + ) + if err != nil { + return err + } + models, err := gallery.AvailableGalleryModels(galleries, systemState) if err != nil { return err } @@ -60,6 +68,15 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error { } func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { + + systemState, err := system.GetSystemState( + system.WithModelPath(mi.ModelsPath), + system.WithBackendPath(mi.BackendsPath), + ) + if err != nil { + return err + } + var galleries []config.Gallery if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil { log.Error().Err(err).Msg("unable to load galleries") @@ -86,7 +103,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { } } //startup.InstallModels() - models, err := gallery.AvailableGalleryModels(galleries, mi.ModelsPath) + models, err := gallery.AvailableGalleryModels(galleries, systemState) if err != nil { return err } @@ -94,7 +111,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { modelURI := downloader.URI(modelName) if !modelURI.LooksLikeOCI() { - model := gallery.FindGalleryElement(models, modelName, mi.ModelsPath) + model := gallery.FindGalleryElement(models, modelName) if model == nil { log.Error().Str("model", modelName).Msg("model not found") return err @@ -108,7 +125,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") } - err = startup.InstallModels(galleries, backendGalleries, mi.ModelsPath, mi.BackendsPath, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName) + err = startup.InstallModels(galleries, backendGalleries, systemState, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName) if err != nil { return err } diff --git a/core/cli/run.go b/core/cli/run.go index 377185a21bc4..eb94becd7d99 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -13,6 +13,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -22,6 +23,7 @@ type RunCMD struct { ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"` BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"` + BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"` UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"` @@ -77,12 +79,20 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { os.MkdirAll(r.BackendsPath, 0750) os.MkdirAll(r.ModelsPath, 0750) + systemState, err := system.GetSystemState( + system.WithBackendSystemPath(r.BackendsSystemPath), + system.WithModelPath(r.ModelsPath), + system.WithBackendPath(r.BackendsPath), + ) + if err != nil { + return err + } + opts := []config.AppOption{ config.WithConfigFile(r.ModelsConfigFile), config.WithJSONStringPreload(r.PreloadModels), config.WithYAMLConfigPreload(r.PreloadModelsConfig), - config.WithModelPath(r.ModelsPath), - config.WithBackendsPath(r.BackendsPath), + config.WithSystemState(systemState), config.WithContextSize(r.ContextSize), config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel), config.WithGeneratedContentDir(r.GeneratedContentPath), diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go index 1193b329f64c..a0f96b4fbcaf 100644 --- a/core/cli/soundgeneration.go +++ b/core/cli/soundgeneration.go @@ -12,6 +12,7 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog/log" ) @@ -56,6 +57,13 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { } text := strings.Join(t.Text, " ") + systemState, err := system.GetSystemState( + system.WithModelPath(t.ModelsPath), + ) + if err != nil { + return err + } + externalBackends := make(map[string]string) // split ":" to get backend name and the uri for _, v := range t.ExternalGRPCBackends { @@ -66,12 +74,12 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { } opts := &config.ApplicationConfig{ - ModelPath: t.ModelsPath, + SystemState: systemState, Context: context.Background(), GeneratedContentDir: outputDir, ExternalGRPCBackends: externalBackends, } - ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend) + ml := model.NewModelLoader(systemState, opts.SingleBackend) defer func() { err := ml.StopAllGRPC() @@ -80,7 +88,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { } }() - options := config.BackendConfig{} + options := config.ModelConfig{} options.SetDefaults() options.Backend = t.Backend options.Model = t.Model diff --git a/core/cli/transcript.go b/core/cli/transcript.go index 3e5ee6d44662..9c74425dbfe7 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -9,6 +9,7 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog/log" ) @@ -24,18 +25,24 @@ type TranscriptCMD struct { } func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { + systemState, err := system.GetSystemState( + system.WithModelPath(t.ModelsPath), + ) + if err != nil { + return err + } opts := &config.ApplicationConfig{ - ModelPath: t.ModelsPath, - Context: context.Background(), + SystemState: systemState, + Context: context.Background(), } - cl := config.NewBackendConfigLoader(t.ModelsPath) - ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend) - if err := cl.LoadBackendConfigsFromPath(t.ModelsPath); err != nil { + cl := config.NewModelConfigLoader(t.ModelsPath) + ml := model.NewModelLoader(systemState, opts.SingleBackend) + if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil { return err } - c, exists := cl.GetBackendConfig(t.Model) + c, exists := cl.GetModelConfig(t.Model) if !exists { return errors.New("model not found") } diff --git a/core/cli/tts.go b/core/cli/tts.go index 552fdf018881..ed0266714a02 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -11,6 +11,7 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog/log" ) @@ -34,12 +35,20 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error { text := strings.Join(t.Text, " ") + systemState, err := system.GetSystemState( + system.WithModelPath(t.ModelsPath), + ) + if err != nil { + return err + } + opts := &config.ApplicationConfig{ - ModelPath: t.ModelsPath, + SystemState: systemState, Context: context.Background(), GeneratedContentDir: outputDir, } - ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend) + + ml := model.NewModelLoader(systemState, opts.SingleBackend) defer func() { err := ml.StopAllGRPC() @@ -48,7 +57,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error { } }() - options := config.BackendConfig{} + options := config.ModelConfig{} options.SetDefaults() options.Backend = t.Backend options.Model = t.Model diff --git a/core/cli/util.go b/core/cli/util.go index 4e01ea2c25a8..aaffe1f9e954 100644 --- a/core/cli/util.go +++ b/core/cli/util.go @@ -17,6 +17,7 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/oci" + "github.com/mudler/LocalAI/pkg/system" ) type UtilCMD struct { @@ -108,6 +109,14 @@ func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error { } func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error { + + systemState, err := system.GetSystemState( + system.WithModelPath(hfscmd.ModelsPath), + ) + if err != nil { + return err + } + log.Info().Msg("LocalAI Security Scanner - This is BEST EFFORT functionality! Currently limited to huggingface models!") if len(hfscmd.ToScan) == 0 { log.Info().Msg("Checking all installed models against galleries") @@ -116,7 +125,7 @@ func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error { log.Error().Err(err).Msg("unable to load galleries") } - err := gallery.SafetyScanGalleryModels(galleries, hfscmd.ModelsPath) + err := gallery.SafetyScanGalleryModels(galleries, systemState) if err == nil { log.Info().Msg("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.") } else { @@ -150,17 +159,17 @@ func (uhcmd *UsecaseHeuristicCMD) Run(ctx *cliContext.Context) error { log.Error().Msg("ModelsPath is a required parameter") return fmt.Errorf("model path is a required parameter") } - bcl := config.NewBackendConfigLoader(uhcmd.ModelsPath) - err := bcl.LoadBackendConfig(uhcmd.ConfigName) + bcl := config.NewModelConfigLoader(uhcmd.ModelsPath) + err := bcl.ReadModelConfig(uhcmd.ConfigName) if err != nil { log.Error().Err(err).Str("ConfigName", uhcmd.ConfigName).Msg("error while loading backend") return err } - bc, exists := bcl.GetBackendConfig(uhcmd.ConfigName) + bc, exists := bcl.GetModelConfig(uhcmd.ConfigName) if !exists { log.Error().Str("ConfigName", uhcmd.ConfigName).Msg("ConfigName not found") } - for name, uc := range config.GetAllBackendConfigUsecases() { + for name, uc := range config.GetAllModelConfigUsecases() { if bc.HasUsecases(uc) { log.Info().Str("Usecase", name) } diff --git a/core/cli/worker/worker.go b/core/cli/worker/worker.go index 33813db06422..77bb35a1b525 100644 --- a/core/cli/worker/worker.go +++ b/core/cli/worker/worker.go @@ -1,8 +1,9 @@ package worker type WorkerFlags struct { - BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"` - ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"` + BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"` + BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"` + ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"` } type Worker struct { diff --git a/core/cli/worker/worker_llamacpp.go b/core/cli/worker/worker_llamacpp.go index 95ca9c6bda26..8f8c8cd09fce 100644 --- a/core/cli/worker/worker_llamacpp.go +++ b/core/cli/worker/worker_llamacpp.go @@ -10,6 +10,7 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog/log" ) @@ -21,20 +22,19 @@ const ( llamaCPPRPCBinaryName = "llama-cpp-rpc-server" ) -func findLLamaCPPBackend(backendSystemPath string) (string, error) { - backends, err := gallery.ListSystemBackends(backendSystemPath) +func findLLamaCPPBackend(systemState *system.SystemState) (string, error) { + backends, err := gallery.ListSystemBackends(systemState) if err != nil { log.Warn().Msgf("Failed listing system backends: %s", err) return "", err } log.Debug().Msgf("System backends: %v", backends) - backendPath := "" backend, ok := backends.Get("llama-cpp") if !ok { return "", errors.New("llama-cpp backend not found, install it first") } - backendPath = filepath.Dir(backend.RunFile) + backendPath := filepath.Dir(backend.RunFile) if backendPath == "" { return "", errors.New("llama-cpp backend not found, install it first") @@ -54,7 +54,14 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error { return fmt.Errorf("usage: local-ai worker llama-cpp-rpc -- ") } - grpcProcess, err := findLLamaCPPBackend(r.BackendsPath) + systemState, err := system.GetSystemState( + system.WithBackendPath(r.BackendsPath), + system.WithBackendSystemPath(r.BackendsSystemPath), + ) + if err != nil { + return err + } + grpcProcess, err := findLLamaCPPBackend(systemState) if err != nil { return err } diff --git a/core/cli/worker/worker_p2p.go b/core/cli/worker/worker_p2p.go index 1533de4e5074..c65d80c9f2fb 100644 --- a/core/cli/worker/worker_p2p.go +++ b/core/cli/worker/worker_p2p.go @@ -10,6 +10,7 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/LocalAI/pkg/system" "github.com/phayes/freeport" "github.com/rs/zerolog/log" ) @@ -25,6 +26,14 @@ type P2P struct { func (r *P2P) Run(ctx *cliContext.Context) error { + systemState, err := system.GetSystemState( + system.WithBackendPath(r.BackendsPath), + system.WithBackendSystemPath(r.BackendsSystemPath), + ) + if err != nil { + return err + } + // Check if the token is set // as we always need it. if r.Token == "" { @@ -60,7 +69,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error { for { log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port) - grpcProcess, err := findLLamaCPPBackend(r.BackendsPath) + grpcProcess, err := findLLamaCPPBackend(systemState) if err != nil { log.Error().Err(err).Msg("Failed to find llama-cpp-rpc-server") return diff --git a/core/config/application_config.go b/core/config/application_config.go index ab2526d2691e..775e30f66034 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -6,6 +6,7 @@ import ( "regexp" "time" + "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/rs/zerolog/log" ) @@ -13,8 +14,7 @@ import ( type ApplicationConfig struct { Context context.Context ConfigFile string - ModelPath string - BackendsPath string + SystemState *system.SystemState ExternalBackends []string UploadLimitMB, Threads, ContextSize int F16 bool @@ -86,15 +86,9 @@ func WithModelsURL(urls ...string) AppOption { } } -func WithModelPath(path string) AppOption { +func WithSystemState(state *system.SystemState) AppOption { return func(o *ApplicationConfig) { - o.ModelPath = path - } -} - -func WithBackendsPath(path string) AppOption { - return func(o *ApplicationConfig) { - o.BackendsPath = path + o.SystemState = state } } @@ -379,7 +373,7 @@ func (o *ApplicationConfig) ToConfigLoaderOptions() []ConfigLoaderOption { LoadOptionDebug(o.Debug), LoadOptionF16(o.F16), LoadOptionThreads(o.Threads), - ModelPath(o.ModelPath), + ModelPath(o.SystemState.Model.ModelsPath), } } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 739f876a0c1a..e39e828fa51f 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -24,20 +24,20 @@ type TTSConfig struct { AudioPath string `yaml:"audio_path"` } -type BackendConfig struct { +type ModelConfig struct { schema.PredictionOptions `yaml:"parameters"` Name string `yaml:"name"` - F16 *bool `yaml:"f16"` - Threads *int `yaml:"threads"` - Debug *bool `yaml:"debug"` - Roles map[string]string `yaml:"roles"` - Embeddings *bool `yaml:"embeddings"` - Backend string `yaml:"backend"` - TemplateConfig TemplateConfig `yaml:"template"` - KnownUsecaseStrings []string `yaml:"known_usecases"` - KnownUsecases *BackendConfigUsecases `yaml:"-"` - Pipeline Pipeline `yaml:"pipeline"` + F16 *bool `yaml:"f16"` + Threads *int `yaml:"threads"` + Debug *bool `yaml:"debug"` + Roles map[string]string `yaml:"roles"` + Embeddings *bool `yaml:"embeddings"` + Backend string `yaml:"backend"` + TemplateConfig TemplateConfig `yaml:"template"` + KnownUsecaseStrings []string `yaml:"known_usecases"` + KnownUsecases *ModelConfigUsecases `yaml:"-"` + Pipeline Pipeline `yaml:"pipeline"` PromptStrings, InputStrings []string `yaml:"-"` InputToken [][]int `yaml:"-"` @@ -217,18 +217,18 @@ type TemplateConfig struct { ReplyPrefix string `yaml:"reply_prefix"` } -func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error { - type BCAlias BackendConfig +func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error { + type BCAlias ModelConfig var aux BCAlias if err := value.Decode(&aux); err != nil { return err } - *c = BackendConfig(aux) + *c = ModelConfig(aux) c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings) // Make sure the usecases are valid, we rewrite with what we identified c.KnownUsecaseStrings = []string{} - for k, usecase := range GetAllBackendConfigUsecases() { + for k, usecase := range GetAllModelConfigUsecases() { if c.HasUsecases(usecase) { c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k) } @@ -236,25 +236,25 @@ func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error { return nil } -func (c *BackendConfig) SetFunctionCallString(s string) { +func (c *ModelConfig) SetFunctionCallString(s string) { c.functionCallString = s } -func (c *BackendConfig) SetFunctionCallNameString(s string) { +func (c *ModelConfig) SetFunctionCallNameString(s string) { c.functionCallNameString = s } -func (c *BackendConfig) ShouldUseFunctions() bool { +func (c *ModelConfig) ShouldUseFunctions() bool { return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) } -func (c *BackendConfig) ShouldCallSpecificFunction() bool { +func (c *ModelConfig) ShouldCallSpecificFunction() bool { return len(c.functionCallNameString) > 0 } // MMProjFileName returns the filename of the MMProj file // If the MMProj is a URL, it will return the MD5 of the URL which is the filename -func (c *BackendConfig) MMProjFileName() string { +func (c *ModelConfig) MMProjFileName() string { uri := downloader.URI(c.MMProj) if uri.LooksLikeURL() { f, _ := uri.FilenameFromUrl() @@ -264,19 +264,19 @@ func (c *BackendConfig) MMProjFileName() string { return c.MMProj } -func (c *BackendConfig) IsMMProjURL() bool { +func (c *ModelConfig) IsMMProjURL() bool { uri := downloader.URI(c.MMProj) return uri.LooksLikeURL() } -func (c *BackendConfig) IsModelURL() bool { +func (c *ModelConfig) IsModelURL() bool { uri := downloader.URI(c.Model) return uri.LooksLikeURL() } // ModelFileName returns the filename of the model // If the model is a URL, it will return the MD5 of the URL which is the filename -func (c *BackendConfig) ModelFileName() string { +func (c *ModelConfig) ModelFileName() string { uri := downloader.URI(c.Model) if uri.LooksLikeURL() { f, _ := uri.FilenameFromUrl() @@ -286,7 +286,7 @@ func (c *BackendConfig) ModelFileName() string { return c.Model } -func (c *BackendConfig) FunctionToCall() string { +func (c *ModelConfig) FunctionToCall() string { if c.functionCallNameString != "" && c.functionCallNameString != "none" && c.functionCallNameString != "auto" { return c.functionCallNameString @@ -295,7 +295,7 @@ func (c *BackendConfig) FunctionToCall() string { return c.functionCallString } -func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { +func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) { lo := &LoadOptions{} lo.Apply(opts...) @@ -411,7 +411,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { guessDefaultsFromFile(cfg, lo.modelPath, ctx) } -func (c *BackendConfig) Validate() bool { +func (c *ModelConfig) Validate() bool { downloadedFileNames := []string{} for _, f := range c.DownloadFiles { downloadedFileNames = append(downloadedFileNames, f.Filename) @@ -438,34 +438,34 @@ func (c *BackendConfig) Validate() bool { return true } -func (c *BackendConfig) HasTemplate() bool { +func (c *ModelConfig) HasTemplate() bool { return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" } -type BackendConfigUsecases int +type ModelConfigUsecases int const ( - FLAG_ANY BackendConfigUsecases = 0b000000000000 - FLAG_CHAT BackendConfigUsecases = 0b000000000001 - FLAG_COMPLETION BackendConfigUsecases = 0b000000000010 - FLAG_EDIT BackendConfigUsecases = 0b000000000100 - FLAG_EMBEDDINGS BackendConfigUsecases = 0b000000001000 - FLAG_RERANK BackendConfigUsecases = 0b000000010000 - FLAG_IMAGE BackendConfigUsecases = 0b000000100000 - FLAG_TRANSCRIPT BackendConfigUsecases = 0b000001000000 - FLAG_TTS BackendConfigUsecases = 0b000010000000 - FLAG_SOUND_GENERATION BackendConfigUsecases = 0b000100000000 - FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000 - FLAG_VAD BackendConfigUsecases = 0b010000000000 - FLAG_VIDEO BackendConfigUsecases = 0b100000000000 - FLAG_DETECTION BackendConfigUsecases = 0b1000000000000 + FLAG_ANY ModelConfigUsecases = 0b000000000000 + FLAG_CHAT ModelConfigUsecases = 0b000000000001 + FLAG_COMPLETION ModelConfigUsecases = 0b000000000010 + FLAG_EDIT ModelConfigUsecases = 0b000000000100 + FLAG_EMBEDDINGS ModelConfigUsecases = 0b000000001000 + FLAG_RERANK ModelConfigUsecases = 0b000000010000 + FLAG_IMAGE ModelConfigUsecases = 0b000000100000 + FLAG_TRANSCRIPT ModelConfigUsecases = 0b000001000000 + FLAG_TTS ModelConfigUsecases = 0b000010000000 + FLAG_SOUND_GENERATION ModelConfigUsecases = 0b000100000000 + FLAG_TOKENIZE ModelConfigUsecases = 0b001000000000 + FLAG_VAD ModelConfigUsecases = 0b010000000000 + FLAG_VIDEO ModelConfigUsecases = 0b100000000000 + FLAG_DETECTION ModelConfigUsecases = 0b1000000000000 // Common Subsets - FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT + FLAG_LLM ModelConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT ) -func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases { - return map[string]BackendConfigUsecases{ +func GetAllModelConfigUsecases() map[string]ModelConfigUsecases { + return map[string]ModelConfigUsecases{ "FLAG_ANY": FLAG_ANY, "FLAG_CHAT": FLAG_CHAT, "FLAG_COMPLETION": FLAG_COMPLETION, @@ -488,12 +488,12 @@ func stringToFlag(s string) string { return "FLAG_" + strings.ToUpper(s) } -func GetUsecasesFromYAML(input []string) *BackendConfigUsecases { +func GetUsecasesFromYAML(input []string) *ModelConfigUsecases { if len(input) == 0 { return nil } result := FLAG_ANY - flags := GetAllBackendConfigUsecases() + flags := GetAllModelConfigUsecases() for _, str := range input { flag, exists := flags[stringToFlag(str)] if exists { @@ -503,8 +503,8 @@ func GetUsecasesFromYAML(input []string) *BackendConfigUsecases { return &result } -// HasUsecases examines a BackendConfig and determines which endpoints have a chance of success. -func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool { +// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success. +func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool { if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) { return true } @@ -514,7 +514,7 @@ func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool { // GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at. // In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half. // This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently. -func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool { +func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool { if (u & FLAG_CHAT) == FLAG_CHAT { if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" { return false diff --git a/core/config/backend_config_filter.go b/core/config/backend_config_filter.go index f1eb24883b0c..4252ebfa3e8c 100644 --- a/core/config/backend_config_filter.go +++ b/core/config/backend_config_filter.go @@ -2,11 +2,11 @@ package config import "regexp" -type BackendConfigFilterFn func(string, *BackendConfig) bool +type ModelConfigFilterFn func(string, *ModelConfig) bool -func NoFilterFn(_ string, _ *BackendConfig) bool { return true } +func NoFilterFn(_ string, _ *ModelConfig) bool { return true } -func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) { +func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) { if filter == "" { return NoFilterFn, nil } @@ -14,7 +14,7 @@ func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) { if err != nil { return nil, err } - return func(name string, config *BackendConfig) bool { + return func(name string, config *ModelConfig) bool { if config != nil { return rxp.MatchString(config.Name) } @@ -22,11 +22,11 @@ func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) { }, nil } -func BuildUsecaseFilterFn(usecases BackendConfigUsecases) BackendConfigFilterFn { +func BuildUsecaseFilterFn(usecases ModelConfigUsecases) ModelConfigFilterFn { if usecases == FLAG_ANY { return NoFilterFn } - return func(name string, config *BackendConfig) bool { + return func(name string, config *ModelConfig) bool { if config == nil { return false // TODO: Potentially make this a param, for now, no known usecase to include } diff --git a/core/config/backend_config_loader.go b/core/config/backend_config_loader.go index 410810f4e0ba..5b16cadc8758 100644 --- a/core/config/backend_config_loader.go +++ b/core/config/backend_config_loader.go @@ -18,15 +18,15 @@ import ( "gopkg.in/yaml.v3" ) -type BackendConfigLoader struct { - configs map[string]BackendConfig +type ModelConfigLoader struct { + configs map[string]ModelConfig modelPath string sync.Mutex } -func NewBackendConfigLoader(modelPath string) *BackendConfigLoader { - return &BackendConfigLoader{ - configs: make(map[string]BackendConfig), +func NewModelConfigLoader(modelPath string) *ModelConfigLoader { + return &ModelConfigLoader{ + configs: make(map[string]ModelConfig), modelPath: modelPath, } } @@ -77,14 +77,14 @@ func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { } // TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig -func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { - c := &[]*BackendConfig{} +func readMultipleModelConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*ModelConfig, error) { + c := &[]*ModelConfig{} f, err := os.ReadFile(file) if err != nil { - return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err) + return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot read config file %q: %w", file, err) } if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err) + return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot unmarshal config file %q: %w", file, err) } for _, cc := range *c { @@ -94,17 +94,17 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) return *c, nil } -func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { +func readModelConfigFromFile(file string, opts ...ConfigLoaderOption) (*ModelConfig, error) { lo := &LoadOptions{} lo.Apply(opts...) - c := &BackendConfig{} + c := &ModelConfig{} f, err := os.ReadFile(file) if err != nil { - return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err) + return nil, fmt.Errorf("readModelConfigFromFile cannot read config file %q: %w", file, err) } if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err) + return nil, fmt.Errorf("readModelConfigFromFile cannot unmarshal config file %q: %w", file, err) } c.SetDefaults(opts...) @@ -112,10 +112,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen } // Load a config file for a model -func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { +func (bcl *ModelConfigLoader) LoadModelConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*ModelConfig, error) { // Load a config file if present after the model name - cfg := &BackendConfig{ + cfg := &ModelConfig{ PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: modelName, @@ -123,19 +123,19 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath }, } - cfgExisting, exists := bcl.GetBackendConfig(modelName) + cfgExisting, exists := bcl.GetModelConfig(modelName) if exists { cfg = &cfgExisting } else { // Try loading a model config file modelConfig := filepath.Join(modelPath, modelName+".yaml") if _, err := os.Stat(modelConfig); err == nil { - if err := bcl.LoadBackendConfig( + if err := bcl.ReadModelConfig( modelConfig, opts..., ); err != nil { return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) } - cfgExisting, exists = bcl.GetBackendConfig(modelName) + cfgExisting, exists = bcl.GetModelConfig(modelName) if exists { cfg = &cfgExisting } @@ -147,20 +147,20 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath return cfg, nil } -func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) { - return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath, +func (bcl *ModelConfigLoader) LoadModelConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*ModelConfig, error) { + return bcl.LoadModelConfigFileByName(modelName, appConfig.SystemState.Model.ModelsPath, LoadOptionDebug(appConfig.Debug), LoadOptionThreads(appConfig.Threads), LoadOptionContextSize(appConfig.ContextSize), LoadOptionF16(appConfig.F16), - ModelPath(appConfig.ModelPath)) + ModelPath(appConfig.SystemState.Model.ModelsPath)) } // This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile -func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error { +func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, opts ...ConfigLoaderOption) error { bcl.Lock() defer bcl.Unlock() - c, err := readMultipleBackendConfigsFromFile(file, opts...) + c, err := readMultipleModelConfigsFromFile(file, opts...) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } @@ -173,12 +173,12 @@ func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string return nil } -func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { +func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderOption) error { bcl.Lock() defer bcl.Unlock() - c, err := readBackendConfigFromFile(file, opts...) + c, err := readModelConfigFromFile(file, opts...) if err != nil { - return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err) + return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err) } if c.Validate() { @@ -190,17 +190,17 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa return nil } -func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { +func (bcl *ModelConfigLoader) GetModelConfig(m string) (ModelConfig, bool) { bcl.Lock() defer bcl.Unlock() v, exists := bcl.configs[m] return v, exists } -func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { +func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig { bcl.Lock() defer bcl.Unlock() - var res []BackendConfig + var res []ModelConfig for _, v := range bcl.configs { res = append(res, v) } @@ -212,10 +212,10 @@ func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { return res } -func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFilterFn) []BackendConfig { +func (bcl *ModelConfigLoader) GetModelConfigsByFilter(filter ModelConfigFilterFn) []ModelConfig { bcl.Lock() defer bcl.Unlock() - var res []BackendConfig + var res []ModelConfig if filter == nil { filter = NoFilterFn @@ -232,14 +232,14 @@ func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFi return res } -func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) { +func (bcl *ModelConfigLoader) RemoveModelConfig(m string) { bcl.Lock() defer bcl.Unlock() delete(bcl.configs, m) } // Preload prepare models if they are not local but url or huggingface repositories -func (bcl *BackendConfigLoader) Preload(modelPath string) error { +func (bcl *ModelConfigLoader) Preload(modelPath string) error { bcl.Lock() defer bcl.Unlock() @@ -330,15 +330,15 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error { return nil } -// LoadBackendConfigsFromPath reads all the configurations of the models from a path +// LoadModelConfigsFromPath reads all the configurations of the models from a path // (non-recursive) -func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { +func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...ConfigLoaderOption) error { bcl.Lock() defer bcl.Unlock() entries, err := os.ReadDir(path) if err != nil { - return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err) + return fmt.Errorf("LoadModelConfigsFromPath cannot read directory '%s': %w", path, err) } files := make([]fs.FileInfo, 0, len(entries)) for _, entry := range entries { @@ -354,9 +354,9 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ... strings.HasPrefix(file.Name(), ".") { continue } - c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...) + c, err := readModelConfigFromFile(filepath.Join(path, file.Name()), opts...) if err != nil { - log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file") + log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadModelConfigsFromPath cannot read config file") continue } if c.Validate() { diff --git a/core/config/backend_config_test.go b/core/config/backend_config_test.go index 6afc6a5d5b39..9d49e270c751 100644 --- a/core/config/backend_config_test.go +++ b/core/config/backend_config_test.go @@ -25,7 +25,7 @@ known_usecases: - COMPLETION `) Expect(err).ToNot(HaveOccurred()) - config, err := readBackendConfigFromFile(tmp.Name()) + config, err := readModelConfigFromFile(tmp.Name()) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) Expect(config.Validate()).To(BeFalse()) @@ -41,7 +41,7 @@ backend: "foo-bar" parameters: model: "foo-bar"`) Expect(err).ToNot(HaveOccurred()) - config, err := readBackendConfigFromFile(tmp.Name()) + config, err := readModelConfigFromFile(tmp.Name()) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml @@ -58,7 +58,7 @@ parameters: defer os.Remove(tmp.Name()) _, err = io.Copy(tmp, resp.Body) Expect(err).To(BeNil()) - config, err = readBackendConfigFromFile(tmp.Name()) + config, err = readModelConfigFromFile(tmp.Name()) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml @@ -68,12 +68,12 @@ parameters: }) It("Properly handles backend usecase matching", func() { - a := BackendConfig{ + a := ModelConfig{ Name: "a", } Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially. - b := BackendConfig{ + b := ModelConfig{ Name: "b", Backend: "stablediffusion", } @@ -81,7 +81,7 @@ parameters: Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue()) Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse()) - c := BackendConfig{ + c := ModelConfig{ Name: "c", Backend: "llama-cpp", TemplateConfig: TemplateConfig{ @@ -93,7 +93,7 @@ parameters: Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse()) Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue()) - d := BackendConfig{ + d := ModelConfig{ Name: "d", Backend: "llama-cpp", TemplateConfig: TemplateConfig{ @@ -107,7 +107,7 @@ parameters: Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue()) trueValue := true - e := BackendConfig{ + e := ModelConfig{ Name: "e", Backend: "llama-cpp", TemplateConfig: TemplateConfig{ @@ -122,7 +122,7 @@ parameters: Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse()) Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue()) - f := BackendConfig{ + f := ModelConfig{ Name: "f", Backend: "piper", } @@ -130,7 +130,7 @@ parameters: Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue()) Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse()) - g := BackendConfig{ + g := ModelConfig{ Name: "g", Backend: "whisper", } @@ -138,7 +138,7 @@ parameters: Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue()) Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse()) - h := BackendConfig{ + h := ModelConfig{ Name: "h", Backend: "transformers-musicgen", } @@ -148,7 +148,7 @@ parameters: Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue()) knownUsecases := FLAG_CHAT | FLAG_COMPLETION - i := BackendConfig{ + i := ModelConfig{ Name: "i", Backend: "whisper", // Earlier test checks parsing, this just needs to set final values diff --git a/core/config/config_test.go b/core/config/config_test.go index 85f18eaea3cb..f127f8f56830 100644 --- a/core/config/config_test.go +++ b/core/config/config_test.go @@ -16,7 +16,7 @@ var _ = Describe("Test cases for config related functions", func() { Context("Test Read configuration functions", func() { configFile = os.Getenv("CONFIG_FILE") It("Test readConfigFile", func() { - config, err := readMultipleBackendConfigsFromFile(configFile) + config, err := readMultipleModelConfigsFromFile(configFile) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml @@ -26,11 +26,11 @@ var _ = Describe("Test cases for config related functions", func() { It("Test LoadConfigs", func() { - bcl := NewBackendConfigLoader(os.Getenv("MODELS_PATH")) - err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH")) + bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH")) + err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH")) Expect(err).To(BeNil()) - configs := bcl.GetAllBackendConfigs() + configs := bcl.GetAllModelsConfigs() loadedModelNames := []string{} for _, v := range configs { loadedModelNames = append(loadedModelNames, v.Name) @@ -51,10 +51,10 @@ var _ = Describe("Test cases for config related functions", func() { It("Test new loadconfig", func() { - bcl := NewBackendConfigLoader(os.Getenv("MODELS_PATH")) - err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH")) + bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH")) + err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH")) Expect(err).To(BeNil()) - configs := bcl.GetAllBackendConfigs() + configs := bcl.GetAllModelsConfigs() loadedModelNames := []string{} for _, v := range configs { loadedModelNames = append(loadedModelNames, v.Name) @@ -90,14 +90,14 @@ options: err = os.WriteFile(modelFile, []byte(model), 0644) Expect(err).ToNot(HaveOccurred()) - err = bcl.LoadBackendConfigsFromPath(tmpdir) + err = bcl.LoadModelConfigsFromPath(tmpdir) Expect(err).ToNot(HaveOccurred()) - configs = bcl.GetAllBackendConfigs() + configs = bcl.GetAllModelsConfigs() Expect(len(configs)).ToNot(Equal(totalModels)) loadedModelNames = []string{} - var testModel BackendConfig + var testModel ModelConfig for _, v := range configs { loadedModelNames = append(loadedModelNames, v.Name) if v.Name == "test-model" { diff --git a/core/config/gguf.go b/core/config/gguf.go index 99be69be8c10..edc7d523083f 100644 --- a/core/config/gguf.go +++ b/core/config/gguf.go @@ -146,7 +146,7 @@ var knownTemplates = map[string]familyType{ `{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}`: Mistral03, } -func guessGGUFFromFile(cfg *BackendConfig, f *gguf.GGUFFile, defaultCtx int) { +func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) { if defaultCtx == 0 && cfg.ContextSize == nil { ctxSize := f.EstimateLLaMACppRun().ContextSize diff --git a/core/config/guesser.go b/core/config/guesser.go index 48b0fe16404c..b2e1e6a6a5d3 100644 --- a/core/config/guesser.go +++ b/core/config/guesser.go @@ -8,7 +8,7 @@ import ( "github.com/rs/zerolog/log" ) -func guessDefaultsFromFile(cfg *BackendConfig, modelPath string, defaultCtx int) { +func guessDefaultsFromFile(cfg *ModelConfig, modelPath string, defaultCtx int) { if os.Getenv("LOCALAI_DISABLE_GUESSING") == "true" { log.Debug().Msgf("guessDefaultsFromFile: %s", "guessing disabled with LOCALAI_DISABLE_GUESSING") return diff --git a/core/gallery/backends.go b/core/gallery/backends.go index 2d7c98a6de6f..f7af7bbae8d7 100644 --- a/core/gallery/backends.go +++ b/core/gallery/backends.go @@ -59,10 +59,10 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error { } // Installs a model from the gallery -func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, basePath string, downloadStatus func(string, string, string, float64), force bool) error { +func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, downloadStatus func(string, string, string, float64), force bool) error { if !force { // check if we already have the backend installed - backends, err := ListSystemBackends(basePath) + backends, err := ListSystemBackends(systemState) if err != nil { return err } @@ -77,12 +77,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S log.Debug().Interface("galleries", galleries).Str("name", name).Msg("Installing backend from gallery") - backends, err := AvailableBackends(galleries, basePath) + backends, err := AvailableBackends(galleries, systemState) if err != nil { return err } - backend := FindGalleryElement(backends, name, basePath) + backend := FindGalleryElement(backends, name) if backend == nil { return fmt.Errorf("no backend found with name %q", name) } @@ -99,12 +99,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend") // Then, let's install the best backend - if err := InstallBackend(basePath, bestBackend, downloadStatus); err != nil { + if err := InstallBackend(systemState, bestBackend, downloadStatus); err != nil { return err } // we need now to create a path for the meta backend, with the alias to the installed ones so it can be used to remove it - metaBackendPath := filepath.Join(basePath, name) + metaBackendPath := filepath.Join(systemState.Backend.BackendsPath, name) if err := os.MkdirAll(metaBackendPath, 0750); err != nil { return fmt.Errorf("failed to create meta backend path %q: %v", metaBackendPath, err) } @@ -124,12 +124,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S return nil } - return InstallBackend(basePath, backend, downloadStatus) + return InstallBackend(systemState, backend, downloadStatus) } -func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error { +func InstallBackend(systemState *system.SystemState, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error { // Create base path if it doesn't exist - err := os.MkdirAll(basePath, 0750) + err := os.MkdirAll(systemState.Backend.BackendsPath, 0750) if err != nil { return fmt.Errorf("failed to create base path: %v", err) } @@ -139,7 +139,7 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func } name := config.Name - backendPath := filepath.Join(basePath, name) + backendPath := filepath.Join(systemState.Backend.BackendsPath, name) err = os.MkdirAll(backendPath, 0750) if err != nil { return fmt.Errorf("failed to create base path: %v", err) @@ -188,14 +188,28 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func return nil } -func DeleteBackendFromSystem(basePath string, name string) error { - backendDirectory := filepath.Join(basePath, name) +func DeleteBackendFromSystem(systemState *system.SystemState, name string) error { + backends, err := ListSystemBackends(systemState) + if err != nil { + return err + } + + backend, ok := backends.Get(name) + if !ok { + return fmt.Errorf("backend %q not found", name) + } + + if backend.IsSystem { + return fmt.Errorf("system backend %q cannot be deleted", name) + } + + backendDirectory := filepath.Join(systemState.Backend.BackendsPath, name) // check if the backend dir exists if _, err := os.Stat(backendDirectory); os.IsNotExist(err) { // if doesn't exist, it might be an alias, so we need to check if we have a matching alias in // all the backends in the basePath - backends, err := os.ReadDir(basePath) + backends, err := os.ReadDir(systemState.Backend.BackendsPath) if err != nil { return err } @@ -203,12 +217,12 @@ func DeleteBackendFromSystem(basePath string, name string) error { for _, backend := range backends { if backend.IsDir() { - metadata, err := readBackendMetadata(filepath.Join(basePath, backend.Name())) + metadata, err := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, backend.Name())) if err != nil { return err } if metadata != nil && metadata.Alias == name { - backendDirectory = filepath.Join(basePath, backend.Name()) + backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name()) foundBackend = true break } @@ -228,7 +242,7 @@ func DeleteBackendFromSystem(basePath string, name string) error { } if metadata != nil && metadata.MetaBackendFor != "" { - metaBackendDirectory := filepath.Join(basePath, metadata.MetaBackendFor) + metaBackendDirectory := filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor) log.Debug().Str("backendDirectory", metaBackendDirectory).Msg("Deleting meta backend") if _, err := os.Stat(metaBackendDirectory); os.IsNotExist(err) { return fmt.Errorf("meta backend %q not found", metadata.MetaBackendFor) @@ -243,6 +257,7 @@ type SystemBackend struct { Name string RunFile string IsMeta bool + IsSystem bool Metadata *BackendMetadata } @@ -266,30 +281,51 @@ func (b SystemBackends) GetAll() []SystemBackend { return backends } -func ListSystemBackends(basePath string) (SystemBackends, error) { - potentialBackends, err := os.ReadDir(basePath) +func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error) { + potentialBackends, err := os.ReadDir(systemState.Backend.BackendsPath) if err != nil { return nil, err } backends := make(SystemBackends) + systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath) + if err == nil { + // system backends are special, they are provided by the system and not managed by LocalAI + for _, systemBackend := range systemBackends { + if systemBackend.IsDir() { + systemBackendRunFile := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile) + if _, err := os.Stat(systemBackendRunFile); err == nil { + backends[systemBackend.Name()] = SystemBackend{ + Name: systemBackend.Name(), + RunFile: filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile), + IsMeta: false, + IsSystem: true, + Metadata: nil, + } + } + } + } + } else { + log.Warn().Err(err).Msg("Failed to read system backends, but that's ok, we will just use the backends managed by LocalAI") + } + for _, potentialBackend := range potentialBackends { if potentialBackend.IsDir() { - potentialBackendRunFile := filepath.Join(basePath, potentialBackend.Name(), runFile) + potentialBackendRunFile := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), runFile) var metadata *BackendMetadata // If metadata file does not exist, we just use the directory name // and we do not fill the other metadata (such as potential backend Aliases) - metadataFilePath := filepath.Join(basePath, potentialBackend.Name(), metadataFile) + metadataFilePath := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), metadataFile) if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) { metadata = &BackendMetadata{ Name: potentialBackend.Name(), } } else { // Check for alias in metadata - metadata, err = readBackendMetadata(filepath.Join(basePath, potentialBackend.Name())) + metadata, err = readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name())) if err != nil { return nil, err } @@ -323,7 +359,7 @@ func ListSystemBackends(basePath string) (SystemBackends, error) { if metadata.MetaBackendFor != "" { backends[metadata.Name] = SystemBackend{ Name: metadata.Name, - RunFile: filepath.Join(basePath, metadata.MetaBackendFor, runFile), + RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile), IsMeta: true, Metadata: metadata, } @@ -334,8 +370,8 @@ func ListSystemBackends(basePath string) (SystemBackends, error) { return backends, nil } -func RegisterBackends(basePath string, modelLoader *model.ModelLoader) error { - backends, err := ListSystemBackends(basePath) +func RegisterBackends(systemState *system.SystemState, modelLoader *model.ModelLoader) error { + backends, err := ListSystemBackends(systemState) if err != nil { return err } diff --git a/core/gallery/backends_test.go b/core/gallery/backends_test.go index 39a0521c7396..ea3462011eff 100644 --- a/core/gallery/backends_test.go +++ b/core/gallery/backends_test.go @@ -43,13 +43,21 @@ var _ = Describe("Gallery Backends", func() { Describe("InstallBackendFromGallery", func() { It("should return error when backend is not found", func() { - err := InstallBackendFromGallery(galleries, nil, "non-existent", tempDir, nil, true) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) + Expect(err).NotTo(HaveOccurred()) + err = InstallBackendFromGallery(galleries, systemState, "non-existent", nil, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\"")) }) It("should install backend from gallery", func() { - err := InstallBackendFromGallery(galleries, nil, "test-backend", tempDir, nil, true) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) + Expect(err).NotTo(HaveOccurred()) + err = InstallBackendFromGallery(galleries, systemState, "test-backend", nil, true) Expect(err).ToNot(HaveOccurred()) Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile()) }) @@ -220,26 +228,32 @@ var _ = Describe("Gallery Backends", func() { Expect(err).NotTo(HaveOccurred()) // Test with NVIDIA system state - nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000} - err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true) + nvidiaSystemState := &system.SystemState{ + GPUVendor: "nvidia", + VRAM: 1000000000000, + Backend: system.Backend{BackendsPath: tempDir}, + } + err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true) Expect(err).NotTo(HaveOccurred()) metaBackendPath := filepath.Join(tempDir, "meta-backend") Expect(metaBackendPath).To(BeADirectory()) - metaBackendPath = filepath.Join(tempDir, "meta-backend", "metadata.json") - Expect(metaBackendPath).To(BeARegularFile()) - concreteBackendPath := filepath.Join(tempDir, "nvidia-backend") Expect(concreteBackendPath).To(BeADirectory()) - allBackends, err := ListSystemBackends(tempDir) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) Expect(err).NotTo(HaveOccurred()) - Expect(allBackends.Exists("meta-backend")).To(BeTrue()) - Expect(allBackends.Exists("nvidia-backend")).To(BeTrue()) + + allBackends, err := ListSystemBackends(systemState) + Expect(err).NotTo(HaveOccurred()) + Expect(allBackends).To(HaveKey("meta-backend")) + Expect(allBackends).To(HaveKey("nvidia-backend")) // Delete meta backend by name - err = DeleteBackendFromSystem(tempDir, "meta-backend") + err = DeleteBackendFromSystem(systemState, "meta-backend") Expect(err).NotTo(HaveOccurred()) // Verify meta backend directory is deleted @@ -294,8 +308,12 @@ var _ = Describe("Gallery Backends", func() { Expect(err).NotTo(HaveOccurred()) // Test with NVIDIA system state - nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000} - err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true) + nvidiaSystemState := &system.SystemState{ + GPUVendor: "nvidia", + VRAM: 1000000000000, + Backend: system.Backend{BackendsPath: tempDir}, + } + err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true) Expect(err).NotTo(HaveOccurred()) metaBackendPath := filepath.Join(tempDir, "meta-backend") @@ -304,19 +322,22 @@ var _ = Describe("Gallery Backends", func() { concreteBackendPath := filepath.Join(tempDir, "nvidia-backend") Expect(concreteBackendPath).To(BeADirectory()) - allBackends, err := ListSystemBackends(tempDir) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) Expect(err).NotTo(HaveOccurred()) - Expect(allBackends.Exists("meta-backend")).To(BeTrue()) - Expect(allBackends.Exists("nvidia-backend")).To(BeTrue()) - backend, ok := allBackends.Get("meta-backend") - Expect(ok).To(BeTrue()) - Expect(backend.Metadata.MetaBackendFor).To(Equal("nvidia-backend")) - Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh"))) - Expect(backend.IsMeta).To(BeTrue()) + allBackends, err := ListSystemBackends(systemState) + Expect(err).NotTo(HaveOccurred()) + Expect(allBackends).To(HaveKey("meta-backend")) + Expect(allBackends).To(HaveKey("nvidia-backend")) + mback, exists := allBackends.Get("meta-backend") + Expect(exists).To(BeTrue()) + Expect(mback.IsMeta).To(BeTrue()) + Expect(mback.Metadata.MetaBackendFor).To(Equal("nvidia-backend")) // Delete meta backend by name - err = DeleteBackendFromSystem(tempDir, "meta-backend") + err = DeleteBackendFromSystem(systemState, "meta-backend") Expect(err).NotTo(HaveOccurred()) // Verify meta backend directory is deleted @@ -371,8 +392,12 @@ var _ = Describe("Gallery Backends", func() { Expect(err).NotTo(HaveOccurred()) // Test with NVIDIA system state - nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000} - err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true) + nvidiaSystemState := &system.SystemState{ + GPUVendor: "nvidia", + VRAM: 1000000000000, + Backend: system.Backend{BackendsPath: tempDir}, + } + err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true) Expect(err).NotTo(HaveOccurred()) metaBackendPath := filepath.Join(tempDir, "meta-backend") @@ -381,16 +406,21 @@ var _ = Describe("Gallery Backends", func() { concreteBackendPath := filepath.Join(tempDir, "nvidia-backend") Expect(concreteBackendPath).To(BeADirectory()) - allBackends, err := ListSystemBackends(tempDir) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) Expect(err).NotTo(HaveOccurred()) - Expect(allBackends.Exists("meta-backend")).To(BeTrue()) - Expect(allBackends.Exists("nvidia-backend")).To(BeTrue()) - backend, ok := allBackends.Get("meta-backend") - Expect(ok).To(BeTrue()) - Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh"))) + + allBackends, err := ListSystemBackends(systemState) + Expect(err).NotTo(HaveOccurred()) + Expect(allBackends).To(HaveKey("meta-backend")) + Expect(allBackends).To(HaveKey("nvidia-backend")) + mback, exists := allBackends.Get("meta-backend") + Expect(exists).To(BeTrue()) + Expect(mback.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh"))) // Delete meta backend by name - err = DeleteBackendFromSystem(tempDir, "meta-backend") + err = DeleteBackendFromSystem(systemState, "meta-backend") Expect(err).NotTo(HaveOccurred()) // Verify meta backend directory is deleted @@ -427,25 +457,28 @@ var _ = Describe("Gallery Backends", func() { Expect(err).NotTo(HaveOccurred()) // List system backends - backends, err := ListSystemBackends(tempDir) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) Expect(err).NotTo(HaveOccurred()) + backends, err := ListSystemBackends(systemState) + Expect(err).NotTo(HaveOccurred()) + + metaBackend, exists := backends.Get("meta-backend") + concreteBackendRunFile := filepath.Join(tempDir, "concrete-backend", "run.sh") + // Should include both the meta backend name and concrete backend name - Expect(backends.Exists("meta-backend")).To(BeTrue()) + Expect(exists).To(BeTrue()) Expect(backends.Exists("concrete-backend")).To(BeTrue()) - // meta-backend should point to concrete-backend - Expect(backends.Exists("meta-backend")).To(BeTrue()) - backend, ok := backends.Get("meta-backend") - Expect(ok).To(BeTrue()) - Expect(backend.Metadata.MetaBackendFor).To(Equal("concrete-backend")) - Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "concrete-backend", "run.sh"))) - Expect(backend.IsMeta).To(BeTrue()) - + // meta-backend should be empty + Expect(metaBackend.IsMeta).To(BeTrue()) + Expect(metaBackend.RunFile).To(Equal(concreteBackendRunFile)) // concrete-backend should point to its own run.sh - backend, ok = backends.Get("concrete-backend") - Expect(ok).To(BeTrue()) - Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "concrete-backend", "run.sh"))) + concreteBackend, exists := backends.Get("concrete-backend") + Expect(exists).To(BeTrue()) + Expect(concreteBackend.RunFile).To(Equal(concreteBackendRunFile)) }) }) @@ -459,11 +492,80 @@ var _ = Describe("Gallery Backends", func() { URI: "test-uri", } - err := InstallBackend(newPath, &backend, nil) + systemState, err := system.GetSystemState( + system.WithBackendPath(newPath), + ) + Expect(err).NotTo(HaveOccurred()) + err = InstallBackend(systemState, &backend, nil) Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created Expect(newPath).To(BeADirectory()) }) + It("should overwrite existing backend", func() { + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + Skip("Skipping test on darwin/arm64") + } + newPath := filepath.Join(tempDir, "test-backend") + + // Create a dummy backend directory + err := os.MkdirAll(newPath, 0750) + Expect(err).NotTo(HaveOccurred()) + + err = os.WriteFile(filepath.Join(newPath, "metadata.json"), []byte("foo"), 0644) + Expect(err).NotTo(HaveOccurred()) + err = os.WriteFile(filepath.Join(newPath, "run.sh"), []byte(""), 0644) + Expect(err).NotTo(HaveOccurred()) + + backend := GalleryBackend{ + Metadata: Metadata{ + Name: "test-backend", + }, + URI: "quay.io/mudler/tests:localai-backend-test", + Alias: "test-alias", + } + + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) + Expect(err).NotTo(HaveOccurred()) + err = InstallBackend(systemState, &backend, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) + dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json")) + Expect(err).ToNot(HaveOccurred()) + Expect(string(dat)).ToNot(Equal("foo")) + }) + + It("should overwrite existing backend", func() { + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + Skip("Skipping test on darwin/arm64") + } + newPath := filepath.Join(tempDir, "test-backend") + + // Create a dummy backend directory + err := os.MkdirAll(newPath, 0750) + Expect(err).NotTo(HaveOccurred()) + + backend := GalleryBackend{ + Metadata: Metadata{ + Name: "test-backend", + }, + URI: "quay.io/mudler/tests:localai-backend-test", + Alias: "test-alias", + } + + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) + Expect(err).NotTo(HaveOccurred()) + + Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile()) + + err = InstallBackend(systemState, &backend, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) + }) + It("should create alias file when specified", func() { if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { Skip("Skipping test on darwin/arm64") @@ -476,7 +578,11 @@ var _ = Describe("Gallery Backends", func() { Alias: "test-alias", } - err := InstallBackend(tempDir, &backend, nil) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) + Expect(err).NotTo(HaveOccurred()) + err = InstallBackend(systemState, &backend, nil) Expect(err).ToNot(HaveOccurred()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) @@ -492,16 +598,14 @@ var _ = Describe("Gallery Backends", func() { Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile()) // Check that the alias was recognized - backends, err := ListSystemBackends(tempDir) + backends, err := ListSystemBackends(systemState) Expect(err).ToNot(HaveOccurred()) - Expect(backends.Exists("test-alias")).To(BeTrue()) - Expect(backends.Exists("test-backend")).To(BeTrue()) - b, ok := backends.Get("test-alias") - Expect(ok).To(BeTrue()) - Expect(b.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh"))) - b, ok = backends.Get("test-backend") - Expect(ok).To(BeTrue()) - Expect(b.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh"))) + aliasBackend, exists := backends.Get("test-alias") + Expect(exists).To(BeTrue()) + Expect(aliasBackend.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh"))) + testB, exists := backends.Get("test-backend") + Expect(exists).To(BeTrue()) + Expect(testB.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh"))) }) }) @@ -514,13 +618,26 @@ var _ = Describe("Gallery Backends", func() { err := os.MkdirAll(backendPath, 0750) Expect(err).NotTo(HaveOccurred()) - err = DeleteBackendFromSystem(tempDir, backendName) + err = os.WriteFile(filepath.Join(backendPath, "metadata.json"), []byte("{}"), 0644) + Expect(err).NotTo(HaveOccurred()) + err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0644) + Expect(err).NotTo(HaveOccurred()) + + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) + Expect(err).NotTo(HaveOccurred()) + err = DeleteBackendFromSystem(systemState, backendName) Expect(err).NotTo(HaveOccurred()) Expect(backendPath).NotTo(BeADirectory()) }) It("should not error when backend doesn't exist", func() { - err := DeleteBackendFromSystem(tempDir, "non-existent") + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) + Expect(err).NotTo(HaveOccurred()) + err = DeleteBackendFromSystem(systemState, "non-existent") Expect(err).To(HaveOccurred()) }) }) @@ -538,14 +655,17 @@ var _ = Describe("Gallery Backends", func() { Expect(err).NotTo(HaveOccurred()) } - backends, err := ListSystemBackends(tempDir) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) + Expect(err).NotTo(HaveOccurred()) + backends, err := ListSystemBackends(systemState) Expect(err).NotTo(HaveOccurred()) - Expect(backends.GetAll()).To(HaveLen(len(backendNames))) + Expect(backends).To(HaveLen(len(backendNames))) for _, name := range backendNames { - Expect(backends.Exists(name)).To(BeTrue()) - backend, ok := backends.Get(name) - Expect(ok).To(BeTrue()) + backend, exists := backends.Get(name) + Expect(exists).To(BeTrue()) Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, name, "run.sh"))) } }) @@ -572,16 +692,23 @@ var _ = Describe("Gallery Backends", func() { err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0755) Expect(err).NotTo(HaveOccurred()) - backends, err := ListSystemBackends(tempDir) + systemState, err := system.GetSystemState( + system.WithBackendPath(tempDir), + ) Expect(err).NotTo(HaveOccurred()) - Expect(backends.Exists(alias)).To(BeTrue()) - backend, ok := backends.Get(alias) - Expect(ok).To(BeTrue()) + backends, err := ListSystemBackends(systemState) + Expect(err).NotTo(HaveOccurred()) + backend, exists := backends.Get(alias) + Expect(exists).To(BeTrue()) Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, backendName, "run.sh"))) }) It("should return error when base path doesn't exist", func() { - _, err := ListSystemBackends(filepath.Join(tempDir, "non-existent")) + systemState, err := system.GetSystemState( + system.WithBackendPath("foobardir"), + ) + Expect(err).NotTo(HaveOccurred()) + _, err = ListSystemBackends(systemState) Expect(err).To(HaveOccurred()) }) }) diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index a46920f1aff9..a80550102b17 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -8,6 +8,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" ) @@ -89,7 +90,7 @@ func (gm GalleryElements[T]) Paginate(pageNum int, itemsNum int) GalleryElements return gm[start:end] } -func FindGalleryElement[T GalleryElement](models []T, name string, basePath string) T { +func FindGalleryElement[T GalleryElement](models []T, name string) T { var model T name = strings.ReplaceAll(name, string(os.PathSeparator), "__") @@ -116,13 +117,13 @@ func FindGalleryElement[T GalleryElement](models []T, name string, basePath stri // List available models // Models galleries are a list of yaml files that are hosted on a remote server (for example github). // Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting. -func AvailableGalleryModels(galleries []config.Gallery, basePath string) (GalleryElements[*GalleryModel], error) { +func AvailableGalleryModels(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryModel], error) { var models []*GalleryModel // Get models from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryElements[*GalleryModel](gallery, basePath, func(model *GalleryModel) bool { - if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { + galleryModels, err := getGalleryElements[*GalleryModel](gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool { + if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { return true } return false @@ -137,13 +138,13 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) (Galler } // List available backends -func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryElements[*GalleryBackend], error) { - var models []*GalleryBackend +func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) { + var backends []*GalleryBackend - // Get models from galleries + // Get backends from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryElements[*GalleryBackend](gallery, basePath, func(backend *GalleryBackend) bool { - backends, err := ListSystemBackends(basePath) + galleryBackends, err := getGalleryElements[*GalleryBackend](gallery, systemState.Backend.BackendsPath, func(backend *GalleryBackend) bool { + backends, err := ListSystemBackends(systemState) if err != nil { return false } @@ -152,10 +153,10 @@ func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryElem if err != nil { return nil, err } - models = append(models, galleryModels...) + backends = append(backends, galleryBackends...) } - return models, nil + return backends, nil } func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) { diff --git a/core/gallery/models.go b/core/gallery/models.go index 54afbfdfdfcf..f161e0cf3d5e 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -72,7 +72,8 @@ type PromptTemplate struct { // Installs a model from the gallery func InstallModelFromGallery( modelGalleries, backendGalleries []config.Gallery, - name string, basePath, backendBasePath string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error { + systemState *system.SystemState, + name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error { applyModel := func(model *GalleryModel) error { name = strings.ReplaceAll(name, string(os.PathSeparator), "__") @@ -81,7 +82,7 @@ func InstallModelFromGallery( if len(model.URL) > 0 { var err error - config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, basePath) + config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath) if err != nil { return err } @@ -122,19 +123,15 @@ func InstallModelFromGallery( return err } - installedModel, err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus, enforceScan) + installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan) if err != nil { return err } log.Debug().Msgf("Installed model %q", installedModel.Name) if automaticallyInstallBackend && installedModel.Backend != "" { log.Debug().Msgf("Installing backend %q", installedModel.Backend) - systemState, err := system.GetSystemState() - if err != nil { - return err - } - if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, backendBasePath, downloadStatus, false); err != nil { + if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, downloadStatus, false); err != nil { return err } } @@ -142,12 +139,12 @@ func InstallModelFromGallery( return nil } - models, err := AvailableGalleryModels(modelGalleries, basePath) + models, err := AvailableGalleryModels(modelGalleries, systemState) if err != nil { return err } - model := FindGalleryElement(models, name, basePath) + model := FindGalleryElement(models, name) if model == nil { return fmt.Errorf("no model found with name %q", name) } @@ -155,7 +152,8 @@ func InstallModelFromGallery( return applyModel(model) } -func InstallModel(basePath, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.BackendConfig, error) { +func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) { + basePath := systemState.Model.ModelsPath // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0750) if err != nil { @@ -221,7 +219,7 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver return nil, err } - backendConfig := lconfig.BackendConfig{} + modelConfig := lconfig.ModelConfig{} // write config file if len(configOverrides) != 0 || len(config.ConfigFile) != 0 { @@ -246,12 +244,12 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver return nil, fmt.Errorf("failed to marshal updated config YAML: %v", err) } - err = yaml.Unmarshal(updatedConfigYAML, &backendConfig) + err = yaml.Unmarshal(updatedConfigYAML, &modelConfig) if err != nil { return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err) } - if !backendConfig.Validate() { + if !modelConfig.Validate() { return nil, fmt.Errorf("failed to validate updated config YAML") } @@ -272,7 +270,7 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver log.Debug().Msgf("Written gallery file %s", modelFile) - return &backendConfig, os.WriteFile(modelFile, data, 0600) + return &modelConfig, os.WriteFile(modelFile, data, 0600) } func galleryFileName(name string) string { @@ -285,21 +283,39 @@ func GetLocalModelConfiguration(basePath string, name string) (*ModelConfig, err return ReadConfigFile[ModelConfig](galleryFile) } -func DeleteModelFromSystem(basePath string, name string, additionalFiles []string) error { - // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. - name = strings.ReplaceAll(name, string(os.PathSeparator), "__") +func DeleteModelFromSystem(systemState *system.SystemState, name string) error { + additionalFiles := []string{} - configFile := filepath.Join(basePath, fmt.Sprintf("%s.yaml", name)) + configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name)) + if err := utils.VerifyPath(configFile, systemState.Model.ModelsPath); err != nil { + return fmt.Errorf("failed to verify path %s: %w", configFile, err) + } + // Galleryname is the name of the model in this case + dat, err := os.ReadFile(configFile) + if err == nil { + modelConfig := &config.ModelConfig{} - galleryFile := filepath.Join(basePath, galleryFileName(name)) + err = yaml.Unmarshal(dat, &modelConfig) + if err != nil { + return err + } + if modelConfig.Model != "" { + additionalFiles = append(additionalFiles, modelConfig.ModelFileName()) + } - for _, f := range []string{configFile, galleryFile} { - if err := utils.VerifyPath(f, basePath); err != nil { - return fmt.Errorf("failed to verify path %s: %w", f, err) + if modelConfig.MMProj != "" { + additionalFiles = append(additionalFiles, modelConfig.MMProjFileName()) } } - var err error + // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + + galleryFile := filepath.Join(systemState.Model.ModelsPath, galleryFileName(name)) + if err := utils.VerifyPath(galleryFile, systemState.Model.ModelsPath); err != nil { + return fmt.Errorf("failed to verify path %s: %w", galleryFile, err) + } + // Delete all the files associated to the model // read the model config galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile) @@ -312,13 +328,19 @@ func DeleteModelFromSystem(basePath string, name string, additionalFiles []strin // Remove additional files if galleryconfig != nil { for _, f := range galleryconfig.Files { - fullPath := filepath.Join(basePath, f.Filename) + fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename) + if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil { + return fmt.Errorf("failed to verify path %s: %w", fullPath, err) + } filesToRemove = append(filesToRemove, fullPath) } } for _, f := range additionalFiles { - fullPath := filepath.Join(filepath.Join(basePath, f)) + fullPath := filepath.Join(filepath.Join(systemState.Model.ModelsPath, f)) + if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil { + return fmt.Errorf("failed to verify path %s: %w", fullPath, err) + } filesToRemove = append(filesToRemove, fullPath) } @@ -340,8 +362,8 @@ func DeleteModelFromSystem(basePath string, name string, additionalFiles []strin // This is ***NEVER*** going to be perfect or finished. // This is a BEST EFFORT function to surface known-vulnerable models to users. -func SafetyScanGalleryModels(galleries []config.Gallery, basePath string) error { - galleryModels, err := AvailableGalleryModels(galleries, basePath) +func SafetyScanGalleryModels(galleries []config.Gallery, systemState *system.SystemState) error { + galleryModels, err := AvailableGalleryModels(galleries, systemState) if err != nil { return err } diff --git a/core/gallery/models_test.go b/core/gallery/models_test.go index 5ffd675d1013..f259f0743d94 100644 --- a/core/gallery/models_test.go +++ b/core/gallery/models_test.go @@ -7,6 +7,7 @@ import ( "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" @@ -29,7 +30,11 @@ var _ = Describe("Model test", func() { defer os.RemoveAll(tempdir) c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - _, err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + systemState, err := system.GetSystemState( + system.WithModelPath(tempdir), + ) + Expect(err).ToNot(HaveOccurred()) + _, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { @@ -71,15 +76,19 @@ var _ = Describe("Model test", func() { URL: "file://" + galleryFilePath, }, } + systemState, err := system.GetSystemState( + system.WithModelPath(tempdir), + ) + Expect(err).ToNot(HaveOccurred()) - models, err := AvailableGalleryModels(galleries, tempdir) + models, err := AvailableGalleryModels(galleries, systemState) Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Name).To(Equal("bert")) Expect(models[0].URL).To(Equal(bertEmbeddingsURL)) Expect(models[0].Installed).To(BeFalse()) - err = InstallModelFromGallery(galleries, []config.Gallery{}, "test@bert", tempdir, "", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true) + err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true) Expect(err).ToNot(HaveOccurred()) dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) @@ -90,16 +99,16 @@ var _ = Describe("Model test", func() { Expect(err).ToNot(HaveOccurred()) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) - models, err = AvailableGalleryModels(galleries, tempdir) + models, err = AvailableGalleryModels(galleries, systemState) Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Installed).To(BeTrue()) // delete - err = DeleteModelFromSystem(tempdir, "bert", []string{}) + err = DeleteModelFromSystem(systemState, "bert") Expect(err).ToNot(HaveOccurred()) - models, err = AvailableGalleryModels(galleries, tempdir) + models, err = AvailableGalleryModels(galleries, systemState) Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Installed).To(BeFalse()) @@ -116,7 +125,11 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - _, err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + systemState, err := system.GetSystemState( + system.WithModelPath(tempdir), + ) + Expect(err).ToNot(HaveOccurred()) + _, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -132,7 +145,11 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - _, err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) + systemState, err := system.GetSystemState( + system.WithModelPath(tempdir), + ) + Expect(err).ToNot(HaveOccurred()) + _, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -158,7 +175,11 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - _, err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + systemState, err := system.GetSystemState( + system.WithModelPath(tempdir), + ) + Expect(err).ToNot(HaveOccurred()) + _, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).To(HaveOccurred()) }) }) diff --git a/core/http/app.go b/core/http/app.go index 11b3ae5732af..c73e94308fbe 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -198,7 +198,7 @@ func API(application *application.Application) (*fiber.App, error) { } galleryService := services.NewGalleryService(application.ApplicationConfig(), application.ModelLoader()) - err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader()) + err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader(), application.ApplicationConfig().SystemState) if err != nil { return nil, err } diff --git a/core/http/app_test.go b/core/http/app_test.go index 03aaf8a4c20a..b1f43d332d34 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -19,6 +19,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" + "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" @@ -320,12 +321,17 @@ var _ = Describe("API test", func() { }, } + systemState, err := system.GetSystemState( + system.WithBackendPath(backendPath), + system.WithModelPath(modelDir), + ) + Expect(err).ToNot(HaveOccurred()) + application, err := application.New( append(commonOpts, config.WithContext(c), + config.WithSystemState(systemState), config.WithGalleries(galleries), - config.WithModelPath(modelDir), - config.WithBackendsPath(backendPath), config.WithApiKeys([]string{apiKey}), )...) Expect(err).ToNot(HaveOccurred()) @@ -523,13 +529,18 @@ var _ = Describe("API test", func() { }, } + systemState, err := system.GetSystemState( + system.WithBackendPath(backendPath), + system.WithModelPath(modelDir), + ) + Expect(err).ToNot(HaveOccurred()) + application, err := application.New( append(commonOpts, config.WithContext(c), config.WithGeneratedContentDir(tmpdir), - config.WithBackendsPath(backendPath), + config.WithSystemState(systemState), config.WithGalleries(galleries), - config.WithModelPath(modelDir), )..., ) Expect(err).ToNot(HaveOccurred()) @@ -729,12 +740,17 @@ var _ = Describe("API test", func() { var err error + systemState, err := system.GetSystemState( + system.WithBackendPath(backendPath), + system.WithModelPath(modelPath), + ) + Expect(err).ToNot(HaveOccurred()) + application, err := application.New( append(commonOpts, config.WithExternalBackend("transformers", os.Getenv("HUGGINGFACE_GRPC")), config.WithContext(c), - config.WithBackendsPath(backendPath), - config.WithModelPath(modelPath), + config.WithSystemState(systemState), )...) Expect(err).ToNot(HaveOccurred()) app, err = API(application) @@ -960,11 +976,17 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error + + systemState, err := system.GetSystemState( + system.WithBackendPath(backendPath), + system.WithModelPath(modelPath), + ) + Expect(err).ToNot(HaveOccurred()) + application, err := application.New( append(commonOpts, config.WithContext(c), - config.WithModelPath(modelPath), - config.WithBackendsPath(backendPath), + config.WithSystemState(systemState), config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go index 548716def741..53da894e1c9b 100644 --- a/core/http/endpoints/elevenlabs/soundgeneration.go +++ b/core/http/endpoints/elevenlabs/soundgeneration.go @@ -15,7 +15,7 @@ import ( // @Param request body schema.ElevenLabsSoundGenerationRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/sound-generation [post] -func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest) @@ -23,7 +23,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad return fiber.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index 651a526fb644..dac5de70bf05 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -17,7 +17,7 @@ import ( // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/text-to-speech/{voice-id} [post] -func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { voiceID := c.Params("voice-id") @@ -27,7 +27,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi return fiber.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go index 26a09c2dcbb2..7d9270247f4a 100644 --- a/core/http/endpoints/jina/rerank.go +++ b/core/http/endpoints/jina/rerank.go @@ -17,7 +17,7 @@ import ( // @Param request body schema.JINARerankRequest true "query params" // @Success 200 {object} schema.JINARerankResponse "Response" // @Router /v1/rerank [post] -func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest) @@ -25,7 +25,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a return fiber.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/localai/backend.go b/core/http/endpoints/localai/backend.go index 2a8df50d3764..f4d88ffec154 100644 --- a/core/http/endpoints/localai/backend.go +++ b/core/http/endpoints/localai/backend.go @@ -11,24 +11,27 @@ import ( "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog/log" ) type BackendEndpointService struct { - galleries []config.Gallery - backendPath string - backendApplier *services.GalleryService + galleries []config.Gallery + backendPath string + backendSystemPath string + backendApplier *services.GalleryService } type GalleryBackend struct { ID string `json:"id"` } -func CreateBackendEndpointService(galleries []config.Gallery, backendPath string, backendApplier *services.GalleryService) BackendEndpointService { +func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService { return BackendEndpointService{ - galleries: galleries, - backendPath: backendPath, - backendApplier: backendApplier, + galleries: galleries, + backendPath: systemState.Backend.BackendsPath, + backendSystemPath: systemState.Backend.BackendsSystemPath, + backendApplier: backendApplier, } } @@ -111,9 +114,9 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er // @Summary List all Backends // @Success 200 {object} []gallery.GalleryBackend "Response" // @Router /backends [get] -func (mgs *BackendEndpointService) ListBackendsEndpoint() func(c *fiber.Ctx) error { +func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - backends, err := gallery.ListSystemBackends(mgs.backendPath) + backends, err := gallery.ListSystemBackends(systemState) if err != nil { return err } @@ -141,9 +144,9 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber. // @Summary List all available Backends // @Success 200 {object} []gallery.GalleryBackend "Response" // @Router /backends/available [get] -func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint() func(c *fiber.Ctx) error { +func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - backends, err := gallery.AvailableBackends(mgs.galleries, mgs.backendPath) + backends, err := gallery.AvailableBackends(mgs.galleries, systemState) if err != nil { return err } diff --git a/core/http/endpoints/localai/detection.go b/core/http/endpoints/localai/detection.go index 496a64c10f35..c4ab249110fe 100644 --- a/core/http/endpoints/localai/detection.go +++ b/core/http/endpoints/localai/detection.go @@ -16,7 +16,7 @@ import ( // @Param request body schema.DetectionRequest true "query params" // @Success 200 {object} schema.DetectionResponse "Response" // @Router /v1/detection [post] -func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest) @@ -24,7 +24,7 @@ func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, ap return fiber.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go index 9fa00521f71b..8948fde34250 100644 --- a/core/http/endpoints/localai/gallery.go +++ b/core/http/endpoints/localai/gallery.go @@ -11,6 +11,7 @@ import ( "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog/log" ) @@ -26,11 +27,11 @@ type GalleryModel struct { gallery.GalleryModel } -func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService { +func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService { return ModelGalleryEndpointService{ galleries: galleries, backendGalleries: backendGalleries, - modelPath: modelPath, + modelPath: systemState.Model.ModelsPath, galleryApplier: galleryApplier, } } @@ -115,10 +116,10 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib // @Summary List installable models. // @Success 200 {object} []gallery.GalleryModel "Response" // @Router /models/available [get] -func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { +func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath) + models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState) if err != nil { log.Error().Err(err).Msg("could not list models from galleries") return err diff --git a/core/http/endpoints/localai/get_token_metrics.go b/core/http/endpoints/localai/get_token_metrics.go index 30de2cdd5f20..db00193a8499 100644 --- a/core/http/endpoints/localai/get_token_metrics.go +++ b/core/http/endpoints/localai/get_token_metrics.go @@ -21,7 +21,7 @@ import ( // @Success 200 {string} binary "generated audio/wav file" // @Router /v1/tokenMetrics [get] // @Router /tokenMetrics [get] -func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.TokenMetricsRequest) @@ -37,7 +37,7 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, log.Warn().Msgf("Model not found in context: %s", input.Model) } - cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig) + cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(modelFile, appConfig) if err != nil { log.Err(err) diff --git a/core/http/endpoints/localai/tokenize.go b/core/http/endpoints/localai/tokenize.go index e0c06a27afc3..cd12e50dfb7a 100644 --- a/core/http/endpoints/localai/tokenize.go +++ b/core/http/endpoints/localai/tokenize.go @@ -14,14 +14,14 @@ import ( // @Param request body schema.TokenizeRequest true "Request" // @Success 200 {object} schema.TokenizeResponse "Response" // @Router /v1/tokenize [post] -func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(ctx *fiber.Ctx) error { input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest) if !ok || input.Model == "" { return fiber.ErrBadRequest } - cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 18c692dfbe06..5c2f01dad4b9 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -22,14 +22,14 @@ import ( // @Success 200 {string} binary "generated audio/wav file" // @Router /v1/audio/speech [post] // @Router /tts [post] -func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest) if !ok || input.Model == "" { return fiber.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/localai/vad.go b/core/http/endpoints/localai/vad.go index 384b97546ba4..c3e310fe867b 100644 --- a/core/http/endpoints/localai/vad.go +++ b/core/http/endpoints/localai/vad.go @@ -16,14 +16,14 @@ import ( // @Param request body schema.VADRequest true "query params" // @Success 200 {object} proto.VADResponse "Response" // @Router /vad [post] -func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest) if !ok || input.Model == "" { return fiber.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/localai/video.go b/core/http/endpoints/localai/video.go index bec8a6a12f57..df01ce316c5f 100644 --- a/core/http/endpoints/localai/video.go +++ b/core/http/endpoints/localai/video.go @@ -64,7 +64,7 @@ func downloadFile(url string) (string, error) { // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /video [post] -func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest) if !ok || input.Model == "" { @@ -72,7 +72,7 @@ func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon return fiber.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { log.Error().Msg("Video Endpoint - Invalid Config") return fiber.ErrBadRequest diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go index ba291536efcd..7f5e0076c366 100644 --- a/core/http/endpoints/localai/welcome.go +++ b/core/http/endpoints/localai/welcome.go @@ -11,12 +11,12 @@ import ( ) func WelcomeEndpoint(appConfig *config.ApplicationConfig, - cl *config.BackendConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error { + cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error { return func(c *fiber.Ctx) error { - backendConfigs := cl.GetAllBackendConfigs() + modelConfigs := cl.GetAllModelsConfigs() galleryConfigs := map[string]*gallery.ModelConfig{} - for _, m := range backendConfigs { + for _, m := range modelConfigs { cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name) if err != nil { continue @@ -34,7 +34,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig, "Version": internal.PrintableVersion(), "BaseURL": utils.BaseURL(c), "Models": modelsWithoutConfig, - "ModelsConfig": backendConfigs, + "ModelsConfig": modelConfigs, "GalleryConfig": galleryConfigs, "ApplicationConfig": appConfig, "ProcessingModels": processingModels, diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 0d96e04cdecc..afb279dcee5b 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -27,11 +27,11 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { +func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { var id, textContentToReturn string var created int - process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { + process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { initialMessage := schema.OpenAIResponse{ ID: id, Created: created, @@ -66,7 +66,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat }) close(responses) } - processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { result := "" _, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { result += s @@ -183,7 +183,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat extraUsage := c.Get("Extra-Usage", "") != "" - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return fiber.ErrBadRequest } @@ -501,7 +501,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } } -func handleQuestion(config *config.BackendConfig, cl *config.BackendConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) { +func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) { if len(funcResults) == 0 && result != "" { log.Debug().Msgf("nothing function results but we had a message from the LLM") diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 654166a1c48d..7700a4559097 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -27,10 +27,10 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] -func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { created := int(time.Now().Unix()) - process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { + process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, @@ -73,7 +73,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e return fiber.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index fbcd398d2017..0cdeba09f4a0 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -23,7 +23,7 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/edits [post] -func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { @@ -34,7 +34,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat // Opt-in extra usage flag extraUsage := c.Get("Extra-Usage", "") != "" - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index 9cbbe189457d..4154c435a5e9 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -21,14 +21,14 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/embeddings [post] -func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return fiber.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return fiber.ErrBadRequest } diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index fa641d8fbe91..91ecdd23ac97 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -65,7 +65,7 @@ func downloadFile(url string) (string, error) { // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/images/generations [post] -func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { @@ -73,7 +73,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon return fiber.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { log.Error().Msg("Image Endpoint - Invalid Config") return fiber.ErrBadRequest diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index c0deeb02e9b6..b7b256bad0c4 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -11,8 +11,8 @@ import ( func ComputeChoices( req *schema.OpenAIRequest, predInput string, - config *config.BackendConfig, - bcl *config.BackendConfigLoader, + config *config.ModelConfig, + bcl *config.ModelConfigLoader, o *config.ApplicationConfig, loader *model.ModelLoader, cb func(string, *[]schema.Choice), diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 9d21f8fe2f74..6c0ffca04aa0 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -12,7 +12,7 @@ import ( // @Summary List and describe the various models available in the API. // @Success 200 {object} schema.ModelsDataResponse "Response" // @Router /v1/models [get] -func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { // If blank, no filter is applied. filter := c.Query("filter") diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 4c4adef6f1ed..8f330e873f60 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -559,7 +559,7 @@ func sendNotImplemented(c *websocket.Conn, message string) { sendError(c, "not_implemented", message, "", "event_TODO") } -func updateTransSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { +func updateTransSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { sessionLock.Lock() defer sessionLock.Unlock() @@ -589,7 +589,7 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi } // Function to update session configurations -func updateSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { +func updateSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { sessionLock.Lock() defer sessionLock.Unlock() @@ -628,7 +628,7 @@ func updateSession(session *Session, update *types.ClientSession, cl *config.Bac // handleVAD is a goroutine that listens for audio data from the client, // runs VAD on the audio data, and commits utterances to the conversation -func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) { +func handleVAD(cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) { vadContext, cancel := context.WithCancel(context.Background()) go func() { <-done @@ -742,7 +742,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio } } -func commitUtterance(ctx context.Context, utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) { +func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) { if len(utt) == 0 { return } @@ -853,7 +853,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS // TODO: Below needed for normal mode instead of transcription only // Function to generate a response based on the conversation -// func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) { +// func generateResponse(config *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) { // // log.Debug().Msg("Generating realtime response...") // @@ -1067,7 +1067,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS // } // Function to process text response and detect function calls -func processTextResponse(config *config.BackendConfig, session *Session, prompt string) (string, *FunctionCall, error) { +func processTextResponse(config *config.ModelConfig, session *Session, prompt string) (string, *FunctionCall, error) { // Placeholder implementation // Replace this with actual model inference logic using session.Model and prompt diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index aeab31ad908a..a62e5a18a902 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -22,14 +22,14 @@ var ( // This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods // which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS) type wrappedModel struct { - TTSConfig *config.BackendConfig - TranscriptionConfig *config.BackendConfig - LLMConfig *config.BackendConfig + TTSConfig *config.ModelConfig + TranscriptionConfig *config.ModelConfig + LLMConfig *config.ModelConfig TTSClient grpcClient.Backend TranscriptionClient grpcClient.Backend LLMClient grpcClient.Backend - VADConfig *config.BackendConfig + VADConfig *config.ModelConfig VADClient grpcClient.Backend } @@ -37,17 +37,17 @@ type wrappedModel struct { // We have to wrap this out as well because we want to load two models one for VAD and one for the actual model. // In the future there could be models that accept continous audio input only so this design will be useful for that type anyToAnyModel struct { - LLMConfig *config.BackendConfig + LLMConfig *config.ModelConfig LLMClient grpcClient.Backend - VADConfig *config.BackendConfig + VADConfig *config.ModelConfig VADClient grpcClient.Backend } type transcriptOnlyModel struct { - TranscriptionConfig *config.BackendConfig + TranscriptionConfig *config.ModelConfig TranscriptionClient grpcClient.Backend - VADConfig *config.BackendConfig + VADConfig *config.ModelConfig VADClient grpcClient.Backend } @@ -105,8 +105,8 @@ func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOpti return m.LLMClient.PredictStream(ctx, in, f) } -func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.BackendConfig, error) { - cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath) +func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) { + cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) if err != nil { return nil, nil, fmt.Errorf("failed to load backend config: %w", err) @@ -122,7 +122,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf return nil, nil, fmt.Errorf("failed to load tts model: %w", err) } - cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath) + cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) if err != nil { return nil, nil, fmt.Errorf("failed to load backend config: %w", err) @@ -139,17 +139,17 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf } return &transcriptOnlyModel{ - VADConfig: cfgVAD, - VADClient: VADClient, + VADConfig: cfgVAD, + VADClient: VADClient, TranscriptionConfig: cfgSST, TranscriptionClient: transcriptionClient, }, cfgSST, nil } // returns and loads either a wrapped model or a model that support audio-to-audio -func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) { +func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) { - cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath) + cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) @@ -166,7 +166,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod } // TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process - cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath) + cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) @@ -185,7 +185,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod // TODO: Decide when we have a real any-to-any model if false { - cfgAnyToAny, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath) + cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) @@ -212,7 +212,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod log.Debug().Msg("Loading a wrapped model") // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations - cfgLLM, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath) + cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) @@ -222,7 +222,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod return nil, fmt.Errorf("failed to validate config: %w", err) } - cfgTTS, err := cl.LoadBackendConfigFileByName(pipeline.TTS, ml.ModelPath) + cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) @@ -232,7 +232,6 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod return nil, fmt.Errorf("failed to validate config: %w", err) } - opts = backend.ModelOptions(*cfgTTS, appConfig) ttsClient, err := ml.Load(opts...) if err != nil { diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index b10e06ef6963..56482576bd9d 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -24,14 +24,14 @@ import ( // @Param file formData file true "file" // @Success 200 {object} map[string]string "Response" // @Router /v1/audio/transcriptions [post] -func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return fiber.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return fiber.ErrBadRequest } diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index dccf55a0c251..35f39f7f37f9 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -26,16 +26,16 @@ type correlationIDKeyType string const CorrelationIDKey correlationIDKeyType = "correlationID" type RequestExtractor struct { - backendConfigLoader *config.BackendConfigLoader - modelLoader *model.ModelLoader - applicationConfig *config.ApplicationConfig + modelConfigLoader *config.ModelConfigLoader + modelLoader *model.ModelLoader + applicationConfig *config.ApplicationConfig } -func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { +func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { return &RequestExtractor{ - backendConfigLoader: backendConfigLoader, - modelLoader: modelLoader, - applicationConfig: applicationConfig, + modelConfigLoader: modelConfigLoader, + modelLoader: modelLoader, + applicationConfig: applicationConfig, } } @@ -59,7 +59,7 @@ func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) { // Set model from bearer token, if available bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request. if bearer != "" { - exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) + exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) if err == nil && exists { model = bearer } @@ -81,7 +81,7 @@ func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModel } } -func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler { +func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler { return func(ctx *fiber.Ctx) error { re.setModelNameFromRequest(ctx) localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) @@ -89,7 +89,7 @@ func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn con return ctx.Next() } - modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) + modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) if err != nil { log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()") return ctx.Next() @@ -129,7 +129,7 @@ func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIR } } - cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig) + cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig) if err != nil { log.Err(err) @@ -152,7 +152,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error { return fiber.ErrBadRequest } - cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return fiber.ErrBadRequest } @@ -168,7 +168,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error { input.Context = ctxWithCorrelationID input.Cancel = cancel - err := mergeOpenAIRequestAndBackendConfig(cfg, input) + err := mergeOpenAIRequestAndModelConfig(cfg, input) if err != nil { return err } @@ -184,7 +184,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error { return ctx.Next() } -func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error { +func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { if input.Echo { config.Echo = input.Echo } diff --git a/core/http/routes/elevenlabs.go b/core/http/routes/elevenlabs.go index 9e735bb11b6c..96e132e94e86 100644 --- a/core/http/routes/elevenlabs.go +++ b/core/http/routes/elevenlabs.go @@ -11,7 +11,7 @@ import ( func RegisterElevenLabsRoutes(app *fiber.App, re *middleware.RequestExtractor, - cl *config.BackendConfigLoader, + cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) { diff --git a/core/http/routes/jina.go b/core/http/routes/jina.go index 1f7a1a7c3e40..a55ca79f5597 100644 --- a/core/http/routes/jina.go +++ b/core/http/routes/jina.go @@ -12,7 +12,7 @@ import ( func RegisterJINARoutes(app *fiber.App, re *middleware.RequestExtractor, - cl *config.BackendConfigLoader, + cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) { diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index ce9e8496b382..54b7021f17a5 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -14,7 +14,7 @@ import ( func RegisterLocalAIRoutes(router *fiber.App, requestExtractor *middleware.RequestExtractor, - cl *config.BackendConfigLoader, + cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService) { @@ -23,20 +23,23 @@ func RegisterLocalAIRoutes(router *fiber.App, // LocalAI API endpoints if !appConfig.DisableGalleryEndpoint { - modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.ModelPath, galleryService) + modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService) router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) - router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint()) + router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState)) router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint()) router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint()) router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) - backendGalleryEndpointService := localai.CreateBackendEndpointService(appConfig.BackendGalleries, appConfig.BackendsPath, galleryService) + backendGalleryEndpointService := localai.CreateBackendEndpointService( + appConfig.BackendGalleries, + appConfig.SystemState, + galleryService) router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint()) router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint()) - router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint()) - router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint()) + router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState)) + router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState)) router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint()) router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint()) } diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index 11b2ab485596..6e2bda7a9d69 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -15,7 +15,7 @@ import ( ) func RegisterUIRoutes(app *fiber.App, - cl *config.BackendConfigLoader, + cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService) { @@ -65,9 +65,9 @@ func RegisterUIRoutes(app *fiber.App, } app.Get("/talk/", func(c *fiber.Ctx) error { - backendConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED) + modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED) - if len(backendConfigs) == 0 { + if len(modelConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models return c.Redirect(utils.BaseURL(c)) } @@ -75,8 +75,8 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": "LocalAI - Talk", "BaseURL": utils.BaseURL(c), - "ModelsConfig": backendConfigs, - "Model": backendConfigs[0], + "ModelsConfig": modelConfigs, + "Model": modelConfigs[0], "Version": internal.PrintableVersion(), } @@ -86,17 +86,17 @@ func RegisterUIRoutes(app *fiber.App, }) app.Get("/chat/", func(c *fiber.Ctx) error { - backendConfigs := cl.GetAllBackendConfigs() + modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) - if len(backendConfigs)+len(modelsWithoutConfig) == 0 { + if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models return c.Redirect(utils.BaseURL(c)) } modelThatCanBeUsed := "" galleryConfigs := map[string]*gallery.ModelConfig{} - for _, m := range backendConfigs { + for _, m := range modelConfigs { cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name) if err != nil { continue @@ -106,7 +106,7 @@ func RegisterUIRoutes(app *fiber.App, title := "LocalAI - Chat" - for _, b := range backendConfigs { + for _, b := range modelConfigs { if b.HasUsecases(config.FLAG_CHAT) { modelThatCanBeUsed = b.Name title = "LocalAI - Chat with " + modelThatCanBeUsed @@ -119,7 +119,7 @@ func RegisterUIRoutes(app *fiber.App, "BaseURL": utils.BaseURL(c), "ModelsWithoutConfig": modelsWithoutConfig, "GalleryConfig": galleryConfigs, - "ModelsConfig": backendConfigs, + "ModelsConfig": modelConfigs, "Model": modelThatCanBeUsed, "Version": internal.PrintableVersion(), } @@ -130,12 +130,12 @@ func RegisterUIRoutes(app *fiber.App, // Show the Chat page app.Get("/chat/:model", func(c *fiber.Ctx) error { - backendConfigs := cl.GetAllBackendConfigs() + modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) galleryConfigs := map[string]*gallery.ModelConfig{} - for _, m := range backendConfigs { + for _, m := range modelConfigs { cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name) if err != nil { continue @@ -146,7 +146,7 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": "LocalAI - Chat with " + c.Params("model"), "BaseURL": utils.BaseURL(c), - "ModelsConfig": backendConfigs, + "ModelsConfig": modelConfigs, "GalleryConfig": galleryConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": c.Params("model"), @@ -158,13 +158,13 @@ func RegisterUIRoutes(app *fiber.App, }) app.Get("/text2image/:model", func(c *fiber.Ctx) error { - backendConfigs := cl.GetAllBackendConfigs() + modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) summary := fiber.Map{ "Title": "LocalAI - Generate images with " + c.Params("model"), "BaseURL": utils.BaseURL(c), - "ModelsConfig": backendConfigs, + "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": c.Params("model"), "Version": internal.PrintableVersion(), @@ -175,10 +175,10 @@ func RegisterUIRoutes(app *fiber.App, }) app.Get("/text2image/", func(c *fiber.Ctx) error { - backendConfigs := cl.GetAllBackendConfigs() + modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) - if len(backendConfigs)+len(modelsWithoutConfig) == 0 { + if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models return c.Redirect(utils.BaseURL(c)) } @@ -186,7 +186,7 @@ func RegisterUIRoutes(app *fiber.App, modelThatCanBeUsed := "" title := "LocalAI - Generate images" - for _, b := range backendConfigs { + for _, b := range modelConfigs { if b.HasUsecases(config.FLAG_IMAGE) { modelThatCanBeUsed = b.Name title = "LocalAI - Generate images with " + modelThatCanBeUsed @@ -197,7 +197,7 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": title, "BaseURL": utils.BaseURL(c), - "ModelsConfig": backendConfigs, + "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": modelThatCanBeUsed, "Version": internal.PrintableVersion(), @@ -208,13 +208,13 @@ func RegisterUIRoutes(app *fiber.App, }) app.Get("/tts/:model", func(c *fiber.Ctx) error { - backendConfigs := cl.GetAllBackendConfigs() + modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) summary := fiber.Map{ "Title": "LocalAI - Generate images with " + c.Params("model"), "BaseURL": utils.BaseURL(c), - "ModelsConfig": backendConfigs, + "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": c.Params("model"), "Version": internal.PrintableVersion(), @@ -225,10 +225,10 @@ func RegisterUIRoutes(app *fiber.App, }) app.Get("/tts/", func(c *fiber.Ctx) error { - backendConfigs := cl.GetAllBackendConfigs() + modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) - if len(backendConfigs)+len(modelsWithoutConfig) == 0 { + if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models return c.Redirect(utils.BaseURL(c)) } @@ -236,7 +236,7 @@ func RegisterUIRoutes(app *fiber.App, modelThatCanBeUsed := "" title := "LocalAI - Generate audio" - for _, b := range backendConfigs { + for _, b := range modelConfigs { if b.HasUsecases(config.FLAG_TTS) { modelThatCanBeUsed = b.Name title = "LocalAI - Generate audio with " + modelThatCanBeUsed @@ -246,7 +246,7 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": title, "BaseURL": utils.BaseURL(c), - "ModelsConfig": backendConfigs, + "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": modelThatCanBeUsed, "Version": internal.PrintableVersion(), diff --git a/core/http/routes/ui_backend_gallery.go b/core/http/routes/ui_backend_gallery.go index d16cdb026cfb..94acd5bcf4c7 100644 --- a/core/http/routes/ui_backend_gallery.go +++ b/core/http/routes/ui_backend_gallery.go @@ -28,7 +28,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC page := c.Query("page") items := c.Query("items") - backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.BackendsPath) + backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState) if err != nil { log.Error().Err(err).Msg("could not list backends from galleries") return c.Status(fiber.StatusInternalServerError).Render("views/error", fiber.Map{ @@ -129,7 +129,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) } - backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.BackendsPath) + backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState) if page != "" { // return a subset of the backends diff --git a/core/http/routes/ui_gallery.go b/core/http/routes/ui_gallery.go index 54de40d5ef32..47bd5d4c7682 100644 --- a/core/http/routes/ui_gallery.go +++ b/core/http/routes/ui_gallery.go @@ -20,7 +20,7 @@ import ( "github.com/rs/zerolog/log" ) -func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) { +func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) { // Show the Models page (all models) app.Get("/browse", func(c *fiber.Ctx) error { @@ -28,7 +28,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo page := c.Query("page") items := c.Query("items") - models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath) + models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState) if err != nil { log.Error().Err(err).Msg("could not list models from galleries") return c.Status(fiber.StatusInternalServerError).Render("views/error", fiber.Map{ @@ -131,7 +131,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) } - models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath) + models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState) if page != "" { // return a subset of the models @@ -224,7 +224,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo } go func() { galleryService.ModelGalleryChannel <- op - cl.RemoveBackendConfig(galleryName) + cl.RemoveModelConfig(galleryName) }() return c.SendString(elements.StartModelProgressBar(uid, "0", "Deletion")) diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go index 88fefa09c49b..f5b16e2874b9 100644 --- a/core/services/backend_monitor.go +++ b/core/services/backend_monitor.go @@ -16,21 +16,21 @@ import ( ) type BackendMonitorService struct { - backendConfigLoader *config.BackendConfigLoader - modelLoader *model.ModelLoader - options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. + modelConfigLoader *config.ModelConfigLoader + modelLoader *model.ModelLoader + options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. } -func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService { +func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.ModelConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService { return &BackendMonitorService{ - modelLoader: modelLoader, - backendConfigLoader: configLoader, - options: appConfig, + modelLoader: modelLoader, + modelConfigLoader: configLoader, + options: appConfig, } } func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) { - config, exists := bms.backendConfigLoader.GetBackendConfig(modelName) + config, exists := bms.modelConfigLoader.GetModelConfig(modelName) var backendId string if exists { backendId = config.Model @@ -47,7 +47,7 @@ func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) } func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { - config, exists := bms.backendConfigLoader.GetBackendConfig(model) + config, exists := bms.modelConfigLoader.GetModelConfig(model) var backend string if exists { backend = config.Model diff --git a/core/services/backends.go b/core/services/backends.go index 143c1bd14f8b..73001f97d2d5 100644 --- a/core/services/backends.go +++ b/core/services/backends.go @@ -20,21 +20,21 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend], s var err error if op.Delete { - err = gallery.DeleteBackendFromSystem(g.appConfig.BackendsPath, op.GalleryElementName) + err = gallery.DeleteBackendFromSystem(g.appConfig.SystemState, op.GalleryElementName) g.modelLoader.DeleteExternalBackend(op.GalleryElementName) } else { log.Warn().Msgf("installing backend %s", op.GalleryElementName) log.Debug().Msgf("backend galleries: %v", g.appConfig.BackendGalleries) - err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, g.appConfig.BackendsPath, progressCallback, true) + err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, progressCallback, true) if err == nil { - err = gallery.RegisterBackends(g.appConfig.BackendsPath, g.modelLoader) + err = gallery.RegisterBackends(systemState, g.modelLoader) } } if err != nil { log.Error().Err(err).Msgf("error installing backend %s", op.GalleryElementName) if !op.Delete { // If we didn't install the backend, we need to make sure we don't have a leftover directory - gallery.DeleteBackendFromSystem(g.appConfig.BackendsPath, op.GalleryElementName) + gallery.DeleteBackendFromSystem(systemState, op.GalleryElementName) } return err } diff --git a/core/services/gallery.go b/core/services/gallery.go index cace4c15218e..4833aa334ba2 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -9,7 +9,6 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" - "github.com/rs/zerolog/log" ) type GalleryService struct { @@ -52,7 +51,7 @@ func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus { return g.statuses } -func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader) error { +func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error { // updates the status with an error var updateError func(id string, e error) if !g.appConfig.OpaqueErrors { @@ -65,11 +64,6 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader } } - systemState, err := system.GetSystemState() - if err != nil { - log.Error().Err(err).Msg("failed to get system state") - } - go func() { for { select { @@ -82,7 +76,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader } case op := <-g.ModelGalleryChannel: - err := g.modelHandler(&op, cl) + err := g.modelHandler(&op, cl, systemState) if err != nil { updateError(op.ID, err) } diff --git a/core/services/list_models.go b/core/services/list_models.go index 45c05f5828c4..67b0e751fbce 100644 --- a/core/services/list_models.go +++ b/core/services/list_models.go @@ -14,7 +14,7 @@ const ( ALWAYS_INCLUDE ) -func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter config.BackendConfigFilterFn, looseFilePolicy LooseFilePolicy) ([]string, error) { +func ListModels(bcl *config.ModelConfigLoader, ml *model.ModelLoader, filter config.ModelConfigFilterFn, looseFilePolicy LooseFilePolicy) ([]string, error) { var skipMap map[string]interface{} = map[string]interface{}{} @@ -22,7 +22,7 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c // Start with known configurations - for _, c := range bcl.GetBackendConfigsByFilter(filter) { + for _, c := range bcl.GetModelConfigsByFilter(filter) { // Is this better than looseFilePolicy <= SKIP_IF_CONFIGURED ? less performant but more readable? if (looseFilePolicy == SKIP_IF_CONFIGURED) || (looseFilePolicy == LOOSE_ONLY) { skipMap[c.Model] = nil @@ -50,7 +50,7 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c return dataModels, nil } -func CheckIfModelExists(bcl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) { +func CheckIfModelExists(bcl *config.ModelConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) { filter, err := config.BuildNameFilterFn(modelName) if err != nil { return false, err diff --git a/core/services/models.go b/core/services/models.go index 98b2ddc8d08c..032857a5b8ee 100644 --- a/core/services/models.go +++ b/core/services/models.go @@ -3,7 +3,6 @@ package services import ( "encoding/json" "os" - "path/filepath" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" @@ -12,7 +11,7 @@ import ( "gopkg.in/yaml.v2" ) -func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *config.BackendConfigLoader) error { +func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *config.ModelConfigLoader, systemState *system.SystemState) error { utils.ResetDownloadTimers() g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0}) @@ -23,18 +22,18 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *c utils.DisplayDownloadFunction(fileName, current, total, percentage) } - err := processModelOperation(op, g.appConfig.ModelPath, g.appConfig.BackendsPath, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback) + err := processModelOperation(op, systemState, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback) if err != nil { return err } // Reload models - err = cl.LoadBackendConfigsFromPath(g.appConfig.ModelPath) + err = cl.LoadModelConfigsFromPath(systemState.Model.ModelsPath) if err != nil { return err } - err = cl.Preload(g.appConfig.ModelPath) + err = cl.Preload(systemState.Model.ModelsPath) if err != nil { return err } @@ -50,26 +49,21 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *c return nil } -func installModelFromRemoteConfig(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery, backendBasePath string) error { - config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, modelPath) +func installModelFromRemoteConfig(systemState *system.SystemState, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error { + config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, systemState.Model.ModelsPath) if err != nil { return err } config.Files = append(config.Files, req.AdditionalFiles...) - installedModel, err := gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus, enforceScan) + installedModel, err := gallery.InstallModel(systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan) if err != nil { return err } if automaticallyInstallBackend && installedModel.Backend != "" { - systemState, err := system.GetSystemState() - if err != nil { - return err - } - - if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, backendBasePath, downloadStatus, false); err != nil { + if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, downloadStatus, false); err != nil { return err } } @@ -82,22 +76,22 @@ type galleryModel struct { ID string `json:"id"` } -func processRequests(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error { +func processRequests(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error { var err error for _, r := range requests { utils.ResetDownloadTimers() if r.ID == "" { - err = installModelFromRemoteConfig(modelPath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries, backendBasePath) + err = installModelFromRemoteConfig(systemState, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries) } else { err = gallery.InstallModelFromGallery( - galleries, backendGalleries, r.ID, modelPath, backendBasePath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend) + galleries, backendGalleries, systemState, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend) } } return err } -func ApplyGalleryFromFile(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error { +func ApplyGalleryFromFile(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error { dat, err := os.ReadFile(s) if err != nil { return err @@ -108,58 +102,35 @@ func ApplyGalleryFromFile(modelPath, backendBasePath string, enforceScan, automa return err } - return processRequests(modelPath, backendBasePath, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests) + return processRequests(systemState, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests) } -func ApplyGalleryFromString(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error { +func ApplyGalleryFromString(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error { var requests []galleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { return err } - return processRequests(modelPath, backendBasePath, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests) + return processRequests(systemState, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests) } // processModelOperation handles the installation or deletion of a model func processModelOperation( op *GalleryOp[gallery.GalleryModel], - modelPath string, - backendBasePath string, + systemState *system.SystemState, enforcePredownloadScans bool, automaticallyInstallBackend bool, progressCallback func(string, string, string, float64), ) error { // delete a model if op.Delete { - modelConfig := &config.BackendConfig{} - - // Galleryname is the name of the model in this case - dat, err := os.ReadFile(filepath.Join(modelPath, op.GalleryElementName+".yaml")) - if err != nil { - return err - } - err = yaml.Unmarshal(dat, modelConfig) - if err != nil { - return err - } - - files := []string{} - // Remove the model from the config - if modelConfig.Model != "" { - files = append(files, modelConfig.ModelFileName()) - } - - if modelConfig.MMProj != "" { - files = append(files, modelConfig.MMProjFileName()) - } - - return gallery.DeleteModelFromSystem(modelPath, op.GalleryElementName, files) + return gallery.DeleteModelFromSystem(systemState, op.GalleryElementName) } // if the request contains a gallery name, we apply the gallery from the gallery list if op.GalleryElementName != "" { - return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, op.GalleryElementName, modelPath, backendBasePath, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend) + return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, systemState, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend) // } else if op.ConfigURL != "" { // err := startup.InstallModels(op.Galleries, modelPath, enforcePredownloadScans, progressCallback, op.ConfigURL) // if err != nil { @@ -167,6 +138,6 @@ func processModelOperation( // } // return cl.Preload(modelPath) } else { - return installModelFromRemoteConfig(modelPath, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries, backendBasePath) + return installModelFromRemoteConfig(systemState, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries) } } diff --git a/core/startup/backend_preload.go b/core/startup/backend_preload.go index 1595a8e1c088..04296b4067b5 100644 --- a/core/startup/backend_preload.go +++ b/core/startup/backend_preload.go @@ -12,11 +12,7 @@ import ( "github.com/rs/zerolog/log" ) -func InstallExternalBackends(galleries []config.Gallery, backendPath string, downloadStatus func(string, string, string, float64), backend, name, alias string) error { - systemState, err := system.GetSystemState() - if err != nil { - return fmt.Errorf("failed to get system state: %w", err) - } +func InstallExternalBackends(galleries []config.Gallery, systemState *system.SystemState, downloadStatus func(string, string, string, float64), backend, name, alias string) error { uri := downloader.URI(backend) switch { case uri.LooksLikeDir(): @@ -24,7 +20,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow name = filepath.Base(backend) } log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from path") - if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{ + if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{ Metadata: gallery.Metadata{ Name: name, }, @@ -38,7 +34,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow return fmt.Errorf("specifying a name is required for OCI images") } log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image") - if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{ + if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{ Metadata: gallery.Metadata{ Name: name, }, @@ -56,7 +52,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow name = strings.TrimSuffix(name, filepath.Ext(name)) log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image") - if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{ + if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{ Metadata: gallery.Metadata{ Name: name, }, @@ -69,7 +65,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow if name != "" || alias != "" { return fmt.Errorf("specifying a name or alias is not supported for this backend") } - err := gallery.InstallBackendFromGallery(galleries, systemState, backend, backendPath, downloadStatus, true) + err := gallery.InstallBackendFromGallery(galleries, systemState, backend, downloadStatus, true) if err != nil { return fmt.Errorf("error installing backend %s: %w", backend, err) } diff --git a/core/startup/model_preload.go b/core/startup/model_preload.go index 78214eedef75..aa4ef4d15bc1 100644 --- a/core/startup/model_preload.go +++ b/core/startup/model_preload.go @@ -24,7 +24,7 @@ const ( // InstallModels will preload models from the given list of URLs and galleries // It will download the model if it is not already present in the model path // It will also try to resolve if the model is an embedded model YAML configuration -func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, backendBasePath string, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error { +func InstallModels(galleries, backendGalleries []config.Gallery, systemState *system.SystemState, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error { // create an error that groups all errors var err error @@ -36,7 +36,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back return e } - var model config.BackendConfig + var model config.ModelConfig if e := yaml.Unmarshal(modelYAML, &model); e != nil { log.Error().Err(e).Str("filepath", modelPath).Msg("error unmarshalling model definition") return e @@ -47,12 +47,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back return nil } - systemState, err := system.GetSystemState() - if err != nil { - return err - } - - if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, model.Backend, backendBasePath, downloadStatus, false); err != nil { + if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, model.Backend, downloadStatus, false); err != nil { log.Error().Err(err).Str("backend", model.Backend).Msg("error installing backend") return err } @@ -77,8 +72,8 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back ociName = strings.ReplaceAll(ociName, ":", "__") // check if file exists - if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) { - modelDefinitionFilePath := filepath.Join(modelPath, ociName) + if _, e := os.Stat(filepath.Join(systemState.Model.ModelsPath, ociName)); errors.Is(e, os.ErrNotExist) { + modelDefinitionFilePath := filepath.Join(systemState.Model.ModelsPath, ociName) e := uri.DownloadFile(modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) }) @@ -100,7 +95,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back continue } - modelPath := filepath.Join(modelPath, fileName) + modelPath := filepath.Join(systemState.Model.ModelsPath, fileName) if e := utils.VerifyPath(fileName, modelPath); e != nil { log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path") @@ -138,7 +133,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back continue } - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + YAML_EXTENSION + modelDefinitionFilePath := filepath.Join(systemState.Model.ModelsPath, md5Name) + YAML_EXTENSION if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil { log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") err = errors.Join(err, e) @@ -152,7 +147,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back } } else { // Check if it's a model gallery, or print a warning - e, found := installModel(galleries, backendGalleries, url, modelPath, backendBasePath, downloadStatus, enforceScan, autoloadBackendGalleries) + e, found := installModel(galleries, backendGalleries, url, systemState, downloadStatus, enforceScan, autoloadBackendGalleries) if e != nil && found { log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url) err = errors.Join(err, e) @@ -166,13 +161,13 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back return err } -func installModel(galleries, backendGalleries []config.Gallery, modelName, modelPath, backendBasePath string, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) { - models, err := gallery.AvailableGalleryModels(galleries, modelPath) +func installModel(galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) { + models, err := gallery.AvailableGalleryModels(galleries, systemState) if err != nil { return err, false } - model := gallery.FindGalleryElement(models, modelName, modelPath) + model := gallery.FindGalleryElement(models, modelName) if model == nil { return err, false } @@ -182,7 +177,7 @@ func installModel(galleries, backendGalleries []config.Gallery, modelName, model } log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") - err = gallery.InstallModelFromGallery(galleries, backendGalleries, modelName, modelPath, backendBasePath, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries) + err = gallery.InstallModelFromGallery(galleries, backendGalleries, systemState, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries) if err != nil { return err, true } diff --git a/core/startup/model_preload_test.go b/core/startup/model_preload_test.go index b99544d28d27..a15e13bec404 100644 --- a/core/startup/model_preload_test.go +++ b/core/startup/model_preload_test.go @@ -7,6 +7,7 @@ import ( "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/startup" + "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -21,7 +22,10 @@ var _ = Describe("Preload test", func() { url := "https://raw.githubusercontent.com/mudler/LocalAI-examples/main/configurations/phi-2.yaml" fileName := fmt.Sprintf("%s.yaml", "phi-2") - InstallModels([]config.Gallery{}, []config.Gallery{}, tmpdir, "", true, true, nil, url) + systemState, err := system.GetSystemState(system.WithModelPath(tmpdir)) + Expect(err).ToNot(HaveOccurred()) + + InstallModels([]config.Gallery{}, []config.Gallery{}, systemState, true, true, nil, url) resultFile := filepath.Join(tmpdir, fileName) @@ -36,7 +40,10 @@ var _ = Describe("Preload test", func() { url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf" fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K") - err = InstallModels([]config.Gallery{}, []config.Gallery{}, tmpdir, "", false, true, nil, url) + systemState, err := system.GetSystemState(system.WithModelPath(tmpdir)) + Expect(err).ToNot(HaveOccurred()) + + err = InstallModels([]config.Gallery{}, []config.Gallery{}, systemState, true, true, nil, url) Expect(err).ToNot(HaveOccurred()) resultFile := filepath.Join(tmpdir, fileName) diff --git a/core/templates/evaluator.go b/core/templates/evaluator.go index f9bd313afce5..12c2080555f1 100644 --- a/core/templates/evaluator.go +++ b/core/templates/evaluator.go @@ -55,7 +55,7 @@ func NewEvaluator(modelPath string) *Evaluator { } } -func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) { +func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.ModelConfig, in PromptTemplateData) (string, error) { template := "" // A model can have a "file.bin.tmpl" file associated with a prompt template prefix @@ -135,7 +135,7 @@ func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, te return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation) } -func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string { +func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.ModelConfig, funcs []functions.Function, shouldUseFn bool) string { if config.TemplateConfig.JinjaTemplate { var messageData []ChatMessageTemplateData diff --git a/core/templates/evaluator_test.go b/core/templates/evaluator_test.go index 41c17e6a517d..6d29c876b519 100644 --- a/core/templates/evaluator_test.go +++ b/core/templates/evaluator_test.go @@ -53,7 +53,7 @@ Function response: var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ "user": { "expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: llama3, }, @@ -69,7 +69,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in }, "assistant": { "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: llama3, }, @@ -86,7 +86,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in "function_call": { "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: llama3, }, @@ -102,7 +102,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in }, "function_response": { "expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: llama3, }, @@ -121,7 +121,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ "user": { "expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: chatML, }, @@ -137,7 +137,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in }, "assistant": { "expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: chatML, }, @@ -153,7 +153,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in }, "function_call": { "expected": "<|im_start|>assistant\n\n{\"function\":\"test\"}\n<|im_end|>", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: chatML, }, @@ -175,7 +175,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in }, "function_response": { "expected": "<|im_start|>tool\n\nResponse from tool\n<|im_end|>", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: chatML, }, @@ -194,7 +194,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{ "user": { "expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - "config": &config.BackendConfig{ + "config": &config.ModelConfig{ TemplateConfig: config.TemplateConfig{ ChatMessage: toolCallJinja, JinjaTemplate: true, @@ -219,7 +219,7 @@ var _ = Describe("Templates", func() { for key := range chatMLTestMatch { foo := chatMLTestMatch[key] It("renders correctly `"+key+"`", func() { - templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) + templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) Expect(templated).To(Equal(foo["expected"]), templated) }) } @@ -232,7 +232,7 @@ var _ = Describe("Templates", func() { for key := range llama3TestMatch { foo := llama3TestMatch[key] It("renders correctly `"+key+"`", func() { - templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) + templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) Expect(templated).To(Equal(foo["expected"]), templated) }) } @@ -245,7 +245,7 @@ var _ = Describe("Templates", func() { for key := range jinjaTest { foo := jinjaTest[key] It("renders correctly `"+key+"`", func() { - templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) + templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) Expect(templated).To(Equal(foo["expected"]), templated) }) } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index e1543a2c26aa..5d215eaf6989 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" @@ -28,9 +29,9 @@ type ModelLoader struct { externalBackends map[string]string } -func NewModelLoader(modelPath string, singleActiveBackend bool) *ModelLoader { +func NewModelLoader(system *system.SystemState, singleActiveBackend bool) *ModelLoader { nml := &ModelLoader{ - ModelPath: modelPath, + ModelPath: system.Model.ModelsPath, models: make(map[string]*Model), singletonMode: singleActiveBackend, externalBackends: make(map[string]string), diff --git a/pkg/model/loader_test.go b/pkg/model/loader_test.go index a8e77bd28807..fc48265dbb7c 100644 --- a/pkg/model/loader_test.go +++ b/pkg/model/loader_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -21,7 +22,12 @@ var _ = Describe("ModelLoader", func() { // Setup the model loader with a test directory modelPath = "/tmp/test_model_path" os.Mkdir(modelPath, 0755) - modelLoader = model.NewModelLoader(modelPath, false) + + systemState, err := system.GetSystemState( + system.WithModelPath(modelPath), + ) + Expect(err).ToNot(HaveOccurred()) + modelLoader = model.NewModelLoader(systemState, false) }) AfterEach(func() { diff --git a/pkg/system/capabilities.go b/pkg/system/capabilities.go index 056e569fa105..6102cf1efd60 100644 --- a/pkg/system/capabilities.go +++ b/pkg/system/capabilities.go @@ -6,16 +6,9 @@ import ( "strings" "github.com/jaypipes/ghw/pkg/gpu" - "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/rs/zerolog/log" ) -type SystemState struct { - GPUVendor string - gpus []*gpu.GraphicsCard - VRAM uint64 -} - const ( defaultCapability = "default" nvidiaL4T = "nvidia-l4t" @@ -103,22 +96,6 @@ func (s *SystemState) getSystemCapabilities() string { return s.GPUVendor } -func GetSystemState() (*SystemState, error) { - // Detection is best-effort here, we don't want to fail if it fails - gpus, _ := xsysinfo.GPUs() - log.Debug().Any("gpus", gpus).Msg("GPUs") - gpuVendor, _ := detectGPUVendor(gpus) - log.Debug().Str("gpuVendor", gpuVendor).Msg("GPU vendor") - vram, _ := xsysinfo.TotalAvailableVRAM() - log.Debug().Any("vram", vram).Msg("Total available VRAM") - - return &SystemState{ - GPUVendor: gpuVendor, - gpus: gpus, - VRAM: vram, - }, nil -} - func detectGPUVendor(gpus []*gpu.GraphicsCard) (string, error) { for _, gpu := range gpus { if gpu.DeviceInfo != nil { diff --git a/pkg/system/state.go b/pkg/system/state.go new file mode 100644 index 000000000000..7380c7f4ba92 --- /dev/null +++ b/pkg/system/state.go @@ -0,0 +1,61 @@ +package system + +import ( + "github.com/jaypipes/ghw/pkg/gpu" + "github.com/mudler/LocalAI/pkg/xsysinfo" + "github.com/rs/zerolog/log" +) + +type Backend struct { + BackendsPath string + BackendsSystemPath string +} + +type Model struct { + ModelsPath string +} + +type SystemState struct { + GPUVendor string + Backend Backend + Model Model + gpus []*gpu.GraphicsCard + VRAM uint64 +} + +type SystemStateOptions func(*SystemState) + +func WithBackendPath(path string) SystemStateOptions { + return func(s *SystemState) { + s.Backend.BackendsPath = path + } +} + +func WithBackendSystemPath(path string) SystemStateOptions { + return func(s *SystemState) { + s.Backend.BackendsSystemPath = path + } +} + +func WithModelPath(path string) SystemStateOptions { + return func(s *SystemState) { + s.Model.ModelsPath = path + } +} + +func GetSystemState(opts ...SystemStateOptions) (*SystemState, error) { + state := &SystemState{} + for _, opt := range opts { + opt(state) + } + + // Detection is best-effort here, we don't want to fail if it fails + state.gpus, _ = xsysinfo.GPUs() + log.Debug().Any("gpus", state.gpus).Msg("GPUs") + state.GPUVendor, _ = detectGPUVendor(state.gpus) + log.Debug().Str("gpuVendor", state.GPUVendor).Msg("GPU vendor") + state.VRAM, _ = xsysinfo.TotalAvailableVRAM() + log.Debug().Any("vram", state.VRAM).Msg("Total available VRAM") + + return state, nil +} diff --git a/tests/integration/stores_test.go b/tests/integration/stores_test.go index dfe992c1d53f..2cb3afd8a898 100644 --- a/tests/integration/stores_test.go +++ b/tests/integration/stores_test.go @@ -15,6 +15,7 @@ import ( "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/store" + "github.com/mudler/LocalAI/pkg/system" ) func normalize(vecs [][]float32) { @@ -46,7 +47,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" debug := true - bc := config.BackendConfig{ + bc := config.ModelConfig{ Name: "store test", Debug: &debug, Backend: model.LocalStoreBackend, @@ -57,7 +58,12 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" model.WithModel("test"), } - sl = model.NewModelLoader("", false) + systemState, err := system.GetSystemState( + system.WithModelPath(tmpdir), + ) + Expect(err).ToNot(HaveOccurred()) + + sl = model.NewModelLoader(systemState, false) sc, err = sl.Load(storeOpts...) Expect(err).ToNot(HaveOccurred()) Expect(sc).ToNot(BeNil())