Skip to content

Commit 9b36947

Browse files
committed
feat: support HF_ENDPOINT env for the HuggingFace endpoint
ie: `HF_ENDPOINT=https://hf-mirror.com`
1 parent e905e90 commit 9b36947

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

pkg/downloader/huggingface.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ var ErrUnsafeFilesFound = errors.New("unsafe files found")
2323

2424
func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) {
2525
cleanParts := strings.Split(uri.ResolveURL(), "/")
26-
if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" {
26+
if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" && cleanParts[2] != HF_ENDPOINT {
2727
return nil, ErrNonHuggingFaceFile
2828
}
29-
results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4]))
29+
results, err := http.Get(fmt.Sprintf("%s/api/models/%s/%s/scan", HF_ENDPOINT, cleanParts[3], cleanParts[4]))
3030
if err != nil {
3131
return nil, err
3232
}

pkg/downloader/uri.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ const (
3737

3838
type URI string
3939

40+
// HF_ENDPOINT is the HuggingFace endpoint, can be overridden by setting the HF_ENDPOINT environment variable.
41+
var HF_ENDPOINT string = loadConfig()
42+
43+
func loadConfig() string {
44+
HF_ENDPOINT := os.Getenv("HF_ENDPOINT")
45+
if HF_ENDPOINT == "" {
46+
HF_ENDPOINT = "https://huggingface.co"
47+
}
48+
return HF_ENDPOINT
49+
}
50+
4051
func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error {
4152
return uri.DownloadWithAuthorizationAndCallback(basePath, "", f)
4253
}
@@ -213,7 +224,7 @@ func (s URI) ResolveURL() string {
213224
filepath = strings.Split(filepath, "@")[0]
214225
}
215226

216-
return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath)
227+
return fmt.Sprintf("%s/%s/%s/resolve/%s/%s", HF_ENDPOINT, owner, repo, branch, filepath)
217228
}
218229

219230
return string(s)

0 commit comments

Comments
 (0)