Skip to content

Commit 7050c9f

Browse files
authored
feat(webui): add import/edit model page (#6050)
* feat(webui): add import/edit model page Signed-off-by: Ettore Di Giacinto <[email protected]> * Convert to a YAML editor Signed-off-by: Ettore Di Giacinto <[email protected]> * Pass by the baseurl Signed-off-by: Ettore Di Giacinto <[email protected]> * Fixups Signed-off-by: Ettore Di Giacinto <[email protected]> * Add tests Signed-off-by: Ettore Di Giacinto <[email protected]> * Simplify Signed-off-by: Ettore Di Giacinto <[email protected]> * Improve visibility of the yaml editor Signed-off-by: Ettore Di Giacinto <[email protected]> * Add test file Signed-off-by: Ettore Di Giacinto <[email protected]> * Make reset work Signed-off-by: Ettore Di Giacinto <[email protected]> * Emit error only if we can't delete the model yaml file Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 089efe0 commit 7050c9f

File tree

10 files changed

+1664
-138
lines changed

10 files changed

+1664
-138
lines changed

core/config/backend_config.go

Lines changed: 129 additions & 126 deletions
Large diffs are not rendered by default.

core/gallery/models.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -316,24 +316,21 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
316316
return fmt.Errorf("failed to verify path %s: %w", galleryFile, err)
317317
}
318318

319+
var filesToRemove []string
320+
319321
// Delete all the files associated to the model
320322
// read the model config
321323
galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile)
322-
if err != nil {
323-
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
324-
}
325-
326-
var filesToRemove []string
327-
328-
// Remove additional files
329-
if galleryconfig != nil {
324+
if err == nil && galleryconfig != nil {
330325
for _, f := range galleryconfig.Files {
331326
fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename)
332327
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
333328
return fmt.Errorf("failed to verify path %s: %w", fullPath, err)
334329
}
335330
filesToRemove = append(filesToRemove, fullPath)
336331
}
332+
} else {
333+
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
337334
}
338335

339336
for _, f := range additionalFiles {
@@ -344,7 +341,6 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
344341
filesToRemove = append(filesToRemove, fullPath)
345342
}
346343

347-
filesToRemove = append(filesToRemove, configFile)
348344
filesToRemove = append(filesToRemove, galleryFile)
349345

350346
// skip duplicates
@@ -353,11 +349,11 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
353349
// Removing files
354350
for _, f := range filesToRemove {
355351
if e := os.Remove(f); e != nil {
356-
err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", f, e))
352+
log.Error().Err(e).Msgf("failed to remove file %s", f)
357353
}
358354
}
359355

360-
return err
356+
return os.Remove(configFile)
361357
}
362358

363359
// This is ***NEVER*** going to be perfect or finished.
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
package localai
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
10+
"github.com/gofiber/fiber/v2"
11+
"github.com/mudler/LocalAI/core/config"
12+
httpUtils "github.com/mudler/LocalAI/core/http/utils"
13+
"github.com/mudler/LocalAI/internal"
14+
"github.com/mudler/LocalAI/pkg/utils"
15+
16+
"gopkg.in/yaml.v3"
17+
)
18+
19+
// GetEditModelPage renders the edit model page with current configuration
20+
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
21+
return func(c *fiber.Ctx) error {
22+
modelName := c.Params("name")
23+
if modelName == "" {
24+
response := ModelResponse{
25+
Success: false,
26+
Error: "Model name is required",
27+
}
28+
return c.Status(400).JSON(response)
29+
}
30+
31+
modelConfig, exists := cl.GetModelConfig(modelName)
32+
if !exists {
33+
response := ModelResponse{
34+
Success: false,
35+
Error: "Model configuration not found",
36+
}
37+
return c.Status(404).JSON(response)
38+
}
39+
40+
configData, err := yaml.Marshal(modelConfig)
41+
if err != nil {
42+
response := ModelResponse{
43+
Success: false,
44+
Error: "Failed to marshal configuration: " + err.Error(),
45+
}
46+
return c.Status(500).JSON(response)
47+
}
48+
49+
// Marshal the config to JSON for the template
50+
configJSON, err := json.Marshal(modelConfig)
51+
if err != nil {
52+
response := ModelResponse{
53+
Success: false,
54+
Error: "Failed to marshal configuration: " + err.Error(),
55+
}
56+
return c.Status(500).JSON(response)
57+
}
58+
59+
// Render the edit page with the current configuration
60+
templateData := struct {
61+
Title string
62+
ModelName string
63+
Config *config.ModelConfig
64+
ConfigJSON string
65+
ConfigYAML string
66+
BaseURL string
67+
Version string
68+
}{
69+
Title: "LocalAI - Edit Model " + modelName,
70+
ModelName: modelName,
71+
Config: &modelConfig,
72+
ConfigJSON: string(configJSON),
73+
ConfigYAML: string(configData),
74+
BaseURL: httpUtils.BaseURL(c),
75+
Version: internal.PrintableVersion(),
76+
}
77+
78+
return c.Render("views/model-editor", templateData)
79+
}
80+
}
81+
82+
// EditModelEndpoint handles updating existing model configurations
83+
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
84+
return func(c *fiber.Ctx) error {
85+
modelName := c.Params("name")
86+
if modelName == "" {
87+
response := ModelResponse{
88+
Success: false,
89+
Error: "Model name is required",
90+
}
91+
return c.Status(400).JSON(response)
92+
}
93+
94+
// Get the raw body
95+
body := c.Body()
96+
if len(body) == 0 {
97+
response := ModelResponse{
98+
Success: false,
99+
Error: "Request body is empty",
100+
}
101+
return c.Status(400).JSON(response)
102+
}
103+
104+
// Check content type to determine how to parse
105+
contentType := string(c.Context().Request.Header.ContentType())
106+
var req config.ModelConfig
107+
var err error
108+
109+
if strings.Contains(contentType, "application/json") {
110+
// Parse JSON
111+
if err := json.Unmarshal(body, &req); err != nil {
112+
response := ModelResponse{
113+
Success: false,
114+
Error: "Failed to parse JSON: " + err.Error(),
115+
}
116+
return c.Status(400).JSON(response)
117+
}
118+
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
119+
// Parse YAML
120+
if err := yaml.Unmarshal(body, &req); err != nil {
121+
response := ModelResponse{
122+
Success: false,
123+
Error: "Failed to parse YAML: " + err.Error(),
124+
}
125+
return c.Status(400).JSON(response)
126+
}
127+
} else {
128+
// Try to auto-detect format
129+
if strings.TrimSpace(string(body))[0] == '{' {
130+
// Looks like JSON
131+
if err := json.Unmarshal(body, &req); err != nil {
132+
response := ModelResponse{
133+
Success: false,
134+
Error: "Failed to parse JSON: " + err.Error(),
135+
}
136+
return c.Status(400).JSON(response)
137+
}
138+
} else {
139+
// Assume YAML
140+
if err := yaml.Unmarshal(body, &req); err != nil {
141+
response := ModelResponse{
142+
Success: false,
143+
Error: "Failed to parse YAML: " + err.Error(),
144+
}
145+
return c.Status(400).JSON(response)
146+
}
147+
}
148+
}
149+
150+
// Validate required fields
151+
if req.Name == "" {
152+
response := ModelResponse{
153+
Success: false,
154+
Error: "Name is required",
155+
}
156+
return c.Status(400).JSON(response)
157+
}
158+
159+
// Load the existing configuration
160+
configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelName+".yaml")
161+
if err := utils.InTrustedRoot(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
162+
response := ModelResponse{
163+
Success: false,
164+
Error: "Model configuration not trusted: " + err.Error(),
165+
}
166+
return c.Status(404).JSON(response)
167+
}
168+
169+
// Set defaults
170+
req.SetDefaults()
171+
172+
// Validate the configuration
173+
if !req.Validate() {
174+
response := ModelResponse{
175+
Success: false,
176+
Error: "Validation failed",
177+
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
178+
}
179+
return c.Status(400).JSON(response)
180+
}
181+
182+
// Create the YAML file
183+
yamlData, err := yaml.Marshal(req)
184+
if err != nil {
185+
response := ModelResponse{
186+
Success: false,
187+
Error: "Failed to marshal configuration: " + err.Error(),
188+
}
189+
return c.Status(500).JSON(response)
190+
}
191+
192+
// Write to file
193+
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
194+
response := ModelResponse{
195+
Success: false,
196+
Error: "Failed to write configuration file: " + err.Error(),
197+
}
198+
return c.Status(500).JSON(response)
199+
}
200+
201+
// Reload configurations
202+
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
203+
response := ModelResponse{
204+
Success: false,
205+
Error: "Failed to reload configurations: " + err.Error(),
206+
}
207+
return c.Status(500).JSON(response)
208+
}
209+
210+
// Preload the model
211+
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
212+
response := ModelResponse{
213+
Success: false,
214+
Error: "Failed to preload model: " + err.Error(),
215+
}
216+
return c.Status(500).JSON(response)
217+
}
218+
219+
// Return success response
220+
response := ModelResponse{
221+
Success: true,
222+
Message: fmt.Sprintf("Model '%s' updated successfully", modelName),
223+
Filename: configPath,
224+
Config: req,
225+
}
226+
return c.JSON(response)
227+
}
228+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package localai_test
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"net/http/httptest"
7+
"os"
8+
"path/filepath"
9+
10+
"github.com/gofiber/fiber/v2"
11+
"github.com/mudler/LocalAI/core/config"
12+
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
13+
"github.com/mudler/LocalAI/pkg/system"
14+
. "github.com/onsi/ginkgo/v2"
15+
. "github.com/onsi/gomega"
16+
)
17+
18+
var _ = Describe("Edit Model test", func() {
19+
20+
var tempDir string
21+
BeforeEach(func() {
22+
var err error
23+
tempDir, err = os.MkdirTemp("", "localai-test")
24+
Expect(err).ToNot(HaveOccurred())
25+
})
26+
AfterEach(func() {
27+
os.RemoveAll(tempDir)
28+
})
29+
30+
Context("Edit Model endpoint", func() {
31+
It("should edit a model", func() {
32+
systemState, err := system.GetSystemState(
33+
system.WithModelPath(filepath.Join(tempDir)),
34+
)
35+
Expect(err).ToNot(HaveOccurred())
36+
37+
applicationConfig := config.NewApplicationConfig(
38+
config.WithSystemState(systemState),
39+
)
40+
//modelLoader := model.NewModelLoader(systemState, true)
41+
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
42+
43+
// Define Fiber app.
44+
app := fiber.New()
45+
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
46+
47+
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
48+
49+
req := httptest.NewRequest("PUT", "/import-model", requestBody)
50+
resp, err := app.Test(req, 5000)
51+
Expect(err).ToNot(HaveOccurred())
52+
53+
body, err := io.ReadAll(resp.Body)
54+
defer resp.Body.Close()
55+
Expect(err).ToNot(HaveOccurred())
56+
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
57+
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
58+
59+
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
60+
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
61+
62+
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
63+
resp, _ = app.Test(req, 1)
64+
65+
body, err = io.ReadAll(resp.Body)
66+
defer resp.Body.Close()
67+
Expect(err).ToNot(HaveOccurred())
68+
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
69+
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
70+
})
71+
})
72+
})

0 commit comments

Comments
 (0)